raysgd_pytorch.rst 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. .. _torch-guide:
  2. Distributed PyTorch
  3. ===================
  4. The RaySGD ``TorchTrainer`` simplifies distributed model training for PyTorch.
  5. .. image:: raysgd-actors.svg
  6. :align: center
  7. .. tip:: Get in touch with us if you're using or considering using `RaySGD <https://forms.gle/26EMwdahdgm7Lscy9>`_!
  8. The ``TorchTrainer`` is a wrapper around ``torch.distributed.launch`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to needing to wrap your training code in bash scripts.
  9. For end to end examples leveraging RaySGD TorchTrainer, jump to :ref:`raysgd-torch-examples`.
  10. .. contents:: :local:
  11. Basic Usage
  12. -----------
  13. Setting up training
  14. ~~~~~~~~~~~~~~~~~~~
  15. .. tip:: If you want to leverage multi-node data parallel training with PyTorch while using RayTune *without* using RaySGD, check out the :ref:`Tune PyTorch user guide <tune-pytorch-cifar>` and Tune's :ref:`distributed pytorch integrations <tune-ddp-doc>`.
  16. The :ref:`ref-torch-trainer` can be constructed from a custom :ref:`ref-torch-operator` subclass that defines training components like the model, data, optimizer, loss, and ``lr_scheduler``. These components are all automatically replicated across different machines and devices so that training can be executed in parallel.
  17. .. warning:: You should call ``self.register(...)`` and ``self.register_data(...)`` inside the ``setup`` method of your custom ``TrainingOperator`` to register the necessary training components with Ray SGD.
  18. .. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py
  19. :language: python
  20. :start-after: __torch_operator_start__
  21. :end-before: __torch_operator_end__
  22. Under the hood, ``TorchTrainer`` will create *replicas* of your model (controlled by ``num_workers``), each of which is managed by a Ray actor.
  23. Before instantiating the trainer, first start or connect to a Ray cluster:
  24. .. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py
  25. :language: python
  26. :start-after: __torch_ray_start__
  27. :end-before: __torch_ray_end__
  28. And then you can instantiate the trainer object using your custom ``TrainingOperator``:
  29. .. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py
  30. :language: python
  31. :start-after: __torch_trainer_start__
  32. :end-before: __torch_trainer_end__
  33. You can also set the number of workers and whether the workers will use GPUs:
  34. .. code-block:: python
  35. :emphasize-lines: 4,5
  36. trainer = TorchTrainer(
  37. training_operator_cls=MyTrainingOperator,
  38. config={"lr": 0.001},
  39. num_workers=100,
  40. use_gpu=True)
  41. Executing Training
  42. ~~~~~~~~~~~~~~~~~~
  43. Now that the trainer is constructed, here's how to train the model.
  44. .. code-block:: python
  45. for i in range(10):
  46. metrics = trainer.train()
  47. val_metrics = trainer.validate()
  48. Each ``train`` call makes one pass over the training data (trains on 1 epoch), and each ``validate`` call runs the model on the validation data.
  49. Override training and validation methods in your Training Operator (:ref:`raysgd-custom-training`) to calculate custom metrics or customize the training/validation process.
  50. .. tip:: Setting the batch size: Using a provided ``ray.util.sgd.utils.BATCH_SIZE`` variable, you can provide a global batch size that will be divided among all workers automatically.
  51. .. code-block:: python
  52. from torch.utils.data import DataLoader
  53. from ray.util.sgd.utils import BATCH_SIZE
  54. class MyTrainingOperator(TrainingOperator):
  55. def setup(self, config):
  56. ...
  57. # Create data loaders.
  58. # config[BATCH_SIZE] == provided BATCH_SIZE // num_workers
  59. train_dataset, val_dataset = LinearDataset(2, 5), LinearDataset(2, 5)
  60. train_loader = DataLoader(train_dataset, batch_size=config[BATCH_SIZE])
  61. val_loader = DataLoader(val_dataset, batch_size=config[BATCH_SIZE])
  62. ...
  63. trainer = TorchTrainer(
  64. training_operator_cls=MyTrainingOperator,
  65. config={BATCH_SIZE: 1024},
  66. num_workers=128
  67. )
  68. # Each worker will process 1024 // 128 samples per batch
  69. stats = Trainer.train()
  70. You can also obtain profiling information:
  71. .. code-block:: python
  72. >>> from ray.tune.logger import pretty_print
  73. >>> print(pretty_print(trainer.train(profile=True)))
  74. batch_count: 16
  75. epoch: 1
  76. last_train_loss: 0.15574650466442108
  77. mean_train_loss: 7.475177114367485
  78. num_samples: 1000
  79. profile:
  80. mean_apply_s: 2.639293670654297e-05
  81. mean_fwd_s: 0.00012960433959960938
  82. mean_grad_s: 0.00016553401947021483
  83. train_epoch_s: 0.023712158203125
  84. After training, you may want to reappropriate the Ray cluster. To release Ray resources obtained by the Trainer:
  85. .. code-block:: python
  86. trainer.shutdown()
  87. .. note:: Be sure to call ``trainer.save()`` or ``trainer.get_model()`` before shutting down.
  88. See the documentation on the TorchTrainer here: :ref:`ref-torch-trainer`.
  89. See the documentation on the TrainingOperator here: :ref:`ref-torch-operator`.
  90. .. _raysgd-custom-training:
  91. Custom Training and Validation
  92. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  93. If you would like to implement custom training and validation logic, you can do so by overriding the appropiate methods inside your :ref:`ref-torch-operator` subclass.
  94. For both training and validation, there are two granularities that you can provide customization - per epoch and per batch. These correspond to ``train_batch``,
  95. ``train_epoch``, ``validate``, and ``validate_batch``. Other useful methods to override include ``state_dict`` and ``load_state_dict``. You can use these
  96. to save and load additional state for your custom ``TrainingOperator``.
  97. Custom training is necessary if you are using multiple models, optimizers, or schedulers.
  98. Below is a partial example of a custom ``TrainingOperator`` that provides a ``train_batch`` implementation for a Deep Convolutional GAN.
  99. .. code-block:: python
  100. import torch
  101. from ray.util.sgd.torch import TrainingOperator
  102. class GANOperator(TrainingOperator):
  103. def setup(self, config):
  104. """Setup for this operator.
  105. This is where you define the training state and register it with Ray SGD.
  106. Args:
  107. config (dict): Custom configuration value to be passed to
  108. all creator and operator constructors. Same as ``self.config``.
  109. """
  110. ...
  111. self.models, self.optimizers, ... = self.register(...)
  112. self.register_data(...)
  113. def train_batch(self, batch, batch_info):
  114. """Trains on one batch of data from the data creator.
  115. Example taken from:
  116. https://github.com/eriklindernoren/PyTorch-GAN/blob/
  117. a163b82beff3d01688d8315a3fd39080400e7c01/implementations/dcgan/dcgan.py
  118. Args:
  119. batch: One item of the validation iterator.
  120. batch_info (dict): Information dict passed in from ``train_epoch``.
  121. Returns:
  122. A dict of metrics. Defaults to "loss" and "num_samples",
  123. corresponding to the total number of datapoints in the batch.
  124. """
  125. Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
  126. discriminator, generator = self.models
  127. optimizer_D, optimizer_G = self.optimizers
  128. # Adversarial ground truths
  129. valid = Variable(Tensor(batch.shape[0], 1).fill_(1.0), requires_grad=False)
  130. fake = Variable(Tensor(batch.shape[0], 1).fill_(0.0), requires_grad=False)
  131. # Configure input
  132. real_imgs = Variable(batch.type(Tensor))
  133. # -----------------
  134. # Train Generator
  135. # -----------------
  136. optimizer_G.zero_grad()
  137. # Sample noise as generator input
  138. z = Variable(Tensor(np.random.normal(0, 1, (
  139. batch.shape[0], self.config["latent_dim"]))))
  140. # Generate a batch of images
  141. gen_imgs = generator(z)
  142. # Loss measures generator's ability to fool the discriminator
  143. g_loss = adversarial_loss(discriminator(gen_imgs), valid)
  144. g_loss.backward()
  145. optimizer_G.step()
  146. # ---------------------
  147. # Train Discriminator
  148. # ---------------------
  149. optimizer_D.zero_grad()
  150. # Measure discriminator's ability to classify real from generated samples
  151. real_loss = adversarial_loss(discriminator(real_imgs), valid)
  152. fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
  153. d_loss = (real_loss + fake_loss) / 2
  154. d_loss.backward()
  155. optimizer_D.step()
  156. return {
  157. "loss_g": g_loss.item(),
  158. "loss_d": d_loss.item(),
  159. "num_samples": imgs.shape[0]
  160. }
  161. trainer = TorchTrainer(
  162. training_operator_cls=GANOperator,
  163. num_workers=num_workers,
  164. config=config,
  165. use_gpu=True
  166. )
  167. for i in range(5):
  168. stats = trainer.train()
  169. print(stats)
  170. See the `DCGAN example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/dcgan.py>`__ for an end to end example. It constructs two models and two optimizers and uses a custom training operator to provide a non-standard training loop.
  171. Custom DistributedDataParallel Wrappers
  172. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  173. TorchTrainer automatically applies a DistributedDataParallel wrapper to your model.
  174. .. code-block:: python
  175. DistributedDataParallel(model, device_ids=self.device_ids)
  176. You can also pass in additional arguments to DistributedDataParallel by setting the `ddp_args` field in your `TrainingOperator`.
  177. .. code-block:: python
  178. :emphasize-lines: 6
  179. from ray.util.sgd.torch import TrainingOperator
  180. class CustomOperator(TrainingOperator):
  181. def setup(self, config):
  182. ...
  183. self.model, ... = self.register(..., ddp_args={"find_unused_parameters": True})
  184. If you want to use a custom wrapper for distributed training or if you want to wrap in DistributedDataParallel yourself, you can do so by setting ``TorchTrainer(wrap_ddp=False)``.
  185. .. note:: Make sure to register the model before it is wrapped in DistributedDataParallel or a custom wrapper.
  186. .. code-block:: python
  187. :emphasize-lines: 19
  188. from ray.util.sgd.torch import TrainingOperator
  189. class CustomOperator(TrainingOperator):
  190. def setup(self, config):
  191. ...
  192. self.model, ... = self.register(...)
  193. self.new_model = CustomDataParallel(self.model,
  194. device_ids=self.device_ids)
  195. def train_batch(self, batch, batch_idx):
  196. output = self.new_model(batch)
  197. # calculate loss, etc
  198. return {"loss": loss}
  199. trainer = TorchTrainer(
  200. training_operator_cls=CustomOperator,
  201. num_workers=2,
  202. use_gpu=True
  203. wrap_ddp=False)
  204. .. _backwards-compat:
  205. Backwards Compatibility
  206. ~~~~~~~~~~~~~~~~~~~~~~~
  207. In previous versions of Ray, *creator functions* (``model_creator``, ``optimizer_creator``, etc.) were necessary to setup the training components.
  208. These creator functions are no longer used and instead training component setup should be specified inside the ``setup`` method of a ``TrainingOperator`` subclass.
  209. However, if you have these creator functions already and do not want to change your code, you can easily use these creator functions to create a custom ``TrainingOperator``.
  210. .. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py
  211. :language: python
  212. :start-after: __backwards_compat_start__
  213. :end-before: __backwards_compat_end__
  214. Initialization Functions
  215. ------------------------
  216. Use the ``initialization_hook`` parameter to initialize state on each worker process when they are started. This is useful when setting an environment variable:
  217. .. code-block:: python
  218. def initialization_hook():
  219. print("NCCL DEBUG SET")
  220. # Need this for avoiding a connection restart issue
  221. os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
  222. os.environ["NCCL_LL_THRESHOLD"] = "0"
  223. os.environ["NCCL_DEBUG"] = "INFO"
  224. trainer = TorchTrainer(
  225. training_operator_cls=MyTrainingOperator,
  226. initialization_hook=initialization_hook,
  227. config={"lr": 0.001}
  228. num_workers=100,
  229. use_gpu=True)
  230. Save and Load
  231. -------------
  232. If you want to save or reload the training procedure, you can use ``trainer.save``
  233. and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load`` calls. This should work across a distributed cluster even without a NFS because it takes advantage of Ray's distributed object store.
  234. .. tip:: Make sure to override the ``state_dict`` and ``load_state_dict`` methods in your custom TrainingOperator if necessary.
  235. .. code-block:: python
  236. checkpoint_path = os.path.join(tempfile.mkdtemp(), "checkpoint")
  237. trainer_1.save(checkpoint_path)
  238. # You can only have 1 trainer alive at a time
  239. trainer_1.shutdown()
  240. trainer_2 = TorchTrainer(
  241. training_operator_cls=MyTrainingOperator,
  242. num_workers=num_workers)
  243. trainer_2.load(checkpoint_path)
  244. Retrieving the model
  245. --------------------
  246. The trained torch model can be extracted for use within the same Python program with ``trainer.get_model()``. This will load the state dictionary of the model(s).
  247. .. code-block:: python
  248. trainer.train()
  249. model = trainer.get_model() # Returns multiple models if the model_creator does.
  250. Training & Validation Results
  251. -----------------------------
  252. The output for ``trainer.train()`` and ``trainer.validate()`` are first collected on a per-batch basis. These results are then averaged: first across each batch in the epoch, and then across all workers.
  253. By default, the output of ``train`` contains the following:
  254. .. code-block:: python
  255. # Total number of samples trained on in this epoch.
  256. num_samples
  257. # Current training epoch.
  258. epoch
  259. # Number of batches trained on in this epoch averaged across all workers.
  260. batch_count
  261. # Training loss averaged across all batches on all workers.
  262. train_loss
  263. # Training loss for the last batch in epoch averaged across all workers.
  264. last_train_loss
  265. And for ``validate``:
  266. .. code-block:: python
  267. # Total number of samples validated on.
  268. num_samples
  269. # Number of batches validated on averaged across all workers.
  270. batch_count
  271. # Validation loss averaged across all batches on all workers.
  272. val_loss
  273. # Validation loss for last batch averaged across all workers.
  274. last_val_loss
  275. # Validation accuracy for last batch averaged across all workers.
  276. val_accuracy
  277. # Validation accuracy for last batch averaged across all workers.
  278. last_val_accuracy
  279. If ``train`` or ``validate`` are run with ``reduce_results=False``, results are not averaged across workers and a list of results for each worker is returned.
  280. If run with ``profile=True``, timing stats for a single worker is returned alongside the results above.
  281. To add additional metrics to return you should implement your own custom training operator (:ref:`raysgd-custom-training`).
  282. If overriding ``train_batch`` or ``validate_batch``, the result outputs are automatically averaged across all batches, and the results for the last batch are automatically returned.
  283. If overriding ``train_epoch`` or ``validate`` you may find ``ray.util.sgd.utils.AverageMeterCollection`` (:ref:`ref-utils`) useful to handle this averaging.
  284. Mixed Precision (FP16) Training
  285. -------------------------------
  286. You can enable mixed precision training for PyTorch with the ``use_fp16`` flag. This automatically converts the model(s) and optimizer(s) to train using mixed-precision. This requires NVIDIA ``Apex``, which can be installed from `the NVIDIA/Apex repository <https://github.com/NVIDIA/apex#quick-start>`_:
  287. .. code-block:: python
  288. :emphasize-lines: 4
  289. trainer = TorchTrainer(
  290. training_operator_cls=MyTrainingOperator,
  291. num_workers=4,
  292. use_fp16=True)
  293. ``Apex`` is a Pytorch extension with NVIDIA-maintained utilities to streamline mixed precision and distributed training. When ``use_fp16=True``,
  294. you should not manually cast your model or data to ``.half()``. The flag informs the Trainer to call ``amp.initialize`` on the created models and optimizers and optimize using the scaled loss: ``amp.scale_loss(loss, optimizer)``.
  295. To specify particular parameters for ``amp.initialize``, you can use the ``apex_args`` field when calling `self.register` in your `TrainingOperator`. Valid arguments can be found on the `Apex documentation <https://nvidia.github.io/apex/amp.html#apex.amp.initialize>`_:
  296. .. code-block:: python
  297. :emphasize-lines: 8-12
  298. class MyTrainingOperator(TrainingOperator):
  299. def setup(self, config):
  300. models = [...]
  301. optimizers = [...]
  302. model, optimizer = self.register(
  303. models=models,
  304. optimizers=optimizers,
  305. apex_args={
  306. opt_level="03",
  307. num_losses=2,
  308. verbosity=0
  309. }
  310. )
  311. trainer = TorchTrainer(
  312. training_operator_cls=MyTrainingOperator,
  313. num_workers=4,
  314. use_fp16=True
  315. )
  316. Note that if implementing custom training (:ref:`raysgd-custom-training`), you will need to manage loss scaling manually.
  317. Distributed Multi-node Training
  318. -------------------------------
  319. You can scale your training to multiple nodes without making any modifications to your training code.
  320. To train across a cluster, first make sure that the Ray cluster is started (see :ref:`cluster-index` for more details).
  321. Then, in your program, you'll need to connect to this cluster via ``ray.init``:
  322. .. code-block:: python
  323. ray.init(address="auto") # or a specific redis address of the form "ip-address:port"
  324. After connecting, you can scale up the number of workers seamlessly across multiple nodes:
  325. .. code-block:: python
  326. trainer = TorchTrainer(
  327. training_operator_cls=MyTrainingOperator,
  328. num_workers=100
  329. )
  330. trainer.train()
  331. model = trainer.get_model()
  332. Advanced: Fault Tolerance
  333. -------------------------
  334. For distributed deep learning, jobs are often run on infrastructure where nodes can be pre-empted frequently (i.e., spot instances in the cloud). To overcome this, RaySGD provides **fault tolerance** features that enable training to continue regardless of node failures.
  335. .. code-block:: python
  336. trainer.train(max_retries=N)
  337. During each ``train`` method, each parallel worker iterates through the iterable, synchronizing gradients and parameters at each batch. These synchronization primitives can hang when one or more of the parallel workers becomes unresponsive (i.e., when a node is lost). To address this, we've implemented the following protocol.
  338. 1. If any worker node is lost, Ray will mark the training task as complete (``ray.wait`` will return).
  339. 2. Ray will throw ``RayActorException`` when fetching the result for any worker, so the Trainer class will call ``ray.get`` on the "finished" training task.
  340. 3. Upon catching this exception, the Trainer class will kill all of its workers.
  341. 4. The Trainer will then detect the quantity of available resources (either CPUs or GPUs). It will then restart as many workers as it can, each resuming from the last checkpoint. Note that this may result in fewer workers than initially specified.
  342. 5. If there are no available resources, the Trainer will apply an exponential backoff before retrying to create workers.
  343. 6. If there are available resources and the Trainer has fewer workers than initially specified, then it will scale up its worker pool until it reaches the initially specified ``num_workers``.
  344. Note that we assume the Trainer itself is not on a pre-emptible node. To allow the entire Trainer to recover from failure, you must use Tune to execute the training.
  345. Simultaneous Multi-model Training
  346. ---------------------------------
  347. In certain scenarios, such as training GANs, you may want to use multiple models in the training loop. You can do this by registering multiple models, optimizers, or schedulers in the ``setup`` method of ``TrainingOperator``. You must implement custom training and validation (:ref:`raysgd-custom-training`) to train across multiple models.
  348. You can see the `DCGAN script <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/dcgan.py>`_ for an end-to-end example.
  349. .. code-block:: python
  350. from ray.util.sgd.torch import TorchTrainer, TrainingOperator
  351. def train(*, model=None, criterion=None, optimizer=None, dataloader=None):
  352. model.train()
  353. train_loss = 0
  354. correct = 0
  355. total = 0
  356. for batch_idx, (inputs, targets) in enumerate(dataloader):
  357. optimizer.zero_grad()
  358. outputs = model(inputs)
  359. loss = criterion(outputs, targets)
  360. loss.backward()
  361. optimizer.step()
  362. train_loss += loss.item()
  363. _, predicted = outputs.max(1)
  364. total += targets.size(0)
  365. correct += predicted.eq(targets).sum().item()
  366. return {
  367. "accuracy": correct / total,
  368. "train_loss": train_loss / (batch_idx + 1)
  369. }
  370. def model_creator(config):
  371. return Discriminator(), Generator()
  372. def optimizer_creator(models, config):
  373. net_d, net_g = models
  374. discriminator_opt = optim.Adam(
  375. net_d.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
  376. generator_opt = optim.Adam(
  377. net_g.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
  378. return discriminator_opt, generator_opt
  379. class CustomOperator(TrainingOperator):
  380. def setup(self, config):
  381. net_d = Discriminator()
  382. net_g = Generator()
  383. d_opt = optim.Adam(
  384. net_d.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
  385. g_opt = optim.Adam(
  386. net_g.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
  387. # Setup data loaders, loss, schedulers here.
  388. ...
  389. # Register all the components.
  390. self.models, self.optimizers, ... = self.register(models=(net_d, net_g), optimizers=(d_opt, g_opt), ...)
  391. self.register_data(...)
  392. def train_epoch(self, iterator, info):
  393. result = {}
  394. for i, (model, optimizer) in enumerate(
  395. zip(self.models, self.optimizers)):
  396. result["model_{}".format(i)] = train(
  397. model=model,
  398. criterion=self.criterion,
  399. optimizer=optimizer,
  400. dataloader=iterator)
  401. return result
  402. trainer = TorchTrainer(training_operator_cls=CustomOperator)
  403. stats = trainer.train()
  404. Benchmarks
  405. ----------
  406. RaySGD TorchTrainer provides comparable or better performance than other existing solutions for parallel or distributed training.
  407. **Multi-GPU (Single Node) benchmarks**:
  408. .. code-block:: bash
  409. # Images per second for ResNet50
  410. # Batch size per worker = 128
  411. # GPU Type = V100
  412. # Run on AWS us-east-1c, p3dn.24xlarge instance.
  413. Number DataParallel Ray (PyTorch) DataParallel Ray (PyTorch)
  414. of GPUs + Apex + Apex
  415. ======= ============ ============= ============ ==============
  416. 1 355.5 356 776 770
  417. 2 656 701 1303 1346
  418. 4 1289 1401 2606 2695
  419. 8 2521 2795 4795 5862
  420. **Multi-node benchmarks**:
  421. .. code-block:: bash
  422. # Images per second for ResNet50
  423. # Batch size per worker = 128
  424. # GPU Type = V100
  425. # Run on AWS us-east-1c, p3dn.24xlarge instances.
  426. Number Horovod Ray (PyTorch) Horovod Ray (PyTorch)
  427. of GPUs + Apex + Apex
  428. ======= ======= ============= ======= ==============
  429. 1 * 8 2769.7 2962.7 5143 6172
  430. 2 * 8 5492.2 5886.1 9463 10052.8
  431. 4 * 8 10733.4 11705.9 18807 20319.5
  432. 8 * 8 21872.5 23317.9 36911.8 38642
  433. You can see more details in the `benchmarking README <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/benchmarks/README.rst>`_.
  434. DISCLAIMER: RaySGD does not provide any custom communication primitives. If you see any performance issues, you may need to file them on the PyTorch github repository.
  435. Debugging/Tips
  436. --------------
  437. Here's some simple tips on how to debug the TorchTrainer.
  438. **My TorchTrainer implementation is erroring after I ported things over from my previous code.**
  439. Try using ``ipdb`` and ``num_workers=1``. This will provide you introspection what is being called and when.
  440. .. code-block:: python
  441. # first run pip install ipdb
  442. from ray.util.sgd.torch import TrainingOperator
  443. class CustomOperator(TrainingOperator):
  444. def setup(self, config):
  445. import ipdb; ipdb.set_trace()
  446. ...
  447. def train_batch(self, batch, batch_idx):
  448. import ipdb; ipdb.set_trace()
  449. ... # press 'n' or 's' to navigate the session
  450. ... # custom code if exists?
  451. ... # or super(CustomOperator, self).train_batch(batch, batch_idx)
  452. trainer = TorchTrainer(
  453. training_operator_cls=GANOperator,
  454. num_workers=1,
  455. )
  456. **My TorchTrainer implementation is super slow.**
  457. Try using a profiler. Either use:
  458. .. code-block:: python
  459. trainer.train(profile=True)
  460. trainer.validate(profile=True)
  461. or use `Python profiling <https://docs.python.org/3/library/debug.html>`_.
  462. **My setup function downloads data, and I don't want multiple processes downloading to the same path at once.**
  463. Use ``FileLock`` to create locks for critical regions. For example:
  464. .. code-block:: python
  465. import os
  466. import tempfile
  467. from filelock import FileLock
  468. def create_dataset(config):
  469. dataset_path = config["dataset_path"]
  470. # Create a critical region of the code
  471. # This will take a longer amount of time to download the data at first.
  472. # Other processes will block at the ``with`` statement.
  473. # After downloading, this code block becomes very fast.
  474. with FileLock(os.path.join(tempfile.gettempdir(), "download_data.lock")):
  475. if not os.path.exists(dataset_path):
  476. download_data(dataset_path)
  477. # load_data is assumed to safely support concurrent reads.
  478. data = load_data(dataset_path)
  479. return DataLoader(data)
  480. **I get a 'socket timeout' error during training.**
  481. Try increasing the length of the NCCL timeout. The current timeout is 10 seconds.
  482. .. code-block:: bash
  483. NCCL_TIMEOUT_S=1000 python ray_training_script.py
  484. # or
  485. NCCL_TIMEOUT_S=1000 ray start [--head | --address]
  486. Feature Requests
  487. ----------------
  488. Have features that you'd really like to see in RaySGD? Feel free to `open an issue <https://github.com/ray-project/ray>`_.
  489. .. _raysgd-torch-examples:
  490. TorchTrainer Examples
  491. -----------------------
  492. Here are some examples of using RaySGD for training PyTorch models. If you'd like
  493. to contribute an example, feel free to create a `pull request here <https://github.com/ray-project/ray/>`_.
  494. - `Torch training example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/train_example.py>`__
  495. Simple example of using Ray's TorchTrainer.
  496. - `TorchTrainer and RayTune example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/tune_example.py>`__
  497. Simple example of hyperparameter tuning with Ray's TorchTrainer.
  498. - `Semantic Segmentation example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/segmentation/train_segmentation.py>`__
  499. Fine-tuning a ResNet50 model on VOC with Batch Norm.
  500. - `Huggingface Transformer GLUE fine tuning example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/transformers/transformers_example.py>`__
  501. Fine-tuning a pre-trained Transformer model on GLUE tasks. Based off of the `huggingface/transformers <https://github.com/huggingface/transformers/blob/master/examples/>`_ ``run_glue.py`` example.
  502. - `ImageNet Models example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/image_models/train.py>`__
  503. Training state-of-the-art ImageNet models.
  504. - `CIFAR10 example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py>`__
  505. Training a ResNet18 model on CIFAR10.
  506. - `CIFAR10 RayTune example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/cifar_pytorch_pbt.py>`__
  507. Tuning a ResNet18 model on CIFAR10 with Population-based training on RayTune.
  508. - `DCGAN example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/dcgan.py>`__
  509. Training a Deep Convolutional GAN on MNIST. It constructs two models and two optimizers and uses a custom training operator.