raysgd.rst 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. .. _sgd-index:
  2. =====================================
  3. RaySGD: Distributed Training Wrappers
  4. =====================================
  5. .. _`issue on GitHub`: https://github.com/ray-project/ray/issues
  6. RaySGD is a lightweight library for distributed deep learning, providing thin wrappers around PyTorch and TensorFlow native modules for data parallel training.
  7. The main features are:
  8. - **Ease of use**: Scale PyTorch's native ``DistributedDataParallel`` and TensorFlow's ``tf.distribute.MirroredStrategy`` without needing to monitor individual nodes.
  9. - **Composability**: RaySGD is built on top of the Ray Actor API, enabling seamless integration with existing Ray applications such as RLlib, Tune, and Ray.Serve.
  10. - **Scale up and down**: Start on single CPU. Scale up to multi-node, multi-CPU, or multi-GPU clusters by changing 2 lines of code.
  11. Getting Started
  12. ---------------
  13. You can start a ``TorchTrainer`` with the following:
  14. .. code-block:: python
  15. import ray
  16. from ray.util.sgd import TorchTrainer
  17. from ray.util.sgd.torch import TrainingOperator
  18. from ray.util.sgd.torch.examples.train_example import LinearDataset
  19. import torch
  20. from torch.utils.data import DataLoader
  21. class CustomTrainingOperator(TrainingOperator):
  22. def setup(self, config):
  23. # Load data.
  24. train_loader = DataLoader(LinearDataset(2, 5), config["batch_size"])
  25. val_loader = DataLoader(LinearDataset(2, 5), config["batch_size"])
  26. # Create model.
  27. model = torch.nn.Linear(1, 1)
  28. # Create optimizer.
  29. optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
  30. # Create loss.
  31. loss = torch.nn.MSELoss()
  32. # Register model, optimizer, and loss.
  33. self.model, self.optimizer, self.criterion = self.register(
  34. models=model,
  35. optimizers=optimizer,
  36. criterion=loss)
  37. # Register data loaders.
  38. self.register_data(train_loader=train_loader, validation_loader=val_loader)
  39. ray.init()
  40. trainer1 = TorchTrainer(
  41. training_operator_cls=CustomTrainingOperator,
  42. num_workers=2,
  43. use_gpu=False,
  44. config={"batch_size": 64})
  45. stats = trainer1.train()
  46. print(stats)
  47. trainer1.shutdown()
  48. print("success!")
  49. .. tip:: Get in touch with us if you're using or considering using `RaySGD <https://forms.gle/26EMwdahdgm7Lscy9>`_!