12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import os
- import time
- import json
- from pytorch_lightning.loggers.csv_logs import CSVLogger
- import ray
- from ray.air.config import ScalingConfig
- from ray.train.lightning import LightningTrainer, LightningConfigBuilder
- from lightning_test_utils import MNISTClassifier, MNISTDataModule
- if __name__ == "__main__":
- ray.init(address="auto", runtime_env={"working_dir": os.path.dirname(__file__)})
- start = time.time()
- lightning_config = (
- LightningConfigBuilder()
- .module(MNISTClassifier, feature_dim=128, lr=0.001)
- .trainer(
- max_epochs=3,
- accelerator="gpu",
- logger=CSVLogger("logs", name="my_exp_name"),
- )
- .fit_params(datamodule=MNISTDataModule(batch_size=128))
- .checkpointing(monitor="val_accuracy", mode="max", save_last=True)
- .build()
- )
- scaling_config = ScalingConfig(
- num_workers=3, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
- )
- trainer = LightningTrainer(
- lightning_config=lightning_config,
- scaling_config=scaling_config,
- )
- result = trainer.fit()
- taken = time.time() - start
- result = {
- "time_taken": taken,
- "val_accuracy": result.metrics["val_accuracy"],
- }
- test_output_json = os.environ.get(
- "TEST_OUTPUT_JSON", "/tmp/lightning_trainer_test.json"
- )
- with open(test_output_json, "wt") as f:
- json.dump(result, f)
- print("Test Successful!")
|