raysgd_ptl.rst 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. Pytorch Lightning with RaySGD
  2. ==============================
  3. .. image:: /images/sgd_ptl.png
  4. :align: center
  5. :scale: 50 %
  6. RaySGD includes an integration with Pytorch Lightning's `LightningModule <https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html>`_.
  7. Easily take your existing ``LightningModule``, and use it with Ray SGD's ``TorchTrainer`` to take advantage of all of Ray SGD's distributed training features with minimal code changes.
  8. .. tip:: This LightningModule integration is currently under active development. If you encounter any bugs, please raise an issue on `Github <https://github.com/ray-project/ray/issues>`_!
  9. .. note:: Not all Pytorch Lightning features are supported. A full list of unsupported model hooks is listed down :ref:`below <ptl-unsupported-features>`. Please post any feature requests on `Github <https://github.com/ray-project/ray/issues>`_ and we will get to it shortly!
  10. .. contents::
  11. :local:
  12. Quick Start
  13. -----------
  14. Step 1: Define your ``LightningModule`` just like how you would with Pytorch Lightning.
  15. .. code-block:: python
  16. from pytorch_lightning.core.lightning import LightningModule
  17. class MyLightningModule(LightningModule):
  18. ...
  19. Step 2: Use the ``TrainingOperator.from_ptl`` method to convert the ``LightningModule`` to a Ray SGD compatible ``LightningOperator``.
  20. .. code-block:: python
  21. from ray.util.sgd.torch import TrainingOperator
  22. MyLightningOperator = TrainingOperator.from_ptl(MyLightningModule)
  23. Step 3: Use the Operator with Ray SGD's ``TorchTrainer``, just like how you would normally. See :ref:`torch-guide` for a more full guide on ``TorchTrainer``.
  24. .. code-block:: python
  25. import ray
  26. from ray.util.sgd.torch import TorchTrainer
  27. ray.init()
  28. trainer = TorchTrainer(training_operator_cls=MyLightningOperator, num_workers=4, use_gpu=True)
  29. train_stats = trainer.train()
  30. And that's it! For a more comprehensive guide, see the MNIST tutorial :ref:`below <ptl-mnist>`.
  31. .. _ptl-mnist:
  32. MNIST Tutorial
  33. --------------
  34. In this walkthrough we will go through how to train an MNIST classifier with Pytorch Lightning's ``LightningModule`` and Ray SGD.
  35. We will follow `this tutorial from the PyTorch Lightning documentation
  36. <https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html>`_ for specifying our MNIST LightningModule.
  37. Setup / Imports
  38. ~~~~~~~~~~~~~~~
  39. Let's start with some basic imports:
  40. .. literalinclude:: /../../python/ray/util/sgd/torch/examples/pytorch-lightning/mnist-ptl.py
  41. :language: python
  42. :start-after: __import_begin__
  43. :end-before: __import_end__
  44. Most of these imports are needed for building our Pytorch model and training components.
  45. Only a few additional imports are needed for Ray and Pytorch Lightning.
  46. MNIST LightningModule
  47. ~~~~~~~~~~~~~~~~~~~~~
  48. We now define our Pytorch Lightning ``LightningModule``:
  49. .. literalinclude:: /../../python/ray/util/sgd/torch/examples/pytorch-lightning/mnist-ptl.py
  50. :language: python
  51. :start-after: __ptl_begin__
  52. :end-before: __ptl_end__
  53. This is the same code that would normally be used in Pytorch Lightning, and is taken directly from `this PTL guide <https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html>`_.
  54. The only difference here is that the ``__init__`` method can optionally take in a ``config`` argument,
  55. as a way to pass in hyperparameters to your model, optimizer, or schedulers. The ``config`` will be passed in directly from
  56. the TorchTrainer. Or if using Ray SGD in conjunction with Tune (:ref:`raysgd-tune`), it will come directly from the config in your
  57. ``tune.run`` call.
  58. Training with Ray SGD
  59. ~~~~~~~~~~~~~~~~~~~~~
  60. We now can define our training function using our LitMNIST module and Ray SGD.
  61. .. literalinclude:: /../../python/ray/util/sgd/torch/examples/pytorch-lightning/mnist-ptl.py
  62. :language: python
  63. :start-after: __train_begin__
  64. :end-before: __train_end__
  65. With just a single ``from_ptl`` call, we can convert our LightningModule to a ``TrainingOperator`` class that's compatible
  66. with Ray SGD. Now we can take full advantage of all of Ray SGD's distributed trainign features without having to rewrite our existing
  67. LightningModule.
  68. The last thing to do is initialize Ray, and run our training function!
  69. .. code-block:: python
  70. # Use ray.init(address="auto") if running on a Ray cluster.
  71. ray.init()
  72. train_mnist(num_workers=32, use_gpu=True, num_epochs=5)
  73. .. _ptl-unsupported-features:
  74. Unsupported Features
  75. --------------------
  76. This integration is currently under active development, so not all Pytorch Lightning features are supported.
  77. Please post any feature requests on `Github
  78. <https://github.com/ray-project/ray/issues>`_ and we will get to it shortly!
  79. A list of unsupported model hooks (as of v1.0.0) is as follows:
  80. ``test_dataloader``, ``on_test_batch_start``, ``on_test_epoch_start``, ``on_test_batch_end``, ``on_test_epoch_start``,
  81. ``get_progress_bar_dict``, ``on_fit_end``, ``on_pretrain_routine_end``, ``manual_backward``, ``tbtt_split_batch``.