test_durable_trainable.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. """Durable trainable (16 trials, checkpoint to cloud)
  2. In this run, we will start 16 trials on a cluster. The trials create
  3. 10 MB checkpoints every 12 seconds and should only keep 2 of these. This test
  4. ensures that durable checkpoints don't slow down experiment progress too much.
  5. Cluster: cluster_16x2.yaml
  6. Test owner: krfricke
  7. Acceptance criteria: Should run faster than 500 seconds.
  8. Theoretical minimum time: 300 seconds
  9. """
  10. import argparse
  11. import os
  12. import ray
  13. from ray.tune.utils.release_test_util import timed_tune_run
  14. def main(bucket):
  15. secrets_file = os.path.join(os.path.dirname(__file__), "..", "aws_secrets.txt")
  16. if os.path.isfile(secrets_file):
  17. print(f"Loading AWS secrets from file {secrets_file}")
  18. from configparser import ConfigParser
  19. config = ConfigParser()
  20. config.read(secrets_file)
  21. for k, v in config.items():
  22. for x, y in v.items():
  23. var = str(x).upper()
  24. os.environ[var] = str(y)
  25. else:
  26. print("No AWS secrets file found. Loading from boto.")
  27. try:
  28. from boto3 import Session
  29. session = Session()
  30. credentials = session.get_credentials()
  31. current_credentials = credentials.get_frozen_credentials()
  32. os.environ["AWS_ACCESS_KEY_ID"] = current_credentials.access_key
  33. os.environ["AWS_SECRET_ACCESS_KEY"] = current_credentials.secret_key
  34. os.environ["AWS_SESSION_TOKEN"] = current_credentials.token
  35. except Exception:
  36. print("Cannot setup AWS credentials (is this running on GCE?)")
  37. if all(
  38. os.getenv(k, "")
  39. for k in [
  40. "AWS_ACCESS_KEY_ID",
  41. "AWS_SECRET_ACCESS_KEY",
  42. "AWS_SESSION_TOKEN",
  43. ]
  44. ):
  45. print("AWS secrets found in env.")
  46. else:
  47. print("Warning: No AWS secrets found in env!")
  48. ray.init(address="auto")
  49. num_samples = 16
  50. results_per_second = 5 / 60 # 5 results per minute = 1 every 12 seconds
  51. trial_length_s = 300
  52. max_runtime = 650
  53. timed_tune_run(
  54. name="durable trainable",
  55. num_samples=num_samples,
  56. results_per_second=results_per_second,
  57. trial_length_s=trial_length_s,
  58. max_runtime=max_runtime,
  59. checkpoint_freq_s=12, # Once every 12 seconds (once per result)
  60. checkpoint_size_b=int(10 * 1000**2), # 10 MB
  61. keep_checkpoints_num=2,
  62. resources_per_trial={"cpu": 2},
  63. storage_path=bucket,
  64. )
  65. if __name__ == "__main__":
  66. parser = argparse.ArgumentParser()
  67. parser.add_argument("--bucket", type=str, help="Bucket name")
  68. args, _ = parser.parse_known_args()
  69. main(args.bucket or "ray-tune-scalability-test")