ray_train.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import torch
  2. import ray.train as train
  3. from ray.train.torch import TorchTrainer, TorchCheckpoint
  4. from ray.air import ScalingConfig, session
  5. def train_func():
  6. # Setup model.
  7. model = torch.nn.Linear(1, 1)
  8. model = train.torch.prepare_model(model)
  9. loss_fn = torch.nn.MSELoss()
  10. optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
  11. # Setup data.
  12. input = torch.randn(1000, 1)
  13. labels = input * 2
  14. dataset = torch.utils.data.TensorDataset(input, labels)
  15. dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
  16. dataloader = train.torch.prepare_data_loader(dataloader)
  17. # Train.
  18. for _ in range(5):
  19. for X, y in dataloader:
  20. pred = model(X)
  21. loss = loss_fn(pred, y)
  22. optimizer.zero_grad()
  23. loss.backward()
  24. optimizer.step()
  25. session.report({"loss": loss.item()})
  26. session.report({}, checkpoint=TorchCheckpoint.from_model(model))
  27. trainer = TorchTrainer(train_func, scaling_config=ScalingConfig(num_workers=4))
  28. results = trainer.fit()
  29. print(results.metrics)
  30. print(results.checkpoint)