schedule.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from ..utils import call_to_str
  3. from abc import ABC, abstractmethod
  4. class PipeSchedule(ABC):
  5. """Directs the execution of a pipeline engine by generating sequences of
  6. :class:`PipeInstruction`.
  7. Schedules are generators that yield sequences of
  8. :class:`PipeInstruction` to process the micro-batches in one batch.
  9. Each yielded step is atomic in the sense that a barrier
  10. synchronization can be placed between successive steps without
  11. deadlock.
  12. Below is an example schedule that implements data parallelism with gradient accumulation:
  13. .. code-block:: python
  14. class DataParallelSchedule(PipeSchedule):
  15. def steps(self):
  16. for step_id in range(self.micro_batches):
  17. cmds = [
  18. LoadMicroBatch(buffer_id=0),
  19. ForwardPass(buffer_id=0),
  20. BackwardPass(buffer_id=0),
  21. ]
  22. if step_id == self.micro_batches - 1:
  23. cmds.extend([
  24. ReduceGrads(),
  25. OptimizerStep(),
  26. ])
  27. yield cmds
  28. def num_pipe_buffers(self):
  29. return 1
  30. Args:
  31. micro_batches (int): The number of micro-batches that comprise a batch.
  32. stages (int): The number of pipeline stages.
  33. stage_id (int): The pipe stage that will execute the generated schedule.
  34. """
  35. def __init__(self, micro_batches, stages, stage_id):
  36. super().__init__()
  37. self.micro_batches = micro_batches
  38. self.stages = stages
  39. self.stage_id = stage_id
  40. self.prev_stage = self.stage_id - 1
  41. self.next_stage = self.stage_id + 1
  42. @abstractmethod
  43. def steps(self):
  44. """Yield a list of :class:`PipeInstruction` for each step in the schedule.
  45. .. note::
  46. Schedules must implement ``steps()`` to define the schedule.
  47. Returns:
  48. Instructions to be executed as one step of the pipeline
  49. """
  50. pass
  51. def num_pipe_buffers(self):
  52. """The number of pipeline buffers that will be used by this stage.
  53. .. note::
  54. Schedules should specialize ``num_pipe_buffers()`` for memory savings at scale.
  55. Returns:
  56. The number of buffers for the engine to allocate.
  57. """
  58. return self.micro_batches
  59. def _valid_micro_batch(self, micro_batch_id):
  60. return 0 <= micro_batch_id < self.micro_batches
  61. def _valid_stage(self, stage_id):
  62. return 0 <= stage_id < self.stages
  63. @property
  64. def stage(self):
  65. """Stage index used to configure this schedule."""
  66. return self.stage_id
  67. @property
  68. def num_stages(self):
  69. """The number of total pipeline stages used to configure this schedule."""
  70. return self.stages
  71. @property
  72. def num_micro_batches(self):
  73. """The number of total micro_batches used to configure this schedule."""
  74. return self.micro_batches
  75. @property
  76. def is_first_stage(self):
  77. """True if the configured ``stage_id`` is the first stage in the pipeline."""
  78. return self.stage_id == 0
  79. @property
  80. def is_last_stage(self):
  81. """True if the configured ``stage_id`` is the last stage in the pipeline."""
  82. return self.stage_id == self.stages - 1
  83. def _buffer_idx(self, micro_batch_id):
  84. """Map a micro-batch index to a pipeline buffer index.
  85. This method uses a cyclic allocation strategy.
  86. Args:
  87. micro_batch_id (int): The micro-batch index relative to the beginning of the schedule.
  88. Returns:
  89. int: The index of the buffer that should store data.
  90. """
  91. assert self._valid_micro_batch(micro_batch_id)
  92. return micro_batch_id % self.num_pipe_buffers()
  93. def __iter__(self):
  94. self.it = None
  95. return self
  96. def __next__(self):
  97. if self.it is None:
  98. self.it = self.steps()
  99. return next(self.it)
  100. class InferenceSchedule(PipeSchedule):
  101. """A schedule for inferencing batches using pipeline parallelism.
  102. """
  103. def steps(self):
  104. """"""
  105. prev_micro_batch_id = -1
  106. total_steps = self.micro_batches + self.stages - 1
  107. for step_id in range(total_steps):
  108. cmds = []
  109. micro_batch_id = step_id - self.stage_id
  110. # Alternate send/recv buffers
  111. if _is_even(self.stage_id):
  112. recv_buf = step_id % 2
  113. send_buf = (step_id + 1) % 2
  114. else:
  115. recv_buf = (step_id + 1) % 2
  116. send_buf = step_id % 2
  117. if self.is_first_stage or self.is_last_stage:
  118. if self._valid_micro_batch(micro_batch_id):
  119. cmds.append(LoadMicroBatch(recv_buf))
  120. if _is_even(self.stage_id):
  121. if self._valid_stage(self.next_stage):
  122. if self._valid_micro_batch(micro_batch_id - 1):
  123. cmds.append(SendActivation(send_buf))
  124. if self._valid_stage(self.prev_stage):
  125. if self._valid_micro_batch(micro_batch_id):
  126. cmds.append(RecvActivation(recv_buf))
  127. else:
  128. if self._valid_stage(self.prev_stage):
  129. if self._valid_micro_batch(micro_batch_id):
  130. cmds.append(RecvActivation(recv_buf))
  131. if self._valid_stage(self.next_stage):
  132. if self._valid_micro_batch(micro_batch_id - 1):
  133. cmds.append(SendActivation(send_buf))
  134. if self._valid_micro_batch(micro_batch_id):
  135. cmds.append(ForwardPass(recv_buf))
  136. yield cmds
  137. def num_pipe_buffers(self):
  138. """Only two pipeline buffers are required for inferencing.
  139. Returns:
  140. ``2``
  141. """
  142. return 2
  143. class TrainSchedule(PipeSchedule):
  144. """A schedule for training a batch using hybrid parallelism.
  145. Pipeline parallelism is extracted through gradient accumulation and thus
  146. convergence follows that of a data parallel approach with the same batch
  147. size.
  148. """
  149. def steps(self):
  150. """"""
  151. prev_micro_batch_id = -1
  152. total_steps = 2 * (self.micro_batches + self.stages - 1)
  153. for step_id in range(total_steps):
  154. # Map the step of the pipeline to the micro-batch id and also whether it is a
  155. # forward or backward pass step.
  156. micro_batch_id, is_forward = self._step_to_micro_batch(step_id)
  157. if self._valid_micro_batch(prev_micro_batch_id):
  158. prev_buffer = self._buffer_idx(prev_micro_batch_id)
  159. if self._valid_micro_batch(micro_batch_id):
  160. curr_buffer = self._buffer_idx(micro_batch_id)
  161. cmds = []
  162. # Exchange activations
  163. if is_forward:
  164. if self._valid_micro_batch(micro_batch_id) and self._valid_stage(
  165. self.prev_stage):
  166. cmds.append(RecvActivation(curr_buffer))
  167. if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(
  168. self.prev_stage):
  169. cmds.append(SendGrad(prev_buffer))
  170. else:
  171. if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(
  172. self.next_stage):
  173. cmds.append(SendActivation(prev_buffer))
  174. if self._valid_micro_batch(micro_batch_id) and self._valid_stage(
  175. self.next_stage):
  176. cmds.append(RecvGrad(curr_buffer))
  177. # First/last stage loads
  178. if self.stage_id == 0 or self.stage_id == self.stages - 1:
  179. if is_forward and self._valid_micro_batch(micro_batch_id):
  180. cmds.append(LoadMicroBatch(curr_buffer))
  181. # Computation
  182. if self._valid_micro_batch(micro_batch_id):
  183. if is_forward:
  184. cmds.append(ForwardPass(curr_buffer))
  185. else:
  186. cmds.append(BackwardPass(curr_buffer))
  187. # Model step at the end of the batch
  188. if step_id == total_steps - 1:
  189. cmds.append(ReduceTiedGrads())
  190. cmds.append(ReduceGrads())
  191. cmds.append(OptimizerStep())
  192. # Prepare state for next time
  193. prev_micro_batch_id = micro_batch_id
  194. yield cmds
  195. def num_pipe_buffers(self):
  196. """As many buffers as the distance from this stage to the last stage.
  197. """
  198. buffers = min(self.stages - self.stage_id + 1, self.micro_batches)
  199. return max(2, buffers)
  200. def _step_to_micro_batch(self, step_id):
  201. if _is_even(step_id) and _is_even(self.stage_id):
  202. micro_batch_id = self._even_step_forward_id(step_id)
  203. is_forward = True
  204. elif _is_odd(step_id) and _is_odd(self.stage_id):
  205. micro_batch_id = self._odd_step_forward_id(step_id)
  206. is_forward = True
  207. elif _is_even(step_id) and _is_odd(self.stage_id):
  208. micro_batch_id = self._even_step_backward_id(step_id)
  209. is_forward = False
  210. elif _is_odd(step_id) and _is_even(self.stage_id):
  211. micro_batch_id = self._odd_step_backward_id(step_id)
  212. is_forward = False
  213. else:
  214. assert False
  215. return micro_batch_id, is_forward
  216. def _even_step_forward_id(self, step_id):
  217. base = step_id // 2
  218. micro_batch_id = int(base - self.stage_id // 2)
  219. return micro_batch_id
  220. def _odd_step_forward_id(self, step_id):
  221. base = (step_id - 1) // 2
  222. micro_batch_id = int(base - self.stage_id // 2)
  223. return micro_batch_id
  224. def _even_step_backward_id(self, step_id):
  225. base = step_id // 2
  226. micro_batch_id = int(base - self.stages + (self.stage_id + 1) // 2)
  227. return micro_batch_id
  228. def _odd_step_backward_id(self, step_id):
  229. base = ((step_id - 1) // 2) - self.stages + 1
  230. micro_batch_id = int(base + self.stage_id // 2)
  231. return micro_batch_id
  232. class DataParallelSchedule(PipeSchedule):
  233. """An example schedule that trains using traditional data parallelism with gradient
  234. accumulation.
  235. """
  236. def steps(self):
  237. """"""
  238. for step_id in range(self.micro_batches):
  239. cmds = [
  240. LoadMicroBatch(buffer_id=0),
  241. ForwardPass(buffer_id=0),
  242. BackwardPass(buffer_id=0),
  243. ]
  244. if step_id == self.micro_batches - 1:
  245. cmds.extend([
  246. ReduceGrads(),
  247. OptimizerStep(),
  248. ])
  249. yield cmds
  250. def num_pipe_buffers(self):
  251. """Only one pipeline buffer needed.
  252. """
  253. return 1
  254. class PipeInstruction:
  255. """Base class for all instructions to be executed by the pipeline engine.
  256. All keyword arguments are stored as members similar to a ``namedtuple``. These are
  257. then accessible to the :class:`PipeEngine` during execution.
  258. Args:
  259. kwargs (optional): keyword arguments to store as members
  260. """
  261. def __init__(self, **kwargs):
  262. self.name = self.__class__.__name__
  263. self.kwargs = kwargs
  264. for key, val in kwargs.items():
  265. setattr(self, key, val)
  266. def __repr__(self):
  267. return call_to_str(self.name, **self.kwargs)
  268. class OptimizerStep(PipeInstruction):
  269. """Performs one step with the optimizer and zeros gradients.
  270. .. note:: Should be issued after :class:`ReduceGrads` and :class:`ReduceTiedGrads`.
  271. .. note:: Can be a synchronization point among data-parallel ranks.
  272. """
  273. pass
  274. class ReduceGrads(PipeInstruction):
  275. """Reduce the computed gradients among data-parallel processes within the stage.
  276. """
  277. pass
  278. class ReduceTiedGrads(PipeInstruction):
  279. """Reduce the computed gradients of tied modules within a pipeline-parallel group.
  280. .. warning::
  281. The stages included in this synchronization point are not known until
  282. the model is partitioned among pipeline stages. In the worst case, it
  283. includes all pipeline stages. This instruction should be scheduled
  284. carefully to avoid deadlocks.
  285. """
  286. pass
  287. class BufferOpInstruction(PipeInstruction):
  288. """A pipeline instruction that operates on pipeline buffer(s).
  289. Args:
  290. buffer_id (int): the index of the pipeline buffer() to modify.
  291. """
  292. def __init__(self, buffer_id, **kwargs):
  293. super().__init__(buffer_id=buffer_id, **kwargs)
  294. # IO
  295. class LoadMicroBatch(BufferOpInstruction):
  296. """Load a micro-batch into a buffer.
  297. Roughly:
  298. .. code-block:: python
  299. buffers['inputs'][buffer_id] = next(data_iter)
  300. """
  301. pass
  302. # Compute
  303. class ForwardPass(BufferOpInstruction):
  304. """Compute a forward pass.
  305. Roughly:
  306. .. code-block:: python
  307. buffers['outputs'][buffer_id] = forward(buffers['inputs'][buffer_id])
  308. """
  309. pass
  310. class BackwardPass(BufferOpInstruction):
  311. """Compute a backward pass and accumulate gradients.
  312. Roughly:
  313. .. code-block:: python
  314. outputs = buffers['outputs'][buffer_id]
  315. gradients = buffers['gradients'][buffer_id]
  316. torch.autograd.backward(tensors=outputs,
  317. grad_tensors=gradients)
  318. """
  319. pass
  320. # Communication
  321. class SendActivation(BufferOpInstruction):
  322. """Send activations to the next stage in the pipeline.
  323. Roughly:
  324. .. code-block:: python
  325. send(buffers['outputs'][buffer_id])
  326. .. note::
  327. The communication is blocking and must be paired with a :class:`RecvActivation`
  328. on the next pipeline stage to avoid deadlock.
  329. """
  330. pass
  331. class RecvActivation(BufferOpInstruction):
  332. """Receive activations from the previous stage in the pipeline.
  333. Roughly:
  334. .. code-block:: python
  335. buffers['inputs'][buffer_id] = recv()
  336. .. note::
  337. The communication is blocking and must be paired with a :class:`SendActivation`
  338. on the previous pipeline stage to avoid deadlock.
  339. """
  340. pass
  341. class SendGrad(BufferOpInstruction):
  342. """Send computed gradients to the previous pipeline stage.
  343. with respect to the received activations
  344. .. note::
  345. Only received tensors with ``requires_grad==True`` will produce gradients.
  346. Missing gradients will be replaced with ``None`` on the receiving stage.
  347. .. note::
  348. The communication is blocking and must be paired with a :class:`RecvGrad`
  349. on the previous pipeline stage to avoid deadlock.
  350. """
  351. pass
  352. class RecvGrad(BufferOpInstruction):
  353. """Receive computed gradients the next pipeline stage.
  354. .. note::
  355. Only activations with ``requires_grad==True`` will produce gradients.
  356. Missing gradients will be replaced with ``None``.
  357. .. note::
  358. The communication is blocking and must be paired with a :class:`SendGrad`
  359. on the next pipeline stage to avoid deadlock.
  360. """
  361. pass
  362. def _is_even(x):
  363. return x % 2 == 0
  364. def _is_odd(x):
  365. return x % 2 != 0