test_trainer.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. import time
  3. import json
  4. from pytorch_lightning.loggers.csv_logs import CSVLogger
  5. import ray
  6. from ray.air.config import ScalingConfig
  7. from ray.train.lightning import LightningTrainer, LightningConfigBuilder
  8. from lightning_test_utils import MNISTClassifier, MNISTDataModule
  9. if __name__ == "__main__":
  10. ray.init(address="auto", runtime_env={"working_dir": os.path.dirname(__file__)})
  11. start = time.time()
  12. lightning_config = (
  13. LightningConfigBuilder()
  14. .module(MNISTClassifier, feature_dim=128, lr=0.001)
  15. .trainer(
  16. max_epochs=3,
  17. accelerator="gpu",
  18. logger=CSVLogger("logs", name="my_exp_name"),
  19. )
  20. .fit_params(datamodule=MNISTDataModule(batch_size=128))
  21. .checkpointing(monitor="val_accuracy", mode="max", save_last=True)
  22. .build()
  23. )
  24. scaling_config = ScalingConfig(
  25. num_workers=3, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
  26. )
  27. trainer = LightningTrainer(
  28. lightning_config=lightning_config,
  29. scaling_config=scaling_config,
  30. )
  31. result = trainer.fit()
  32. taken = time.time() - start
  33. result = {
  34. "time_taken": taken,
  35. "val_accuracy": result.metrics["val_accuracy"],
  36. }
  37. test_output_json = os.environ.get(
  38. "TEST_OUTPUT_JSON", "/tmp/lightning_trainer_test.json"
  39. )
  40. with open(test_output_json, "wt") as f:
  41. json.dump(result, f)
  42. print("Test Successful!")