simple_example.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # This file is duplicated in ray/tests/ray_lightning
  2. import argparse
  3. import os
  4. import torch
  5. from torch import nn
  6. import torch.nn.functional as F
  7. from torchvision.datasets import MNIST
  8. from torch.utils.data import DataLoader, random_split
  9. from torchvision import transforms
  10. import pytorch_lightning as pl
  11. from importlib_metadata import version
  12. from packaging.version import parse as v_parse
  13. rlt_use_master = v_parse(version("ray_lightning")) >= v_parse("0.3.0")
  14. if rlt_use_master:
  15. # ray_lightning >= 0.3.0
  16. from ray_lightning import RayStrategy
  17. else:
  18. # ray_lightning < 0.3.0
  19. from ray_lightning import RayPlugin
  20. class LitAutoEncoder(pl.LightningModule):
  21. def __init__(self):
  22. super().__init__()
  23. self.encoder = nn.Sequential(
  24. nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)
  25. )
  26. self.decoder = nn.Sequential(
  27. nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)
  28. )
  29. def forward(self, x):
  30. # in lightning, forward defines the prediction/inference actions
  31. embedding = self.encoder(x)
  32. return embedding
  33. def training_step(self, batch, batch_idx):
  34. # training_step defines the train loop. It is independent of forward
  35. x, y = batch
  36. x = x.view(x.size(0), -1)
  37. z = self.encoder(x)
  38. x_hat = self.decoder(z)
  39. loss = F.mse_loss(x_hat, x)
  40. self.log("train_loss", loss)
  41. return loss
  42. def configure_optimizers(self):
  43. optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
  44. return optimizer
  45. def main(num_workers: int = 2, use_gpu: bool = False, max_steps: int = 10):
  46. dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
  47. train, val = random_split(dataset, [55000, 5000])
  48. autoencoder = LitAutoEncoder()
  49. if rlt_use_master:
  50. trainer = pl.Trainer(
  51. strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu),
  52. max_steps=max_steps,
  53. )
  54. else:
  55. trainer = pl.Trainer(
  56. plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)],
  57. max_steps=max_steps,
  58. )
  59. trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
  60. if __name__ == "__main__":
  61. parser = argparse.ArgumentParser(
  62. description="Ray Lightning Example",
  63. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  64. )
  65. parser.add_argument(
  66. "--num-workers",
  67. type=int,
  68. default=2,
  69. help="Number of workers to use for training.",
  70. )
  71. parser.add_argument(
  72. "--max-steps",
  73. type=int,
  74. default=10,
  75. help="Maximum number of steps to run for training.",
  76. )
  77. parser.add_argument(
  78. "--use-gpu",
  79. action="store_true",
  80. default=False,
  81. help="Whether to enable GPU training.",
  82. )
  83. args = parser.parse_args()
  84. main(num_workers=args.num_workers, max_steps=args.max_steps, use_gpu=args.use_gpu)