1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- from dataclasses import dataclass
- from pathlib import Path
- from typing import List
- from datasets import Dataset, concatenate_datasets
- import logging
- import shutil
- logger = logging.getLogger("raft")
- @dataclass
- class Checkpoint:
- path: Path
- num: int
- def load(self) -> Dataset:
- return Dataset.load_from_disk(self.path)
- def __lt__(self, other: 'Checkpoint') -> bool:
- return self.num < other.num
- def __eq__(self, other: 'Checkpoint') -> bool:
- return self.num == other.num
- def __hash__(self) -> int:
- return hash(self.num)
- class Checkpointing:
- def __init__(self, checkpoints_dir: Path) -> None:
- self.checkpoints_dir = checkpoints_dir
- def missing_checkpoints(self, num) -> List[int]:
- return [n for n in range(0, num) if not (self.checkpoints_dir / f"checkpoint-{n}").exists()]
- def save_checkpoint(self, ds: Dataset, num: int):
- checkpoint_path = self.checkpoints_dir / ("checkpoint-" + str(num))
- ds.save_to_disk(checkpoint_path)
- def load_checkpoint(self, num: int):
- checkpoint_path = self.checkpoints_dir / ("checkpoint-" + str(num))
- if checkpoint_path.exists():
- return Dataset.load_from_disk(checkpoint_path)
- return None
- def get_checkpoints(self) -> List[Checkpoint]:
- checkpoints = []
- if not self.checkpoints_dir.exists():
- return checkpoints
- for dir_path in self.checkpoints_dir.iterdir():
- if dir_path.is_dir() and dir_path.name.startswith("checkpoint-"):
- num = int(dir_path.name.split("-")[1])
- checkpoints.append(Checkpoint(dir_path, num))
- return checkpoints
- def has_checkpoints(self) -> bool:
- return len(self.get_checkpoints()) > 0
- def collect_checkpoints(self) -> Dataset:
- ds_list = list([checkpoint.load() for checkpoint in self.get_checkpoints()])
- ds = concatenate_datasets(ds_list)
- return ds
- def delete_checkpoints(self):
- shutil.rmtree(self.checkpoints_dir)
- def checkpointed(checkpointing: Checkpointing):
- def wrapped(func):
- def wrapper(chunk_id, *args, **kwargs):
- ds = checkpointing.load_checkpoint(chunk_id)
- if ds:
- return ds
- ds = func(chunk_id=chunk_id, *args, **kwargs)
- if ds.num_rows > 0:
- checkpointing.save_checkpoint(ds, chunk_id)
- return ds
- return wrapper
- return wrapped
|