test_tuner.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import os
  2. import time
  3. import json
  4. from pytorch_lightning.loggers.csv_logs import CSVLogger
  5. import ray
  6. import ray.tune as tune
  7. from ray.air.config import CheckpointConfig, ScalingConfig
  8. from ray.train.lightning import LightningTrainer, LightningConfigBuilder
  9. from ray.tune.schedulers import PopulationBasedTraining
  10. from lightning_test_utils import MNISTClassifier, MNISTDataModule
  11. if __name__ == "__main__":
  12. ray.init(address="auto", runtime_env={"working_dir": os.path.dirname(__file__)})
  13. start = time.time()
  14. lightning_config = (
  15. LightningConfigBuilder()
  16. .module(
  17. MNISTClassifier,
  18. feature_dim=tune.choice([64, 128]),
  19. lr=tune.grid_search([0.01, 0.001]),
  20. )
  21. .trainer(
  22. max_epochs=5,
  23. accelerator="gpu",
  24. logger=CSVLogger("logs", name="my_exp_name"),
  25. )
  26. .fit_params(datamodule=MNISTDataModule(batch_size=200))
  27. .checkpointing(monitor="val_accuracy", mode="max")
  28. .build()
  29. )
  30. scaling_config = ScalingConfig(
  31. num_workers=3, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
  32. )
  33. lightning_trainer = LightningTrainer(
  34. scaling_config=scaling_config,
  35. )
  36. mutation_config = (
  37. LightningConfigBuilder()
  38. .module(
  39. lr=tune.choice([0.01, 0.001]),
  40. )
  41. .build()
  42. )
  43. tuner = tune.Tuner(
  44. lightning_trainer,
  45. param_space={"lightning_config": lightning_config},
  46. run_config=ray.air.RunConfig(
  47. name="release-tuner-test",
  48. verbose=2,
  49. checkpoint_config=CheckpointConfig(
  50. num_to_keep=2,
  51. checkpoint_score_attribute="val_accuracy",
  52. checkpoint_score_order="max",
  53. ),
  54. ),
  55. tune_config=tune.TuneConfig(
  56. metric="val_accuracy",
  57. mode="max",
  58. num_samples=2,
  59. scheduler=PopulationBasedTraining(
  60. time_attr="training_iteration",
  61. hyperparam_mutations={"lightning_config": mutation_config},
  62. perturbation_interval=1,
  63. ),
  64. ),
  65. )
  66. results = tuner.fit()
  67. best_result = results.get_best_result(metric="val_accuracy", mode="max")
  68. best_result
  69. assert len(results.errors) == 0
  70. taken = time.time() - start
  71. # Report experiment results
  72. result = {
  73. "time_taken": taken,
  74. "val_accuracy": best_result.metrics["val_accuracy"],
  75. }
  76. test_output_json = os.environ.get(
  77. "TEST_OUTPUT_JSON",
  78. "/tmp/lightning_gpu_tuner_test.json",
  79. )
  80. with open(test_output_json, "wt") as f:
  81. json.dump(result, f)
  82. print("Test Successful!")