script.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. """Run Tune with frequent pausing.
  2. See context https://github.com/ray-project/ray/issues/34197.
  3. m5.large node has memory of 7.2 GB. With `RAY_memory_usage_threshold=0.5`,
  4. if the node's memory exceeds 3.6 GB, any new tasks would be killed.
  5. Note this node memory is also shared by processes like ray dashboard etc.
  6. Without ray object store reference leakage from application code, all these
  7. background processes take less than 2 GB of memory all together.
  8. With reference leakage, we reach 3.6 GB threshold within 5 minutes
  9. at the time when this test was written.
  10. success criteria: run through 10min without crash.
  11. cost: A few dollars.
  12. """
  13. import numpy as np
  14. import os
  15. import pickle
  16. import tempfile
  17. from ray import train
  18. from ray.train import Checkpoint, RunConfig
  19. from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
  20. from ray.tune.tune_config import TuneConfig
  21. from ray.tune.tuner import Tuner
  22. def func(config):
  23. starting_epoch = 0
  24. checkpoint = train.get_checkpoint()
  25. if checkpoint:
  26. with checkpoint.as_directory() as checkpoint_dir:
  27. with open(os.path.join(checkpoint_dir, "ckpt.pkl"), "rb") as f:
  28. checkpoint_dict = pickle.load(f)
  29. checkpoint_epoch = checkpoint_dict["epoch"]
  30. starting_epoch = checkpoint_epoch + 1
  31. for epoch in range(starting_epoch, 1000):
  32. checkpoint_dict = {"epoch": epoch, "large_data": np.zeros(10000000)}
  33. with tempfile.TemporaryDirectory() as tmpdir:
  34. with open(os.path.join(tmpdir, "ckpt.pkl"), "wb") as f:
  35. pickle.dump(checkpoint_dict, f)
  36. train.report({}, checkpoint=Checkpoint.from_directory(tmpdir))
  37. class FrequentPausesScheduler(FIFOScheduler):
  38. def on_trial_result(self, tune_controller, trial, result):
  39. return TrialScheduler.PAUSE
  40. tuner = Tuner(
  41. func,
  42. tune_config=TuneConfig(num_samples=2, scheduler=FrequentPausesScheduler()),
  43. run_config=RunConfig(storage_path="/mnt/cluster_storage", name="frequent_pausing"),
  44. )
  45. tuner.fit()