dl_guide.rst 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298
  1. .. _train-dl-guide:
  2. Distributed Deep Learning with Ray Train User Guide
  3. ===================================================
  4. This guide explains how to use Train to scale PyTorch, TensorFlow and Horovod.
  5. In this guide, we cover examples for the following use cases:
  6. * How do I :ref:`port my code <train-porting-code>` to use Ray Train?
  7. * How do I use Ray Train to :ref:`train with a large dataset <train-datasets>`?
  8. * How do I :ref:`monitor <train-monitoring>` my training?
  9. * How do I run my training on pre-emptible instances
  10. (:ref:`fault tolerance <train-fault-tolerance>`)?
  11. * How do I :ref:`tune <train-tune>` my Ray Train model?
  12. .. _train-backends:
  13. Using Deep Learning Frameworks as Backends
  14. ------------------------------------------
  15. Ray Train provides a thin API around different backend frameworks for
  16. distributed deep learning. At the moment, Ray Train allows you to perform
  17. training with:
  18. * **PyTorch:** Ray Train initializes your distributed process group, allowing
  19. you to run your ``DistributedDataParallel`` training script. See `PyTorch
  20. Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`_
  21. for more information.
  22. * **TensorFlow:** Ray Train configures ``TF_CONFIG`` for you, allowing you to run
  23. your ``MultiWorkerMirroredStrategy`` training script. See `Distributed
  24. training with TensorFlow <https://www.tensorflow.org/guide/distributed_training>`_
  25. for more information.
  26. * **Horovod:** Ray Train configures the Horovod environment and Rendezvous
  27. server for you, allowing you to run your ``DistributedOptimizer`` training
  28. script. See `Horovod documentation <https://horovod.readthedocs.io/en/stable/index.html>`_
  29. for more information.
  30. .. _train-porting-code:
  31. Porting code from PyTorch, TensorFlow, or Horovod to Ray Train
  32. --------------------------------------------------------------
  33. The following instructions assume you have a training function
  34. that can already be run on a single worker for one of the supported
  35. :ref:`backend <train-backends>` frameworks.
  36. Updating your training function
  37. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  38. First, you'll want to update your training function to support distributed
  39. training.
  40. .. tab-set::
  41. .. tab-item:: PyTorch
  42. Ray Train will set up your distributed process group for you and also provides utility methods
  43. to automatically prepare your model and data for distributed training.
  44. .. note::
  45. Ray Train will still work even if you don't use the :func:`ray.train.torch.prepare_model`
  46. and :func:`ray.train.torch.prepare_data_loader` utilities below,
  47. and instead handle the logic directly inside your training function.
  48. First, use the :func:`~ray.train.torch.prepare_model` function to automatically move your model to the right device and wrap it in
  49. ``DistributedDataParallel``:
  50. .. code-block:: diff
  51. import torch
  52. from torch.nn.parallel import DistributedDataParallel
  53. +from ray.air import session
  54. +from ray import train
  55. +import ray.train.torch
  56. def train_func():
  57. - device = torch.device(f"cuda:{session.get_local_rank()}" if
  58. - torch.cuda.is_available() else "cpu")
  59. - torch.cuda.set_device(device)
  60. # Create model.
  61. model = NeuralNetwork()
  62. - model = model.to(device)
  63. - model = DistributedDataParallel(model,
  64. - device_ids=[session.get_local_rank()] if torch.cuda.is_available() else None)
  65. + model = train.torch.prepare_model(model)
  66. ...
  67. Then, use the ``prepare_data_loader`` function to automatically add a ``DistributedSampler`` to your ``DataLoader``
  68. and move the batches to the right device. This step is not necessary if you are passing in Ray Data to your Trainer
  69. (see :ref:`train-datasets`):
  70. .. code-block:: diff
  71. import torch
  72. from torch.utils.data import DataLoader, DistributedSampler
  73. +from ray.air import session
  74. +from ray import train
  75. +import ray.train.torch
  76. def train_func():
  77. - device = torch.device(f"cuda:{session.get_local_rank()}" if
  78. - torch.cuda.is_available() else "cpu")
  79. - torch.cuda.set_device(device)
  80. ...
  81. - data_loader = DataLoader(my_dataset, batch_size=worker_batch_size, sampler=DistributedSampler(dataset))
  82. + data_loader = DataLoader(my_dataset, batch_size=worker_batch_size)
  83. + data_loader = train.torch.prepare_data_loader(data_loader)
  84. for X, y in data_loader:
  85. - X = X.to_device(device)
  86. - y = y.to_device(device)
  87. .. tip::
  88. Keep in mind that ``DataLoader`` takes in a ``batch_size`` which is the batch size for each worker.
  89. The global batch size can be calculated from the worker batch size (and vice-versa) with the following equation:
  90. .. code-block:: python
  91. global_batch_size = worker_batch_size * session.get_world_size()
  92. .. tab-item:: TensorFlow
  93. .. note::
  94. The current TensorFlow implementation supports
  95. ``MultiWorkerMirroredStrategy`` (and ``MirroredStrategy``). If there are
  96. other strategies you wish to see supported by Ray Train, please let us know
  97. by submitting a `feature request on GitHub <https://github.com/ray-project/ray/issues>`_.
  98. These instructions closely follow TensorFlow's `Multi-worker training
  99. with Keras <https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras>`_
  100. tutorial. One key difference is that Ray Train will handle the environment
  101. variable set up for you.
  102. **Step 1:** Wrap your model in ``MultiWorkerMirroredStrategy``.
  103. The `MultiWorkerMirroredStrategy <https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy>`_
  104. enables synchronous distributed training. The ``Model`` *must* be built and
  105. compiled within the scope of the strategy.
  106. .. code-block:: python
  107. with tf.distribute.MultiWorkerMirroredStrategy().scope():
  108. model = ... # build model
  109. model.compile()
  110. **Step 2:** Update your ``Dataset`` batch size to the *global* batch
  111. size.
  112. The `batch <https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch>`_
  113. will be split evenly across worker processes, so ``batch_size`` should be
  114. set appropriately.
  115. .. code-block:: diff
  116. -batch_size = worker_batch_size
  117. +batch_size = worker_batch_size * session.get_world_size()
  118. .. tab-item:: Horovod
  119. If you have a training function that already runs with the `Horovod Ray
  120. Executor <https://horovod.readthedocs.io/en/stable/ray_include.html#horovod-ray-executor>`_,
  121. you should not need to make any additional changes!
  122. To onboard onto Horovod, please visit the `Horovod guide
  123. <https://horovod.readthedocs.io/en/stable/index.html#get-started>`_.
  124. Creating a Ray Train Trainer
  125. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  126. ``Trainer``\s are the primary Ray Train classes that are used to manage state and
  127. execute training. You can create a simple ``Trainer`` for the backend of choice
  128. with one of the following:
  129. .. tab-set::
  130. .. tab-item:: PyTorch
  131. .. code-block:: python
  132. from ray.air import ScalingConfig
  133. from ray.train.torch import TorchTrainer
  134. # For GPU Training, set `use_gpu` to True.
  135. use_gpu = False
  136. trainer = TorchTrainer(
  137. train_func,
  138. scaling_config=ScalingConfig(use_gpu=use_gpu, num_workers=2)
  139. )
  140. .. tab-item:: TensorFlow
  141. .. warning::
  142. Ray will not automatically set any environment variables or configuration
  143. related to local parallelism / threading
  144. :ref:`aside from "OMP_NUM_THREADS" <omp-num-thread-note>`.
  145. If you desire greater control over TensorFlow threading, use
  146. the ``tf.config.threading`` module (eg.
  147. ``tf.config.threading.set_inter_op_parallelism_threads(num_cpus)``)
  148. at the beginning of your ``train_loop_per_worker`` function.
  149. .. code-block:: python
  150. from ray.air import ScalingConfig
  151. from ray.train.tensorflow import TensorflowTrainer
  152. # For GPU Training, set `use_gpu` to True.
  153. use_gpu = False
  154. trainer = TensorflowTrainer(
  155. train_func,
  156. scaling_config=ScalingConfig(use_gpu=use_gpu, num_workers=2)
  157. )
  158. .. tab-item:: Horovod
  159. .. code-block:: python
  160. from ray.air import ScalingConfig
  161. from ray.train.horovod import HorovodTrainer
  162. # For GPU Training, set `use_gpu` to True.
  163. use_gpu = False
  164. trainer = HorovodTrainer(
  165. train_func,
  166. scaling_config=ScalingConfig(use_gpu=use_gpu, num_workers=2)
  167. )
  168. To customize the backend setup, you can use the :ref:`framework-specific config objects <train-integration-api>`.
  169. .. tab-set::
  170. .. tab-item:: PyTorch
  171. .. code-block:: python
  172. from ray.air import ScalingConfig
  173. from ray.train.torch import TorchTrainer, TorchConfig
  174. trainer = TorchTrainer(
  175. train_func,
  176. torch_backend=TorchConfig(...),
  177. scaling_config=ScalingConfig(num_workers=2),
  178. )
  179. .. tab-item:: TensorFlow
  180. .. code-block:: python
  181. from ray.air import ScalingConfig
  182. from ray.train.tensorflow import TensorflowTrainer, TensorflowConfig
  183. trainer = TensorflowTrainer(
  184. train_func,
  185. tensorflow_backend=TensorflowConfig(...),
  186. scaling_config=ScalingConfig(num_workers=2),
  187. )
  188. .. tab-item:: Horovod
  189. .. code-block:: python
  190. from ray.air import ScalingConfig
  191. from ray.train.horovod import HorovodTrainer, HorovodConfig
  192. trainer = HorovodTrainer(
  193. train_func,
  194. tensorflow_backend=HorovodConfig(...),
  195. scaling_config=ScalingConfig(num_workers=2),
  196. )
  197. For more configurability, please reference the :py:class:`~ray.train.data_parallel_trainer.DataParallelTrainer` API.
  198. Running your training function
  199. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  200. With a distributed training function and a Ray Train ``Trainer``, you are now
  201. ready to start training!
  202. .. code-block:: python
  203. trainer.fit()
  204. Configuring Training
  205. --------------------
  206. With Ray Train, you can execute a training function (``train_func``) in a
  207. distributed manner by calling ``Trainer.fit``. To pass arguments
  208. into the training function, you can expose a single ``config`` dictionary parameter:
  209. .. code-block:: diff
  210. -def train_func():
  211. +def train_func(config):
  212. Then, you can pass in the config dictionary as an argument to ``Trainer``:
  213. .. code-block:: diff
  214. +config = {} # This should be populated.
  215. trainer = TorchTrainer(
  216. train_func,
  217. + train_loop_config=config,
  218. scaling_config=ScalingConfig(num_workers=2)
  219. )
  220. Putting this all together, you can run your training function with different
  221. configurations. As an example:
  222. .. code-block:: python
  223. from ray.air import session, ScalingConfig
  224. from ray.train.torch import TorchTrainer
  225. def train_func(config):
  226. for i in range(config["num_epochs"]):
  227. session.report({"epoch": i})
  228. trainer = TorchTrainer(
  229. train_func,
  230. train_loop_config={"num_epochs": 2},
  231. scaling_config=ScalingConfig(num_workers=2)
  232. )
  233. result = trainer.fit()
  234. print(result.metrics["num_epochs"])
  235. # 1
  236. A primary use-case for ``config`` is to try different hyperparameters. To
  237. perform hyperparameter tuning with Ray Train, please refer to the
  238. :ref:`Ray Tune integration <train-tune>`.
  239. .. TODO add support for with_parameters
  240. .. _train-result-object:
  241. Accessing Training Results
  242. --------------------------
  243. .. TODO(ml-team) Flesh this section out.
  244. The return of a ``Trainer.fit`` is a :py:class:`~ray.air.result.Result` object, containing
  245. information about the training run. You can access it to obtain saved checkpoints,
  246. metrics and other relevant data.
  247. For example, you can:
  248. * Print the metrics for the last training iteration:
  249. .. code-block:: python
  250. from pprint import pprint
  251. pprint(result.metrics)
  252. # {'_time_this_iter_s': 0.001016855239868164,
  253. # '_timestamp': 1657829125,
  254. # '_training_iteration': 2,
  255. # 'config': {},
  256. # 'date': '2022-07-14_20-05-25',
  257. # 'done': True,
  258. # 'episodes_total': None,
  259. # 'epoch': 1,
  260. # 'experiment_id': '5a3f8b9bf875437881a8ddc7e4dd3340',
  261. # 'experiment_tag': '0',
  262. # 'hostname': 'ip-172-31-43-110',
  263. # 'iterations_since_restore': 2,
  264. # 'node_ip': '172.31.43.110',
  265. # 'pid': 654068,
  266. # 'time_since_restore': 3.4353830814361572,
  267. # 'time_this_iter_s': 0.00809168815612793,
  268. # 'time_total_s': 3.4353830814361572,
  269. # 'timestamp': 1657829125,
  270. # 'timesteps_since_restore': 0,
  271. # 'timesteps_total': None,
  272. # 'training_iteration': 2,
  273. # 'trial_id': '4913f_00000',
  274. # 'warmup_time': 0.003167867660522461}
  275. * View the dataframe containing the metrics from all iterations:
  276. .. code-block:: python
  277. print(result.metrics_dataframe)
  278. * Obtain the :py:class:`~ray.air.checkpoint.Checkpoint`, used for resuming training, prediction and serving.
  279. .. code-block:: python
  280. result.checkpoint # last saved checkpoint
  281. result.best_checkpoints # N best saved checkpoints, as configured in run_config
  282. .. _train-log-dir:
  283. Log Directory Structure
  284. ~~~~~~~~~~~~~~~~~~~~~~~
  285. Each ``Trainer`` will have a local directory created for logs and checkpoints.
  286. You can obtain the path to the directory by accessing the ``log_dir`` attribute
  287. of the :py:class:`~ray.air.result.Result` object returned by ``Trainer.fit()``.
  288. .. code-block:: python
  289. print(result.log_dir)
  290. # '/home/ubuntu/ray_results/TorchTrainer_2022-06-13_20-31-06/checkpoint_000003'
  291. .. _train-datasets:
  292. Distributed Data Ingest with Ray Data and Ray Train
  293. -------------------------------------------------------
  294. :ref:`Ray Data <data>` is the recommended way to work with large datasets in Ray Train. Ray Data provides automatic loading, sharding, and streamed ingest of Data across multiple Train workers.
  295. To get started, pass in one or more datasets under the ``datasets`` keyword argument for Trainer (e.g., ``Trainer(datasets={...})``).
  296. Here's a simple code overview of the Ray Data integration:
  297. .. code-block:: python
  298. from ray.air import session
  299. # Datasets can be accessed in your train_func via ``get_dataset_shard``.
  300. def train_func(config):
  301. train_data_shard = session.get_dataset_shard("train")
  302. validation_data_shard = session.get_dataset_shard("validation")
  303. ...
  304. # Random split the dataset into 80% training data and 20% validation data.
  305. dataset = ray.data.read_csv("...")
  306. train_dataset, validation_dataset = dataset.train_test_split(
  307. test_size=0.2, shuffle=True,
  308. )
  309. trainer = TorchTrainer(
  310. train_func,
  311. datasets={"train": train_dataset, "validation": validation_dataset},
  312. scaling_config=ScalingConfig(num_workers=8),
  313. )
  314. trainer.fit()
  315. For more details on how to configure data ingest for Train, please refer to :ref:`air-ingest`.
  316. .. TODO link to Training Run Iterator API as a 3rd option for logging.
  317. .. _train-monitoring:
  318. Logging, Checkpointing and Callbacks in Ray Train
  319. -------------------------------------------------
  320. Ray Train has mechanisms to easily collect intermediate results from the training workers during the training run
  321. and also has a :ref:`Callback interface <train-callbacks>` to perform actions on these intermediate results (such as logging, aggregations, etc.).
  322. You can use either the :ref:`built-in callbacks <air-builtin-callbacks>` that Ray AIR provides,
  323. or implement a :ref:`custom callback <train-custom-callbacks>` for your use case. The callback API
  324. is shared with Ray Tune.
  325. .. _train-checkpointing:
  326. Ray Train also provides a way to save :ref:`Checkpoints <air-checkpoints-doc>` during the training process. This is
  327. useful for:
  328. 1. :ref:`Integration with Ray Tune <train-tune>` to use certain Ray Tune
  329. schedulers.
  330. 2. Running a long-running training job on a cluster of pre-emptible machines/pods.
  331. 3. Persisting trained model state to later use for serving/inference.
  332. 4. In general, storing any model artifacts.
  333. Reporting intermediate results and handling checkpoints
  334. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  335. Ray AIR provides a *Session* API for reporting intermediate
  336. results and checkpoints from the training function (run on distributed workers) up to the
  337. ``Trainer`` (where your python script is executed) by calling ``session.report(metrics)``.
  338. The results will be collected from the distributed workers and passed to the driver to
  339. be logged and displayed.
  340. .. warning::
  341. Only the results from rank 0 worker will be used. However, in order to ensure
  342. consistency, ``session.report()`` has to be called on each worker. If you
  343. want to aggregate results from multiple workers, see :ref:`train-aggregating-results`.
  344. The primary use-case for reporting is for metrics (accuracy, loss, etc.) at
  345. the end of each training epoch.
  346. .. code-block:: python
  347. from ray.air import session
  348. def train_func():
  349. ...
  350. for i in range(num_epochs):
  351. result = model.train(...)
  352. session.report({"result": result})
  353. The session concept exists on several levels: The execution layer (called `Tune Session`) and the Data Parallel training layer
  354. (called `Train Session`).
  355. The following figure shows how these two sessions look like in a Data Parallel training scenario.
  356. .. image:: ../ray-air/images/session.svg
  357. :width: 650px
  358. :align: center
  359. ..
  360. https://docs.google.com/drawings/d/1g0pv8gqgG29aPEPTcd4BC0LaRNbW1sAkv3H6W1TCp0c/edit
  361. .. _train-dl-saving-checkpoints:
  362. Saving checkpoints
  363. ++++++++++++++++++
  364. :ref:`Checkpoints <air-checkpoints-doc>` can be saved by calling ``session.report(metrics, checkpoint=Checkpoint(...))`` in the
  365. training function. This will cause the checkpoint state from the distributed
  366. workers to be saved on the ``Trainer`` (where your python script is executed).
  367. The latest saved checkpoint can be accessed through the ``checkpoint`` attribute of
  368. the :py:class:`~ray.air.result.Result`, and the best saved checkpoints can be accessed by the ``best_checkpoints``
  369. attribute.
  370. Concrete examples are provided to demonstrate how checkpoints (model weights but not models) are saved
  371. appropriately in distributed training.
  372. .. tab-set::
  373. .. tab-item:: PyTorch
  374. .. code-block:: python
  375. :emphasize-lines: 36, 37, 38, 39, 40, 41
  376. import ray.train.torch
  377. from ray.air import session, Checkpoint, ScalingConfig
  378. from ray.train.torch import TorchTrainer
  379. import torch
  380. import torch.nn as nn
  381. from torch.optim import Adam
  382. import numpy as np
  383. def train_func(config):
  384. n = 100
  385. # create a toy dataset
  386. # data : X - dim = (n, 4)
  387. # target : Y - dim = (n, 1)
  388. X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
  389. Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
  390. # toy neural network : 1-layer
  391. # wrap the model in DDP
  392. model = ray.train.torch.prepare_model(nn.Linear(4, 1))
  393. criterion = nn.MSELoss()
  394. optimizer = Adam(model.parameters(), lr=3e-4)
  395. for epoch in range(config["num_epochs"]):
  396. y = model.forward(X)
  397. # compute loss
  398. loss = criterion(y, Y)
  399. # back-propagate loss
  400. optimizer.zero_grad()
  401. loss.backward()
  402. optimizer.step()
  403. state_dict = model.state_dict()
  404. checkpoint = Checkpoint.from_dict(
  405. dict(epoch=epoch, model_weights=state_dict)
  406. )
  407. session.report({}, checkpoint=checkpoint)
  408. trainer = TorchTrainer(
  409. train_func,
  410. train_loop_config={"num_epochs": 5},
  411. scaling_config=ScalingConfig(num_workers=2),
  412. )
  413. result = trainer.fit()
  414. print(result.checkpoint.to_dict())
  415. # {'epoch': 4, 'model_weights': OrderedDict([('bias', tensor([-0.1215])), ('weight', tensor([[0.3253, 0.1979, 0.4525, 0.2850]]))]), '_timestamp': 1656107095, '_preprocessor': None, '_current_checkpoint_id': 4}
  416. .. tab-item:: TensorFlow
  417. .. code-block:: python
  418. :emphasize-lines: 23
  419. from ray.air import session, Checkpoint, ScalingConfig
  420. from ray.train.tensorflow import TensorflowTrainer
  421. import numpy as np
  422. def train_func(config):
  423. import tensorflow as tf
  424. n = 100
  425. # create a toy dataset
  426. # data : X - dim = (n, 4)
  427. # target : Y - dim = (n, 1)
  428. X = np.random.normal(0, 1, size=(n, 4))
  429. Y = np.random.uniform(0, 1, size=(n, 1))
  430. strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
  431. with strategy.scope():
  432. # toy neural network : 1-layer
  433. model = tf.keras.Sequential([tf.keras.layers.Dense(1, activation="linear", input_shape=(4,))])
  434. model.compile(optimizer="Adam", loss="mean_squared_error", metrics=["mse"])
  435. for epoch in range(config["num_epochs"]):
  436. model.fit(X, Y, batch_size=20)
  437. checkpoint = Checkpoint.from_dict(
  438. dict(epoch=epoch, model_weights=model.get_weights())
  439. )
  440. session.report({}, checkpoint=checkpoint)
  441. trainer = TensorflowTrainer(
  442. train_func,
  443. train_loop_config={"num_epochs": 5},
  444. scaling_config=ScalingConfig(num_workers=2),
  445. )
  446. result = trainer.fit()
  447. print(result.checkpoint.to_dict())
  448. # {'epoch': 4, 'model_weights': [array([[-0.31858477],
  449. # [ 0.03747174],
  450. # [ 0.28266194],
  451. # [ 0.8626015 ]], dtype=float32), array([0.02230084], dtype=float32)], '_timestamp': 1656107383, '_preprocessor': None, '_current_checkpoint_id': 4}
  452. By default, checkpoints will be persisted to local disk in the :ref:`log
  453. directory <train-log-dir>` of each run.
  454. .. code-block:: python
  455. print(result.checkpoint.get_internal_representation())
  456. # ('local_path', '/home/ubuntu/ray_results/TorchTrainer_2022-06-24_21-34-49/TorchTrainer_7988b_00000_0_2022-06-24_21-34-49/checkpoint_000003')
  457. Configuring checkpoints
  458. +++++++++++++++++++++++
  459. For more configurability of checkpointing behavior (specifically saving
  460. checkpoints to disk), a :py:class:`~ray.air.config.CheckpointConfig` can be passed into
  461. ``Trainer``.
  462. As an example, to completely disable writing checkpoints to disk:
  463. .. code-block:: python
  464. :emphasize-lines: 9,14
  465. from ray.air import session, RunConfig, CheckpointConfig, ScalingConfig
  466. from ray.train.torch import TorchTrainer
  467. def train_func():
  468. for epoch in range(3):
  469. checkpoint = Checkpoint.from_dict(dict(epoch=epoch))
  470. session.report({}, checkpoint=checkpoint)
  471. checkpoint_config = CheckpointConfig(num_to_keep=0)
  472. trainer = TorchTrainer(
  473. train_func,
  474. scaling_config=ScalingConfig(num_workers=2),
  475. run_config=RunConfig(checkpoint_config=checkpoint_config)
  476. )
  477. trainer.fit()
  478. You may also config ``CheckpointConfig`` to keep the "N best" checkpoints persisted to disk. The following example shows how you could keep the 2 checkpoints with the lowest "loss" value:
  479. .. code-block:: python
  480. from ray.air import session, Checkpoint, RunConfig, CheckpointConfig, ScalingConfig
  481. from ray.train.torch import TorchTrainer
  482. def train_func():
  483. # first checkpoint
  484. session.report(dict(loss=2), checkpoint=Checkpoint.from_dict(dict(loss=2)))
  485. # second checkpoint
  486. session.report(dict(loss=2), checkpoint=Checkpoint.from_dict(dict(loss=4)))
  487. # third checkpoint
  488. session.report(dict(loss=2), checkpoint=Checkpoint.from_dict(dict(loss=1)))
  489. # fourth checkpoint
  490. session.report(dict(loss=2), checkpoint=Checkpoint.from_dict(dict(loss=3)))
  491. # Keep the 2 checkpoints with the smallest "loss" value.
  492. checkpoint_config = CheckpointConfig(
  493. num_to_keep=2, checkpoint_score_attribute="loss", checkpoint_score_order="min"
  494. )
  495. trainer = TorchTrainer(
  496. train_func,
  497. scaling_config=ScalingConfig(num_workers=2),
  498. run_config=RunConfig(checkpoint_config=checkpoint_config),
  499. )
  500. result = trainer.fit()
  501. print(result.best_checkpoints[0][0].get_internal_representation())
  502. # ('local_path', '/home/ubuntu/ray_results/TorchTrainer_2022-06-24_21-34-49/TorchTrainer_7988b_00000_0_2022-06-24_21-34-49/checkpoint_000000')
  503. print(result.best_checkpoints[1][0].get_internal_representation())
  504. # ('local_path', '/home/ubuntu/ray_results/TorchTrainer_2022-06-24_21-34-49/TorchTrainer_7988b_00000_0_2022-06-24_21-34-49/checkpoint_000002')
  505. .. _train-dl-loading-checkpoints:
  506. Loading checkpoints
  507. +++++++++++++++++++
  508. Checkpoints can be loaded into the training function in 2 steps:
  509. 1. From the training function, :func:`ray.air.session.get_checkpoint` can be used to access
  510. the most recently saved :py:class:`~ray.air.checkpoint.Checkpoint`. This is useful to continue training even
  511. if there's a worker failure.
  512. 2. The checkpoint to start training with can be bootstrapped by passing in a
  513. :py:class:`~ray.air.checkpoint.Checkpoint` to ``Trainer`` as the ``resume_from_checkpoint`` argument.
  514. .. tab-set::
  515. .. tab-item:: PyTorch
  516. .. code-block:: python
  517. :emphasize-lines: 23, 25, 26, 29, 30, 31, 35
  518. import ray.train.torch
  519. from ray.air import session, Checkpoint, ScalingConfig
  520. from ray.train.torch import TorchTrainer
  521. import torch
  522. import torch.nn as nn
  523. from torch.optim import Adam
  524. import numpy as np
  525. def train_func(config):
  526. n = 100
  527. # create a toy dataset
  528. # data : X - dim = (n, 4)
  529. # target : Y - dim = (n, 1)
  530. X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
  531. Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
  532. # toy neural network : 1-layer
  533. model = nn.Linear(4, 1)
  534. criterion = nn.MSELoss()
  535. optimizer = Adam(model.parameters(), lr=3e-4)
  536. start_epoch = 0
  537. checkpoint = session.get_checkpoint()
  538. if checkpoint:
  539. # assume that we have run the session.report() example
  540. # and successfully save some model weights
  541. checkpoint_dict = checkpoint.to_dict()
  542. model.load_state_dict(checkpoint_dict.get("model_weights"))
  543. start_epoch = checkpoint_dict.get("epoch", -1) + 1
  544. # wrap the model in DDP
  545. model = ray.train.torch.prepare_model(model)
  546. for epoch in range(start_epoch, config["num_epochs"]):
  547. y = model.forward(X)
  548. # compute loss
  549. loss = criterion(y, Y)
  550. # back-propagate loss
  551. optimizer.zero_grad()
  552. loss.backward()
  553. optimizer.step()
  554. state_dict = model.state_dict()
  555. checkpoint = Checkpoint.from_dict(
  556. dict(epoch=epoch, model_weights=state_dict)
  557. )
  558. session.report({}, checkpoint=checkpoint)
  559. trainer = TorchTrainer(
  560. train_func,
  561. train_loop_config={"num_epochs": 2},
  562. scaling_config=ScalingConfig(num_workers=2),
  563. )
  564. # save a checkpoint
  565. result = trainer.fit()
  566. # load checkpoint
  567. trainer = TorchTrainer(
  568. train_func,
  569. train_loop_config={"num_epochs": 4},
  570. scaling_config=ScalingConfig(num_workers=2),
  571. resume_from_checkpoint=result.checkpoint,
  572. )
  573. result = trainer.fit()
  574. print(result.checkpoint.to_dict())
  575. # {'epoch': 3, 'model_weights': OrderedDict([('bias', tensor([0.0902])), ('weight', tensor([[-0.1549, -0.0861, 0.4353, -0.4116]]))]), '_timestamp': 1656108265, '_preprocessor': None, '_current_checkpoint_id': 2}
  576. .. tab-item:: TensorFlow
  577. .. code-block:: python
  578. :emphasize-lines: 15, 21, 22, 25, 26, 27, 30
  579. from ray.air import session, Checkpoint, ScalingConfig
  580. from ray.train.tensorflow import TensorflowTrainer
  581. import numpy as np
  582. def train_func(config):
  583. import tensorflow as tf
  584. n = 100
  585. # create a toy dataset
  586. # data : X - dim = (n, 4)
  587. # target : Y - dim = (n, 1)
  588. X = np.random.normal(0, 1, size=(n, 4))
  589. Y = np.random.uniform(0, 1, size=(n, 1))
  590. start_epoch = 0
  591. strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
  592. with strategy.scope():
  593. # toy neural network : 1-layer
  594. model = tf.keras.Sequential([tf.keras.layers.Dense(1, activation="linear", input_shape=(4,))])
  595. checkpoint = session.get_checkpoint()
  596. if checkpoint:
  597. # assume that we have run the session.report() example
  598. # and successfully save some model weights
  599. checkpoint_dict = checkpoint.to_dict()
  600. model.set_weights(checkpoint_dict.get("model_weights"))
  601. start_epoch = checkpoint_dict.get("epoch", -1) + 1
  602. model.compile(optimizer="Adam", loss="mean_squared_error", metrics=["mse"])
  603. for epoch in range(start_epoch, config["num_epochs"]):
  604. model.fit(X, Y, batch_size=20)
  605. checkpoint = Checkpoint.from_dict(
  606. dict(epoch=epoch, model_weights=model.get_weights())
  607. )
  608. session.report({}, checkpoint=checkpoint)
  609. trainer = TensorflowTrainer(
  610. train_func,
  611. train_loop_config={"num_epochs": 2},
  612. scaling_config=ScalingConfig(num_workers=2),
  613. )
  614. # save a checkpoint
  615. result = trainer.fit()
  616. # load a checkpoint
  617. trainer = TensorflowTrainer(
  618. train_func,
  619. train_loop_config={"num_epochs": 5},
  620. scaling_config=ScalingConfig(num_workers=2),
  621. resume_from_checkpoint=result.checkpoint,
  622. )
  623. result = trainer.fit()
  624. print(result.checkpoint.to_dict())
  625. # {'epoch': 4, 'model_weights': [array([[-0.70056134],
  626. # [-0.8839263 ],
  627. # [-1.0043601 ],
  628. # [-0.61634773]], dtype=float32), array([0.01889327], dtype=float32)], '_timestamp': 1656108446, '_preprocessor': None, '_current_checkpoint_id': 3}
  629. .. _train-callbacks:
  630. Callbacks
  631. ~~~~~~~~~
  632. You may want to plug in your training code with your favorite experiment management framework.
  633. Ray AIR provides an interface to fetch intermediate results and callbacks to process/log your intermediate results
  634. (the values passed into :func:`ray.air.session.report`).
  635. Ray AIR contains :ref:`built-in callbacks <air-builtin-callbacks>` for popular tracking frameworks, or you can implement your own callback via the :ref:`Callback <tune-callbacks-docs>` interface.
  636. Example: Logging to MLflow and TensorBoard
  637. ++++++++++++++++++++++++++++++++++++++++++
  638. **Step 1: Install the necessary packages**
  639. .. code-block:: bash
  640. $ pip install mlflow
  641. $ pip install tensorboardX
  642. **Step 2: Run the following training script**
  643. .. literalinclude:: /../../python/ray/train/examples/mlflow_simple_example.py
  644. :language: python
  645. .. _train-custom-callbacks:
  646. Custom Callbacks
  647. ++++++++++++++++
  648. If the provided callbacks do not cover your desired integrations or use-cases,
  649. you may always implement a custom callback by subclassing :py:class:`~ray.tune.logger.LoggerCallback`. If
  650. the callback is general enough, please feel welcome to :ref:`add it <getting-involved>`
  651. to the ``ray`` `repository <https://github.com/ray-project/ray>`_.
  652. A simple example for creating a callback that will print out results:
  653. .. code-block:: python
  654. from typing import List, Dict
  655. from ray.air import session, RunConfig, ScalingConfig
  656. from ray.train.torch import TorchTrainer
  657. from ray.tune.logger import LoggerCallback
  658. # LoggerCallback is a higher level API of Callback.
  659. class LoggingCallback(LoggerCallback):
  660. def __init__(self) -> None:
  661. self.results = []
  662. def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
  663. self.results.append(trial.last_result)
  664. def train_func():
  665. for i in range(3):
  666. session.report({"epoch": i})
  667. callback = LoggingCallback()
  668. trainer = TorchTrainer(
  669. train_func,
  670. run_config=RunConfig(callbacks=[callback]),
  671. scaling_config=ScalingConfig(num_workers=2),
  672. )
  673. trainer.fit()
  674. print("\n".join([str(x) for x in callback.results]))
  675. # {'trial_id': '0f1d0_00000', 'experiment_id': '494a1d050b4a4d11aeabd87ba475fcd3', 'date': '2022-06-27_17-03-28', 'timestamp': 1656349408, 'pid': 23018, 'hostname': 'ip-172-31-43-110', 'node_ip': '172.31.43.110', 'config': {}}
  676. # {'epoch': 0, '_timestamp': 1656349412, '_time_this_iter_s': 0.0026497840881347656, '_training_iteration': 1, 'time_this_iter_s': 3.433483362197876, 'done': False, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 1, 'trial_id': '0f1d0_00000', 'experiment_id': '494a1d050b4a4d11aeabd87ba475fcd3', 'date': '2022-06-27_17-03-32', 'timestamp': 1656349412, 'time_total_s': 3.433483362197876, 'pid': 23018, 'hostname': 'ip-172-31-43-110', 'node_ip': '172.31.43.110', 'config': {}, 'time_since_restore': 3.433483362197876, 'timesteps_since_restore': 0, 'iterations_since_restore': 1, 'warmup_time': 0.003779172897338867, 'experiment_tag': '0'}
  677. # {'epoch': 1, '_timestamp': 1656349412, '_time_this_iter_s': 0.0013833045959472656, '_training_iteration': 2, 'time_this_iter_s': 0.016670703887939453, 'done': False, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 2, 'trial_id': '0f1d0_00000', 'experiment_id': '494a1d050b4a4d11aeabd87ba475fcd3', 'date': '2022-06-27_17-03-32', 'timestamp': 1656349412, 'time_total_s': 3.4501540660858154, 'pid': 23018, 'hostname': 'ip-172-31-43-110', 'node_ip': '172.31.43.110', 'config': {}, 'time_since_restore': 3.4501540660858154, 'timesteps_since_restore': 0, 'iterations_since_restore': 2, 'warmup_time': 0.003779172897338867, 'experiment_tag': '0'}
  678. .. _train-aggregating-results:
  679. How to obtain and aggregate results from different workers?
  680. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  681. In real applications, you may want to calculate optimization metrics besides accuracy and loss: recall, precision, Fbeta, etc.
  682. You may also want to collect metrics from multiple workers. While Ray Train currently only reports metrics from the rank 0
  683. worker, you can use third-party libraries or distributed primitives of your machine learning framework to report
  684. metrics from multiple workers.
  685. .. tab-set::
  686. .. tab-item:: PyTorch
  687. Ray Train natively supports `TorchMetrics <https://torchmetrics.readthedocs.io/en/latest/>`_, which provides a collection of machine learning metrics for distributed, scalable PyTorch models.
  688. Here is an example of reporting both the aggregated R2 score and mean train and validation loss from all workers.
  689. .. literalinclude:: doc_code/torchmetrics_example.py
  690. :language: python
  691. :start-after: __start__
  692. .. tab-item:: TensorFlow
  693. TensorFlow Keras automatically aggregates metrics from all workers. If you wish to have more
  694. control over that, consider implementing a `custom training loop <https://www.tensorflow.org/tutorials/distribute/custom_training>`_.
  695. .. Running on the cloud
  696. .. --------------------
  697. .. Use Ray Train with the Ray cluster launcher by changing the following:
  698. .. .. code-block:: bash
  699. .. ray up cluster.yaml
  700. .. TODO.
  701. .. _train-fault-tolerance:
  702. Fault Tolerance
  703. ---------------
  704. Automatically Recover from Train Worker Failures
  705. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  706. Ray Train has built-in fault tolerance to recover from worker failures (i.e.
  707. ``RayActorError``\s). When a failure is detected, the workers will be shut
  708. down and new workers will be added in.
  709. .. note:: Elastic Training is not yet supported.
  710. The training function will be restarted, but progress from the previous execution can
  711. be resumed through checkpointing.
  712. .. tip::
  713. In order to retain progress when recovery, your training function
  714. **must** implement logic for both :ref:`saving <train-dl-saving-checkpoints>`
  715. *and* :ref:`loading checkpoints <train-dl-loading-checkpoints>`.
  716. Each instance of recovery from a worker failure is considered a retry. The
  717. number of retries is configurable through the ``max_failures`` attribute of the
  718. :class:`~ray.air.FailureConfig` argument set in the :class:`~ray.air.RunConfig`
  719. passed to the ``Trainer``:
  720. .. literalinclude:: doc_code/key_concepts.py
  721. :language: python
  722. :start-after: __failure_config_start__
  723. :end-before: __failure_config_end__
  724. .. _train-restore-guide:
  725. Restore a Ray Train Experiment
  726. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  727. At the experiment level, :ref:`Trainer restoration <trainer-restore>`
  728. allows you to resume a previously interrupted experiment from where it left off.
  729. A Train experiment may be interrupted due to one of the following reasons:
  730. - The experiment was manually interrupted (e.g., Ctrl+C, or pre-empted head node instance).
  731. - The head node crashed (e.g., OOM or some other runtime error).
  732. - The entire cluster went down (e.g., network error affecting all nodes).
  733. Trainer restoration is possible for all of Ray Train's built-in trainers,
  734. but we use ``TorchTrainer`` in the examples for demonstration.
  735. We also use ``<Framework>Trainer`` to refer to methods that are shared across all
  736. built-in trainers.
  737. Let's say your initial Train experiment is configured as follows.
  738. The actual training loop is just for demonstration purposes: the important detail is that
  739. :ref:`saving <train-dl-saving-checkpoints>` *and* :ref:`loading checkpoints <train-dl-loading-checkpoints>`
  740. has been implemented.
  741. .. literalinclude:: doc_code/dl_guide.py
  742. :language: python
  743. :start-after: __ft_initial_run_start__
  744. :end-before: __ft_initial_run_end__
  745. The results and checkpoints of the experiment are saved to the path configured by :class:`~ray.air.config.RunConfig`.
  746. If the experiment has been interrupted due to one of the reasons listed above, use this path to resume:
  747. .. literalinclude:: doc_code/dl_guide.py
  748. :language: python
  749. :start-after: __ft_restored_run_start__
  750. :end-before: __ft_restored_run_end__
  751. .. tip::
  752. You can also restore from a remote path (e.g., from an experiment directory stored in a s3 bucket).
  753. .. literalinclude:: doc_code/dl_guide.py
  754. :language: python
  755. :dedent:
  756. :start-after: __ft_restore_from_cloud_initial_start__
  757. :end-before: __ft_restore_from_cloud_initial_end__
  758. .. literalinclude:: doc_code/dl_guide.py
  759. :language: python
  760. :dedent:
  761. :start-after: __ft_restore_from_cloud_restored_start__
  762. :end-before: __ft_restore_from_cloud_restored_end__
  763. .. note::
  764. Different trainers may allow more parameters to be optionally re-specified on restore.
  765. Only **datasets** are required to be re-specified on restore, if they were supplied originally.
  766. See :ref:`train-framework-specific-restore` for more details.
  767. Auto-resume
  768. +++++++++++
  769. Adding the branching logic below will allow you to run the same script after the interrupt,
  770. picking up training from where you left on the previous run. Notice that we use the
  771. :meth:`<Framework>Trainer.can_restore <ray.train.trainer.BaseTrainer.can_restore>` utility method
  772. to determine the existence and validity of the given experiment directory.
  773. .. literalinclude:: doc_code/dl_guide.py
  774. :language: python
  775. :start-after: __ft_autoresume_start__
  776. :end-before: __ft_autoresume_end__
  777. .. seealso::
  778. See the :meth:`BaseTrainer.restore <ray.train.trainer.BaseTrainer.restore>` docstring
  779. for a full example.
  780. .. note::
  781. `<Framework>Trainer.restore` is different from
  782. :class:`<Framework>Trainer(..., resume_from_checkpoint=...) <ray.train.trainer.BaseTrainer>`.
  783. `resume_from_checkpoint` is meant to be used to start a *new* Train experiment,
  784. which writes results to a new directory and starts over from iteration 0.
  785. `<Framework>Trainer.restore` is used to continue an existing experiment, where
  786. new results will continue to be appended to existing logs.
  787. .. Running on pre-emptible machines
  788. .. --------------------------------
  789. .. You may want to
  790. .. TODO.
  791. .. We do not have a profiling callback in AIR as the execution engine has changed to Tune. The behavior of the callback can be replicated with checkpoints (do a trace, save it to checkpoint, it gets downloaded to driver every iteration).
  792. .. .. _train-profiling:
  793. .. Profiling
  794. .. ---------
  795. .. Ray Train comes with an integration with `PyTorch Profiler <https://pytorch.org/blog/introducing-pytorch-profiler-the-new-and-improved-performance-tool/>`_.
  796. .. Specifically, it comes with a :ref:`TorchWorkerProfiler <train-api-torch-worker-profiler>` utility class and :ref:`train-api-torch-tensorboard-profiler-callback` callback
  797. .. that allow you to use the PyTorch Profiler as you would in a non-distributed PyTorch script, and synchronize the generated Tensorboard traces onto
  798. .. the disk that from which your script was executed from.
  799. .. **Step 1: Update training function with** ``TorchWorkerProfiler``
  800. .. .. code-block:: bash
  801. .. from ray.train.torch import TorchWorkerProfiler
  802. .. def train_func():
  803. .. twp = TorchWorkerProfiler()
  804. .. with profile(..., on_trace_ready=twp.trace_handler) as p:
  805. .. ...
  806. .. profile_results = twp.get_and_clear_profile_traces()
  807. .. train.report(..., **profile_results)
  808. .. ...
  809. .. **Step 2: Run training function with** ``TorchTensorboardProfilerCallback``
  810. .. .. code-block:: python
  811. .. from ray.train import Trainer
  812. .. from ray.train.callbacks import TorchTensorboardProfilerCallback
  813. .. trainer = Trainer(backend="torch", num_workers=2)
  814. .. trainer.start()
  815. .. trainer.run(train_func, callbacks=[TorchTensorboardProfilerCallback()])
  816. .. trainer.shutdown()
  817. .. **Step 3: Visualize the logs**
  818. .. .. code-block:: bash
  819. .. # Navigate to the run directory of the trainer.
  820. .. # For example `cd /home/ray_results/train_2021-09-01_12-00-00/run_001/pytorch_profiler`
  821. .. $ cd <TRAINER_RUN_DIR>/pytorch_profiler
  822. .. # Install the PyTorch Profiler TensorBoard Plugin.
  823. .. $ pip install torch_tb_profiler
  824. .. # Star the TensorBoard UI.
  825. .. $ tensorboard --logdir .
  826. .. # View the PyTorch Profiler traces.
  827. .. $ open http://localhost:6006/#pytorch_profiler
  828. .. _train-tune:
  829. Hyperparameter tuning (Ray Tune)
  830. --------------------------------
  831. Hyperparameter tuning with :ref:`Ray Tune <tune-main>` is natively supported
  832. with Ray Train. Specifically, you can take an existing ``Trainer`` and simply
  833. pass it into a :py:class:`~ray.tune.tuner.Tuner`.
  834. .. code-block:: python
  835. from ray import tune
  836. from ray.air import session, ScalingConfig
  837. from ray.train.torch import TorchTrainer
  838. from ray.tune.tuner import Tuner, TuneConfig
  839. def train_func(config):
  840. # In this example, nothing is expected to change over epochs,
  841. # and the output metric is equivalent to the input value.
  842. for _ in range(config["num_epochs"]):
  843. session.report(dict(output=config["input"]))
  844. trainer = TorchTrainer(train_func, scaling_config=ScalingConfig(num_workers=2))
  845. tuner = Tuner(
  846. trainer,
  847. param_space={
  848. "train_loop_config": {
  849. "num_epochs": 2,
  850. "input": tune.grid_search([1, 2, 3]),
  851. }
  852. },
  853. tune_config=TuneConfig(num_samples=5, metric="output", mode="max"),
  854. )
  855. result_grid = tuner.fit()
  856. print(result_grid.get_best_result().metrics["output"])
  857. # 3
  858. .. _torch-amp:
  859. Automatic Mixed Precision
  860. -------------------------
  861. Automatic mixed precision (AMP) lets you train your models faster by using a lower
  862. precision datatype for operations like linear layers and convolutions.
  863. .. tab-set::
  864. .. tab-item:: PyTorch
  865. You can train your Torch model with AMP by:
  866. 1. Adding :func:`ray.train.torch.accelerate` with ``amp=True`` to the top of your training function.
  867. 2. Wrapping your optimizer with :func:`ray.train.torch.prepare_optimizer`.
  868. 3. Replacing your backward call with :func:`ray.train.torch.backward`.
  869. .. code-block:: diff
  870. def train_func():
  871. + train.torch.accelerate(amp=True)
  872. model = NeuralNetwork()
  873. model = train.torch.prepare_model(model)
  874. data_loader = DataLoader(my_dataset, batch_size=worker_batch_size)
  875. data_loader = train.torch.prepare_data_loader(data_loader)
  876. optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
  877. + optimizer = train.torch.prepare_optimizer(optimizer)
  878. model.train()
  879. for epoch in range(90):
  880. for images, targets in dataloader:
  881. optimizer.zero_grad()
  882. outputs = model(images)
  883. loss = torch.nn.functional.cross_entropy(outputs, targets)
  884. - loss.backward()
  885. + train.torch.backward(loss)
  886. optimizer.step()
  887. ...
  888. .. note:: The performance of AMP varies based on GPU architecture, model type,
  889. and data shape. For certain workflows, AMP may perform worse than
  890. full-precision training.
  891. .. _train-reproducibility:
  892. Reproducibility
  893. ---------------
  894. .. tab-set::
  895. .. tab-item:: PyTorch
  896. To limit sources of nondeterministic behavior, add
  897. :func:`ray.train.torch.enable_reproducibility` to the top of your training
  898. function.
  899. .. code-block:: diff
  900. def train_func():
  901. + train.torch.enable_reproducibility()
  902. model = NeuralNetwork()
  903. model = train.torch.prepare_model(model)
  904. ...
  905. .. warning:: :func:`ray.train.torch.enable_reproducibility` can't guarantee
  906. completely reproducible results across executions. To learn more, read
  907. the `PyTorch notes on randomness <https://pytorch.org/docs/stable/notes/randomness.html>`_.
  908. ..
  909. import ray
  910. from ray import tune
  911. def training_func(config):
  912. dataloader = ray.train.get_dataset()\
  913. .get_shard(torch.rank())\
  914. .iter_torch_batches(batch_size=config["batch_size"])
  915. for i in config["epochs"]:
  916. ray.train.report(...) # use same intermediate reporting API
  917. # Declare the specification for training.
  918. trainer = Trainer(backend="torch", num_workers=12, use_gpu=True)
  919. dataset = ray.dataset.window()
  920. # Convert this to a trainable.
  921. trainable = trainer.to_tune_trainable(training_func, dataset=dataset)
  922. tuner = tune.Tuner(trainable,
  923. param_space={"lr": tune.uniform(), "batch_size": tune.randint(1, 2, 3)},
  924. tune_config=tune.TuneConfig(num_samples=12))
  925. results = tuner.fit()
  926. ..
  927. Advanced APIs
  928. -------------
  929. TODO
  930. Training Run Iterator API
  931. ~~~~~~~~~~~~~~~~~~~~~~~~~
  932. TODO
  933. Stateful Class API
  934. ~~~~~~~~~~~~~~~~~~
  935. TODO