plot_example-lm.rst 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. Fault-Tolerant Fairseq Training
  2. ===============================
  3. This document provides a walkthrough of adapting the `Fairseq library <https://github.com/pytorch/fairseq>`__ to perform fault-tolerant distributed training on AWS.
  4. As an example, we use the WikiText-103 dataset to pretrain the RoBERTa model following `this tutorial <https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.pretraining.md>`__. The pipeline and configurations in this document will work for other models supported by Fairseq, such as sequence-to-sequence machine translation models.
  5. To run this example, you will need to install Ray on your local machine to use the Ray cluster launcher.
  6. You can view the `code for this example`_.
  7. .. _`code for this example`: https://github.com/ray-project/ray/tree/master/doc/examples/lm
  8. To use Ray cluster launcher on AWS, install boto (``pip install boto3``) and configure your AWS credentials in ``~/.aws/credentials`` as described on the :ref:`Automatic Cluster Setup page <cluster-cloud>`.
  9. We provide an `example config file <https://github.com/ray-project/ray/tree/master/doc/examples/lm/lm-cluster.yaml>`__ (``lm-cluster.yaml``).
  10. In the example config file, we use an ``m5.xlarge`` on-demand instance as the head node, and use ``p3.2xlarge`` GPU spot instances as the worker nodes. We set the minimal number of workers to 1 and maximum workers to 2 in the config, which can be modified according to your own demand.
  11. We also mount :ref:`Amazon EFS <aws-cluster-efs>` to store code, data and checkpoints.
  12. .. note::
  13. The ``{{SecurityGroupId}}`` and ``{{FileSystemId}}`` fields in the config file should be replaced by your own IDs.
  14. In ``setup_commands``, we use the PyTorch environment in the Deep Learning AMI, and install Ray and Fairseq:
  15. .. code-block:: yaml
  16. setup_commands:
  17. - echo 'export PATH="$HOME/anaconda3/envs/pytorch_p36/bin:$PATH"' >> ~/.bashrc;
  18. source ~/.bashrc;
  19. pip install -U ray;
  20. pip install -U fairseq==0.8.0;
  21. Run the following command on your local machine to start the Ray cluster:
  22. .. code-block:: bash
  23. ray up lm-cluster.yaml
  24. ``ray_train.sh`` also assumes that all of the ``lm/`` files are in ``$HOME/efs``.
  25. You can move these files manually, or use the following command to upload
  26. files from a local path:
  27. .. code-block:: bash
  28. ray rsync-up lm-cluster.yaml PATH/TO/LM '~/efs/lm'
  29. Preprocessing Data
  30. ------------------
  31. Once the cluster is started, you can then SSH into the head node using ``ray attach lm-cluster.yaml`` and download or preprocess the data on EFS for training. We can run ``preprocess.sh`` (`code <https://github.com/ray-project/ray/tree/master/doc/examples/lm/preprocess.sh>`_) to do this, which adapts instructions from `the RoBERTa tutorial <https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.pretraining.md>`__.
  32. Training
  33. --------
  34. We provide ``ray_train.py`` (`code <https://github.com/ray-project/ray/tree/master/doc/examples/lm/ray_train.py>`__) as an entrypoint to the Fairseq library. Since we are training the model on spot instances, we provide fault-tolerance in ``ray_train.py`` by checkpointing and restarting when a node fails. The code will also check whether there are new resources available after checkpointing. If so, the program will make use of them by restarting and resizing.
  35. Two main components of ``ray_train.py`` are a ``RayDistributedActor`` class and a function ``run_fault_tolerant_loop()``. The ``RayDistributedActor`` sets proper arguments for different ray actor processes, adds a checkpoint hook to enable the process to make use of new available GPUs, and calls the ``main`` of Fairseq:
  36. .. code-block:: python
  37. import math
  38. import copy
  39. import socket
  40. import time
  41. import ray
  42. import fairseq
  43. from fairseq import options
  44. from fairseq_cli.train import main
  45. from contextlib import closing
  46. _original_save_checkpoint = fairseq.checkpoint_utils.save_checkpoint
  47. class RayDistributedActor:
  48. """Actor to perform distributed training."""
  49. def run(self, url, world_rank, args):
  50. """Runs the fairseq training.
  51. We set args for different ray actors for communication,
  52. add a checkpoint hook, and call the main function of fairseq.
  53. """
  54. # Set the init_method and rank of the process for distributed training.
  55. print("Ray worker at {url} rank {rank}".format(
  56. url=url, rank=world_rank))
  57. self.url = url
  58. self.world_rank = world_rank
  59. args.distributed_rank = world_rank
  60. args.distributed_init_method = url
  61. # Add a checkpoint hook to make use of new resources.
  62. self.add_checkpoint_hook(args)
  63. # Call the original main function of fairseq.
  64. main(args, init_distributed=(args.distributed_world_size > 1))
  65. def add_checkpoint_hook(self, args):
  66. """Add a hook to the original save_checkpoint function.
  67. This checks if there are new computational resources available.
  68. If so, raise exception to restart the training process and
  69. make use of the new resources.
  70. """
  71. if args.cpu:
  72. original_n_cpus = args.distributed_world_size
  73. def _new_save_checkpoint(*args, **kwargs):
  74. _original_save_checkpoint(*args, **kwargs)
  75. n_cpus = int(ray.cluster_resources()["CPU"])
  76. if n_cpus > original_n_cpus:
  77. raise Exception(
  78. "New CPUs find (original %d CPUs, now %d CPUs)" %
  79. (original_n_cpus, n_cpus))
  80. else:
  81. original_n_gpus = args.distributed_world_size
  82. def _new_save_checkpoint(*args, **kwargs):
  83. _original_save_checkpoint(*args, **kwargs)
  84. n_gpus = int(ray.cluster_resources().get("GPU", 0))
  85. if n_gpus > original_n_gpus:
  86. raise Exception(
  87. "New GPUs find (original %d GPUs, now %d GPUs)" %
  88. (original_n_gpus, n_gpus))
  89. fairseq.checkpoint_utils.save_checkpoint = _new_save_checkpoint
  90. def get_node_ip(self):
  91. """Returns the IP address of the current node."""
  92. return ray._private.services.get_node_ip_address()
  93. def find_free_port(self):
  94. """Finds a free port on the current node."""
  95. with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
  96. s.bind(("", 0))
  97. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  98. return s.getsockname()[1]
  99. The function ``run_fault_tolerant_loop()`` provides fault-tolerance by catching failure and restart the computation:
  100. .. code-block:: python
  101. def run_fault_tolerant_loop():
  102. """Entrance function to the fairseq library, providing fault-tolerance."""
  103. # Parse the command line arguments.
  104. parser = options.get_training_parser()
  105. add_ray_args(parser)
  106. args = options.parse_args_and_arch(parser)
  107. original_args = copy.deepcopy(args)
  108. # Main loop for fault-tolerant training.
  109. retry = True
  110. while retry:
  111. args = copy.deepcopy(original_args)
  112. # Initialize Ray.
  113. ray.init(address=args.ray_address)
  114. set_num_resources(args)
  115. set_batch_size(args)
  116. # Set up Ray distributed actors.
  117. Actor = ray.remote(
  118. num_cpus=1, num_gpus=int(not args.cpu))(RayDistributedActor)
  119. workers = [Actor.remote() for i in range(args.distributed_world_size)]
  120. # Get the IP address and a free port of actor 0, which is used for
  121. # fairseq distributed training.
  122. ip = ray.get(workers[0].get_node_ip.remote())
  123. port = ray.get(workers[0].find_free_port.remote())
  124. address = "tcp://{ip}:{port}".format(ip=ip, port=port)
  125. # Start the remote processes, and check whether their are any process
  126. # fails. If so, restart all the processes.
  127. unfinished = [
  128. worker.run.remote(address, i, args)
  129. for i, worker in enumerate(workers)
  130. ]
  131. try:
  132. while len(unfinished) > 0:
  133. finished, unfinished = ray.wait(unfinished)
  134. finished = ray.get(finished)
  135. retry = False
  136. except Exception as inst:
  137. print("Ray restart because following error occurs:")
  138. print(inst)
  139. retry = True
  140. ray.shutdown()
  141. In ``ray_train.py``, we also define a set of helper functions. ``add_ray_args()`` adds Ray and fault-tolerant training related arguments to the argument parser:
  142. .. code-block:: python
  143. def add_ray_args(parser):
  144. """Add ray and fault-tolerance related parser arguments to the parser."""
  145. group = parser.add_argument_group("Ray related arguments")
  146. group.add_argument(
  147. "--ray-address",
  148. default="auto",
  149. type=str,
  150. help="address for ray initialization")
  151. group.add_argument(
  152. "--fix-batch-size",
  153. default=None,
  154. metavar="B1,B2,...,B_N",
  155. type=lambda uf: options.eval_str_list(uf, type=int),
  156. help="fix the actual batch size (max_sentences * update_freq "
  157. "* n_GPUs) to be the fixed input values by adjusting update_freq "
  158. "accroding to actual n_GPUs; the batch size is fixed to B_i for "
  159. "epoch i; all epochs >N are fixed to B_N")
  160. return group
  161. ``set_num_resources()`` sets the distributed world size to be the number of resources. Also if we want to use GPUs but the current number of GPUs is 0, the function will wait until there is GPU available:
  162. .. code-block:: python
  163. def set_num_resources(args):
  164. """Get the number of resources and set the corresponding fields."""
  165. if args.cpu:
  166. args.distributed_world_size = int(ray.cluster_resources()["CPU"])
  167. else:
  168. n_gpus = int(ray.cluster_resources().get("GPU", 0))
  169. while n_gpus == 0:
  170. print("No GPUs available, wait 10 seconds")
  171. time.sleep(10)
  172. n_gpus = int(ray.cluster_resources().get("GPU", 0))
  173. args.distributed_world_size = n_gpus
  174. ``set_batch_size()`` keeps the effective batch size to be relatively the same given different number of GPUs:
  175. .. code-block:: python
  176. def set_batch_size(args):
  177. """Fixes the total batch_size to be agnostic to the GPU count."""
  178. if args.fix_batch_size is not None:
  179. args.update_freq = [
  180. math.ceil(batch_size /
  181. (args.max_sentences * args.distributed_world_size))
  182. for batch_size in args.fix_batch_size
  183. ]
  184. print("Training on %d GPUs, max_sentences=%d, update_freq=%s" %
  185. (args.distributed_world_size, args.max_sentences,
  186. repr(args.update_freq)))
  187. To start training, run `following commands <https://github.com/ray-project/ray/tree/master/doc/examples/lm/ray_train.sh>`__ (``ray_train.sh``) on the head machine:
  188. .. code-block:: bash
  189. cd ~/efs/lm
  190. TOTAL_UPDATES=125000 # Total number of training steps
  191. WARMUP_UPDATES=10000 # Warmup the learning rate over this many updates
  192. PEAK_LR=0.0005 # Peak learning rate, adjust as needed
  193. TOKENS_PER_SAMPLE=512 # Max sequence length
  194. #MAX_POSITIONS=512 # Num. positional embeddings (usually same as above)
  195. MAX_SENTENCES=8 # Number of sequences per batch on one GPU (batch size)
  196. FIX_BATCH_SIZE=2048 # Number of batch size in total (max_sentences * update_freq * n_gpus)
  197. SAVE_INTERVAL_UPDATES=1000 # save a checkpoint every N updates
  198. LOG_DIR=$HOME/efs/lm/log/
  199. DATA_DIR=$HOME/efs/lm/data-bin/wikitext-103/
  200. mkdir -p $LOG_DIR
  201. python $HOME/efs/lm/ray_train.py --fp16 $DATA_DIR \
  202. --task masked_lm --criterion masked_lm \
  203. --arch roberta_base --sample-break-mode complete --tokens-per-sample $TOKENS_PER_SAMPLE \
  204. --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --clip-norm 0.0 \
  205. --lr-scheduler polynomial_decay --lr $PEAK_LR --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \
  206. --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
  207. --max-sentences $MAX_SENTENCES \
  208. --fix-batch-size $FIX_BATCH_SIZE \
  209. --max-update $TOTAL_UPDATES --log-format simple --log-interval 1 \
  210. --save-interval-updates $SAVE_INTERVAL_UPDATES \
  211. --save-dir $LOG_DIR --ddp-backend=no_c10d
  212. ``SAVE_INTERVAL_UPDATES`` controls how often to save a checkpoint, which can be tuned based on the `stability of chosen instances <https://aws.amazon.com/ec2/spot/instance-advisor/>`__. ``FIX_BATCH_SIZE`` controls the total batch size to be a roughly fixed number.
  213. Helpful Ray Commands
  214. --------------------
  215. To let Ray automatically stop the cluster after the training finished, you can download the ``ray_train.sh`` to ``~/efs`` of the remote machine, and run the following command on your local machine:
  216. .. code-block:: bash
  217. ray exec --stop lm-cluster.yaml 'bash $HOME/efs/lm/ray_train.sh'
  218. or run the following command on the remote head node:
  219. .. code-block:: bash
  220. ray exec --stop ~/ray_bootstrap_config.yaml 'bash $HOME/efs/lm/ray_train.sh'
  221. To test the fault-tolerance, you can run the following command on your local machine to randomly kill one node:
  222. .. code-block:: bash
  223. ray kill-random-node lm-cluster.yaml