schedule.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from ..utils import call_to_str
  5. from abc import ABC, abstractmethod
  6. class PipeSchedule(ABC):
  7. """Directs the execution of a pipeline engine by generating sequences of
  8. :class:`PipeInstruction`.
  9. Schedules are generators that yield sequences of
  10. :class:`PipeInstruction` to process the micro-batches in one batch.
  11. Each yielded step is atomic in the sense that a barrier
  12. synchronization can be placed between successive steps without
  13. deadlock.
  14. Below is an example schedule that implements data parallelism with gradient accumulation:
  15. .. code-block:: python
  16. class DataParallelSchedule(PipeSchedule):
  17. def steps(self):
  18. for step_id in range(self.micro_batches):
  19. cmds = [
  20. LoadMicroBatch(buffer_id=0),
  21. ForwardPass(buffer_id=0),
  22. BackwardPass(buffer_id=0),
  23. ]
  24. if step_id == self.micro_batches - 1:
  25. cmds.extend([
  26. ReduceGrads(),
  27. OptimizerStep(),
  28. ])
  29. yield cmds
  30. def num_pipe_buffers(self):
  31. return 1
  32. Args:
  33. micro_batches (int): The number of micro-batches that comprise a batch.
  34. stages (int): The number of pipeline stages.
  35. stage_id (int): The pipe stage that will execute the generated schedule.
  36. """
  37. def __init__(self, micro_batches, stages, stage_id):
  38. super().__init__()
  39. self.micro_batches = micro_batches
  40. self.stages = stages
  41. self.stage_id = stage_id
  42. self.prev_stage = self.stage_id - 1
  43. self.next_stage = self.stage_id + 1
  44. @abstractmethod
  45. def steps(self):
  46. """Yield a list of :class:`PipeInstruction` for each step in the schedule.
  47. .. note::
  48. Schedules must implement ``steps()`` to define the schedule.
  49. Returns:
  50. Instructions to be executed as one step of the pipeline
  51. """
  52. pass
  53. def num_pipe_buffers(self):
  54. """The number of pipeline buffers that will be used by this stage.
  55. .. note::
  56. Schedules should specialize ``num_pipe_buffers()`` for memory savings at scale.
  57. Returns:
  58. The number of buffers for the engine to allocate.
  59. """
  60. return self.micro_batches
  61. def _valid_micro_batch(self, micro_batch_id):
  62. return 0 <= micro_batch_id < self.micro_batches
  63. def _valid_stage(self, stage_id):
  64. return 0 <= stage_id < self.stages
  65. @property
  66. def stage(self):
  67. """Stage index used to configure this schedule."""
  68. return self.stage_id
  69. @property
  70. def num_stages(self):
  71. """The number of total pipeline stages used to configure this schedule."""
  72. return self.stages
  73. @property
  74. def num_micro_batches(self):
  75. """The number of total micro_batches used to configure this schedule."""
  76. return self.micro_batches
  77. @property
  78. def is_first_stage(self):
  79. """True if the configured ``stage_id`` is the first stage in the pipeline."""
  80. return self.stage_id == 0
  81. @property
  82. def is_last_stage(self):
  83. """True if the configured ``stage_id`` is the last stage in the pipeline."""
  84. return self.stage_id == self.stages - 1
  85. def _buffer_idx(self, micro_batch_id):
  86. """Map a micro-batch index to a pipeline buffer index.
  87. This method uses a cyclic allocation strategy.
  88. Args:
  89. micro_batch_id (int): The micro-batch index relative to the beginning of the schedule.
  90. Returns:
  91. int: The index of the buffer that should store data.
  92. """
  93. assert self._valid_micro_batch(micro_batch_id)
  94. return micro_batch_id % self.num_pipe_buffers()
  95. def __iter__(self):
  96. self.it = None
  97. return self
  98. def __next__(self):
  99. if self.it is None:
  100. self.it = self.steps()
  101. return next(self.it)
  102. class InferenceSchedule(PipeSchedule):
  103. """A schedule for inferencing batches using pipeline parallelism.
  104. """
  105. def steps(self):
  106. """"""
  107. prev_micro_batch_id = -1
  108. total_steps = self.micro_batches + self.stages - 1
  109. for step_id in range(total_steps):
  110. cmds = []
  111. micro_batch_id = step_id - self.stage_id
  112. # Alternate send/recv buffers
  113. if _is_even(self.stage_id):
  114. recv_buf = step_id % 2
  115. send_buf = (step_id + 1) % 2
  116. else:
  117. recv_buf = (step_id + 1) % 2
  118. send_buf = step_id % 2
  119. if self.is_first_stage or self.is_last_stage:
  120. if self._valid_micro_batch(micro_batch_id):
  121. cmds.append(LoadMicroBatch(recv_buf))
  122. if _is_even(self.stage_id):
  123. if self._valid_stage(self.next_stage):
  124. if self._valid_micro_batch(micro_batch_id - 1):
  125. cmds.append(SendActivation(send_buf))
  126. if self._valid_stage(self.prev_stage):
  127. if self._valid_micro_batch(micro_batch_id):
  128. cmds.append(RecvActivation(recv_buf))
  129. else:
  130. if self._valid_stage(self.prev_stage):
  131. if self._valid_micro_batch(micro_batch_id):
  132. cmds.append(RecvActivation(recv_buf))
  133. if self._valid_stage(self.next_stage):
  134. if self._valid_micro_batch(micro_batch_id - 1):
  135. cmds.append(SendActivation(send_buf))
  136. if self._valid_micro_batch(micro_batch_id):
  137. cmds.append(ForwardPass(recv_buf))
  138. yield cmds
  139. def num_pipe_buffers(self):
  140. """Only two pipeline buffers are required for inferencing.
  141. Returns:
  142. ``2``
  143. """
  144. return 2
  145. class TrainSchedule(PipeSchedule):
  146. """A schedule for training a batch using hybrid parallelism.
  147. Pipeline parallelism is extracted through gradient accumulation and thus
  148. convergence follows that of a data parallel approach with the same batch
  149. size.
  150. """
  151. def steps(self):
  152. """"""
  153. prev_micro_batch_id = -1
  154. total_steps = 2 * (self.micro_batches + self.stages - 1)
  155. for step_id in range(total_steps):
  156. # Map the step of the pipeline to the micro-batch id and also whether it is a
  157. # forward or backward pass step.
  158. micro_batch_id, is_forward = self._step_to_micro_batch(step_id)
  159. if self._valid_micro_batch(prev_micro_batch_id):
  160. prev_buffer = self._buffer_idx(prev_micro_batch_id)
  161. if self._valid_micro_batch(micro_batch_id):
  162. curr_buffer = self._buffer_idx(micro_batch_id)
  163. cmds = []
  164. # Exchange activations
  165. if is_forward:
  166. if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(self.prev_stage):
  167. cmds.append(SendGrad(prev_buffer))
  168. if self._valid_micro_batch(micro_batch_id) and self._valid_stage(self.prev_stage):
  169. cmds.append(RecvActivation(curr_buffer))
  170. else:
  171. if self._valid_micro_batch(micro_batch_id) and self._valid_stage(self.next_stage):
  172. cmds.append(RecvGrad(curr_buffer))
  173. if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(self.next_stage):
  174. cmds.append(SendActivation(prev_buffer))
  175. # First/last stage loads
  176. if self.stage_id == 0 or self.stage_id == self.stages - 1:
  177. if is_forward and self._valid_micro_batch(micro_batch_id):
  178. cmds.append(LoadMicroBatch(curr_buffer))
  179. # Computation
  180. if self._valid_micro_batch(micro_batch_id):
  181. if is_forward:
  182. cmds.append(ForwardPass(curr_buffer))
  183. else:
  184. cmds.append(BackwardPass(curr_buffer))
  185. # Model step at the end of the batch
  186. if step_id == total_steps - 1:
  187. cmds.append(ReduceTiedGrads())
  188. cmds.append(ReduceGrads())
  189. cmds.append(OptimizerStep())
  190. # Prepare state for next time
  191. prev_micro_batch_id = micro_batch_id
  192. yield cmds
  193. def num_pipe_buffers(self):
  194. """Return the number of pipeline buffers required for this stage.
  195. This is equivalent to the maximum number of in-flight forward passes,
  196. since we need to remember the activations of forward passes in order
  197. to run backpropagation. For synchronous 1F1B, this is equivalent to
  198. the index difference between this stage and the last stage.
  199. """
  200. buffers = min(self.stages - self.stage_id, self.micro_batches)
  201. return max(2, buffers)
  202. def _step_to_micro_batch(self, step_id):
  203. if _is_even(step_id) and _is_even(self.stage_id):
  204. micro_batch_id = self._even_step_forward_id(step_id)
  205. is_forward = True
  206. elif _is_odd(step_id) and _is_odd(self.stage_id):
  207. micro_batch_id = self._odd_step_forward_id(step_id)
  208. is_forward = True
  209. elif _is_even(step_id) and _is_odd(self.stage_id):
  210. micro_batch_id = self._even_step_backward_id(step_id)
  211. is_forward = False
  212. elif _is_odd(step_id) and _is_even(self.stage_id):
  213. micro_batch_id = self._odd_step_backward_id(step_id)
  214. is_forward = False
  215. else:
  216. assert False
  217. return micro_batch_id, is_forward
  218. def _even_step_forward_id(self, step_id):
  219. base = step_id // 2
  220. micro_batch_id = int(base - self.stage_id // 2)
  221. return micro_batch_id
  222. def _odd_step_forward_id(self, step_id):
  223. base = (step_id - 1) // 2
  224. micro_batch_id = int(base - self.stage_id // 2)
  225. return micro_batch_id
  226. def _even_step_backward_id(self, step_id):
  227. base = step_id // 2
  228. micro_batch_id = int(base - self.stages + (self.stage_id + 1) // 2)
  229. return micro_batch_id
  230. def _odd_step_backward_id(self, step_id):
  231. base = ((step_id - 1) // 2) - self.stages + 1
  232. micro_batch_id = int(base + self.stage_id // 2)
  233. return micro_batch_id
  234. class DataParallelSchedule(PipeSchedule):
  235. """An example schedule that trains using traditional data parallelism with gradient
  236. accumulation.
  237. """
  238. def steps(self):
  239. """"""
  240. for step_id in range(self.micro_batches):
  241. cmds = [
  242. LoadMicroBatch(buffer_id=0),
  243. ForwardPass(buffer_id=0),
  244. BackwardPass(buffer_id=0),
  245. ]
  246. if step_id == self.micro_batches - 1:
  247. cmds.extend([
  248. ReduceGrads(),
  249. OptimizerStep(),
  250. ])
  251. yield cmds
  252. def num_pipe_buffers(self):
  253. """Only one pipeline buffer needed.
  254. """
  255. return 1
  256. class PipeInstruction:
  257. """Base class for all instructions to be executed by the pipeline engine.
  258. All keyword arguments are stored as members similar to a ``namedtuple``. These are
  259. then accessible to the :class:`PipeEngine` during execution.
  260. Args:
  261. kwargs (optional): keyword arguments to store as members
  262. """
  263. def __init__(self, **kwargs):
  264. self.name = self.__class__.__name__
  265. self.kwargs = kwargs
  266. for key, val in kwargs.items():
  267. setattr(self, key, val)
  268. def __repr__(self):
  269. return call_to_str(self.name, **self.kwargs)
  270. class OptimizerStep(PipeInstruction):
  271. """Performs one step with the optimizer and zeros gradients.
  272. .. note:: Should be issued after :class:`ReduceGrads` and :class:`ReduceTiedGrads`.
  273. .. note:: Can be a synchronization point among data-parallel ranks.
  274. """
  275. pass
  276. class ReduceGrads(PipeInstruction):
  277. """Reduce the computed gradients among data-parallel processes within the stage.
  278. """
  279. pass
  280. class ReduceTiedGrads(PipeInstruction):
  281. """Reduce the computed gradients of tied modules within a pipeline-parallel group.
  282. .. warning::
  283. The stages included in this synchronization point are not known until
  284. the model is partitioned among pipeline stages. In the worst case, it
  285. includes all pipeline stages. This instruction should be scheduled
  286. carefully to avoid deadlocks.
  287. """
  288. pass
  289. class BufferOpInstruction(PipeInstruction):
  290. """A pipeline instruction that operates on pipeline buffer(s).
  291. Args:
  292. buffer_id (int): the index of the pipeline buffer() to modify.
  293. """
  294. def __init__(self, buffer_id, **kwargs):
  295. super().__init__(buffer_id=buffer_id, **kwargs)
  296. # IO
  297. class LoadMicroBatch(BufferOpInstruction):
  298. """Load a micro-batch into a buffer.
  299. Roughly:
  300. .. code-block:: python
  301. buffers['inputs'][buffer_id] = next(data_iter)
  302. """
  303. pass
  304. # Compute
  305. class ForwardPass(BufferOpInstruction):
  306. """Compute a forward pass.
  307. Roughly:
  308. .. code-block:: python
  309. buffers['outputs'][buffer_id] = forward(buffers['inputs'][buffer_id])
  310. """
  311. pass
  312. class BackwardPass(BufferOpInstruction):
  313. """Compute a backward pass and accumulate gradients.
  314. Roughly:
  315. .. code-block:: python
  316. outputs = buffers['outputs'][buffer_id]
  317. gradients = buffers['gradients'][buffer_id]
  318. torch.autograd.backward(tensors=outputs,
  319. grad_tensors=gradients)
  320. """
  321. pass
  322. # Communication
  323. class SendActivation(BufferOpInstruction):
  324. """Send activations to the next stage in the pipeline.
  325. Roughly:
  326. .. code-block:: python
  327. send(buffers['outputs'][buffer_id])
  328. .. note::
  329. The communication is blocking and must be paired with a :class:`RecvActivation`
  330. on the next pipeline stage to avoid deadlock.
  331. """
  332. pass
  333. class RecvActivation(BufferOpInstruction):
  334. """Receive activations from the previous stage in the pipeline.
  335. Roughly:
  336. .. code-block:: python
  337. buffers['inputs'][buffer_id] = recv()
  338. .. note::
  339. The communication is blocking and must be paired with a :class:`SendActivation`
  340. on the previous pipeline stage to avoid deadlock.
  341. """
  342. pass
  343. class SendGrad(BufferOpInstruction):
  344. """Send computed gradients to the previous pipeline stage.
  345. with respect to the received activations
  346. .. note::
  347. Only received tensors with ``requires_grad==True`` will produce gradients.
  348. Missing gradients will be replaced with ``None`` on the receiving stage.
  349. .. note::
  350. The communication is blocking and must be paired with a :class:`RecvGrad`
  351. on the previous pipeline stage to avoid deadlock.
  352. """
  353. pass
  354. class RecvGrad(BufferOpInstruction):
  355. """Receive computed gradients the next pipeline stage.
  356. .. note::
  357. Only activations with ``requires_grad==True`` will produce gradients.
  358. Missing gradients will be replaced with ``None``.
  359. .. note::
  360. The communication is blocking and must be paired with a :class:`SendGrad`
  361. on the next pipeline stage to avoid deadlock.
  362. """
  363. pass
  364. def _is_even(x):
  365. return x % 2 == 0
  366. def _is_odd(x):
  367. return x % 2 != 0