test_simple.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from pytorch_lightning import LightningModule, Trainer
  6. from pytorch_lightning.strategies import DeepSpeedStrategy
  7. from torch.utils.data import DataLoader, Dataset
  8. class RandomDataset(Dataset):
  9. def __init__(self, size, length):
  10. self.len = length
  11. self.data = torch.randn(length, size)
  12. def __getitem__(self, index):
  13. return self.data[index]
  14. def __len__(self):
  15. return self.len
  16. class BoringModel(LightningModule):
  17. def __init__(self):
  18. super().__init__()
  19. self.layer = torch.nn.Linear(32, 2)
  20. def forward(self, x):
  21. return self.layer(x)
  22. def training_step(self, batch, batch_idx):
  23. loss = self(batch).sum()
  24. self.log("train_loss", loss)
  25. return {"loss": loss}
  26. def validation_step(self, batch, batch_idx):
  27. loss = self(batch).sum()
  28. self.log("valid_loss", loss)
  29. def test_step(self, batch, batch_idx):
  30. loss = self(batch).sum()
  31. self.log("test_loss", loss)
  32. def configure_optimizers(self):
  33. return torch.optim.SGD(self.layer.parameters(), lr=0.1)
  34. def train_dataloader(self):
  35. return DataLoader(RandomDataset(32, 64), batch_size=2)
  36. def val_dataloader(self):
  37. return DataLoader(RandomDataset(32, 64), batch_size=2)
  38. def test_lightning_model():
  39. """Test that DeepSpeed works with a simple LightningModule and LightningDataModule."""
  40. model = BoringModel()
  41. trainer = Trainer(strategy=DeepSpeedStrategy(), max_epochs=1, precision=16, accelerator="gpu", devices=1)
  42. trainer.fit(model)