123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- # This file is duplicated in ray/tests/ray_lightning
- import argparse
- import os
- import torch
- from torch import nn
- import torch.nn.functional as F
- from torchvision.datasets import MNIST
- from torch.utils.data import DataLoader, random_split
- from torchvision import transforms
- import pytorch_lightning as pl
- from importlib_metadata import version
- from packaging.version import parse as v_parse
- rlt_use_master = v_parse(version("ray_lightning")) >= v_parse("0.3.0")
- if rlt_use_master:
- # ray_lightning >= 0.3.0
- from ray_lightning import RayStrategy
- else:
- # ray_lightning < 0.3.0
- from ray_lightning import RayPlugin
- class LitAutoEncoder(pl.LightningModule):
- def __init__(self):
- super().__init__()
- self.encoder = nn.Sequential(
- nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)
- )
- self.decoder = nn.Sequential(
- nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)
- )
- def forward(self, x):
- # in lightning, forward defines the prediction/inference actions
- embedding = self.encoder(x)
- return embedding
- def training_step(self, batch, batch_idx):
- # training_step defines the train loop. It is independent of forward
- x, y = batch
- x = x.view(x.size(0), -1)
- z = self.encoder(x)
- x_hat = self.decoder(z)
- loss = F.mse_loss(x_hat, x)
- self.log("train_loss", loss)
- return loss
- def configure_optimizers(self):
- optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
- return optimizer
- def main(num_workers: int = 2, use_gpu: bool = False, max_steps: int = 10):
- dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
- train, val = random_split(dataset, [55000, 5000])
- autoencoder = LitAutoEncoder()
- if rlt_use_master:
- trainer = pl.Trainer(
- strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu),
- max_steps=max_steps,
- )
- else:
- trainer = pl.Trainer(
- plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)],
- max_steps=max_steps,
- )
- trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Ray Lightning Example",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "--num-workers",
- type=int,
- default=2,
- help="Number of workers to use for training.",
- )
- parser.add_argument(
- "--max-steps",
- type=int,
- default=10,
- help="Maximum number of steps to run for training.",
- )
- parser.add_argument(
- "--use-gpu",
- action="store_true",
- default=False,
- help="Whether to enable GPU training.",
- )
- args = parser.parse_args()
- main(num_workers=args.num_workers, max_steps=args.max_steps, use_gpu=args.use_gpu)
|