checkpointing.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from dataclasses import dataclass
  2. from pathlib import Path
  3. from typing import List
  4. from datasets import Dataset, concatenate_datasets
  5. import logging
  6. import shutil
  7. logger = logging.getLogger("raft")
  8. @dataclass
  9. class Checkpoint:
  10. path: Path
  11. num: int
  12. def load(self) -> Dataset:
  13. return Dataset.load_from_disk(self.path)
  14. def __lt__(self, other: 'Checkpoint') -> bool:
  15. return self.num < other.num
  16. def __eq__(self, other: 'Checkpoint') -> bool:
  17. return self.num == other.num
  18. def __hash__(self) -> int:
  19. return hash(self.num)
  20. class Checkpointing:
  21. def __init__(self, checkpoints_dir: Path) -> None:
  22. self.checkpoints_dir = checkpoints_dir
  23. def missing_checkpoints(self, num) -> List[int]:
  24. return [n for n in range(0, num) if not (self.checkpoints_dir / f"checkpoint-{n}").exists()]
  25. def save_checkpoint(self, ds: Dataset, num: int):
  26. checkpoint_path = self.checkpoints_dir / ("checkpoint-" + str(num))
  27. ds.save_to_disk(checkpoint_path)
  28. def load_checkpoint(self, num: int):
  29. checkpoint_path = self.checkpoints_dir / ("checkpoint-" + str(num))
  30. if checkpoint_path.exists():
  31. return Dataset.load_from_disk(checkpoint_path)
  32. return None
  33. def get_checkpoints(self) -> List[Checkpoint]:
  34. checkpoints = []
  35. if not self.checkpoints_dir.exists():
  36. return checkpoints
  37. for dir_path in self.checkpoints_dir.iterdir():
  38. if dir_path.is_dir() and dir_path.name.startswith("checkpoint-"):
  39. num = int(dir_path.name.split("-")[1])
  40. checkpoints.append(Checkpoint(dir_path, num))
  41. return checkpoints
  42. def has_checkpoints(self) -> bool:
  43. return len(self.get_checkpoints()) > 0
  44. def collect_checkpoints(self) -> Dataset:
  45. ds_list = list([checkpoint.load() for checkpoint in self.get_checkpoints()])
  46. ds = concatenate_datasets(ds_list)
  47. return ds
  48. def delete_checkpoints(self):
  49. shutil.rmtree(self.checkpoints_dir)
  50. def checkpointed(checkpointing: Checkpointing):
  51. def wrapped(func):
  52. def wrapper(chunk_id, *args, **kwargs):
  53. ds = checkpointing.load_checkpoint(chunk_id)
  54. if ds:
  55. return ds
  56. ds = func(chunk_id=chunk_id, *args, **kwargs)
  57. if ds.num_rows > 0:
  58. checkpointing.save_checkpoint(ds, chunk_id)
  59. return ds
  60. return wrapper
  61. return wrapped