test_pipe_schedule.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import pytest
  2. import deepspeed.runtime.pipe.schedule as schedule
  3. def _count_type(cmds, classtype):
  4. return len(list(filter(lambda c: type(c) == classtype, cmds)))
  5. def test_pipe_inference_schedule_singlestage():
  6. sched = schedule.InferenceSchedule(micro_batches=4, stages=1, stage_id=0)
  7. assert sched.num_micro_batches == 4
  8. full = list(iter(sched))
  9. for idx, cmds in enumerate(full):
  10. assert len(cmds) == 2
  11. assert type(cmds[0]) == schedule.LoadMicroBatch
  12. assert type(cmds[1]) == schedule.ForwardPass
  13. assert cmds[0].buffer_id == cmds[1].buffer_id
  14. assert len(full) == sched.num_micro_batches
  15. def test_pipe_train_schedule_singlestage():
  16. sched = schedule.TrainSchedule(micro_batches=4, stages=1, stage_id=0)
  17. assert sched.num_micro_batches == 4
  18. full = list(iter(sched))
  19. print()
  20. for idx, cmds in enumerate(full):
  21. print(idx, cmds)
  22. #assert len(cmds) == 2
  23. #assert type(cmds[0]) == schedule.LoadMicroBatch
  24. #assert type(cmds[1]) == schedule.ForwardPass
  25. #assert cmds[0].buffer_id == cmds[1].buffer_id
  26. #assert len(full) == sched.num_micro_batches
  27. @pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
  28. def test_pipe_inference_schedule_firststage(micro_batches, stages=3, verbose=False):
  29. sched = schedule.InferenceSchedule(micro_batches=micro_batches,
  30. stages=stages,
  31. stage_id=0)
  32. assert sched.num_micro_batches == micro_batches
  33. full = list(iter(sched))
  34. if verbose:
  35. print()
  36. for idx, cmds in enumerate(full):
  37. if verbose:
  38. print(idx, cmds)
  39. # Ensure we don't send an activation the first step
  40. if idx == 0:
  41. assert len(cmds) == 2
  42. assert type(cmds[0]) == schedule.LoadMicroBatch
  43. assert type(cmds[1]) == schedule.ForwardPass
  44. assert cmds[0].buffer_id == cmds[1].buffer_id
  45. continue
  46. # the last active step is only a send
  47. if idx == sched.num_micro_batches:
  48. assert len(cmds) == 1
  49. assert type(cmds[0]) == schedule.SendActivation
  50. continue
  51. # no work later on
  52. if idx > sched.num_micro_batches:
  53. assert len(cmds) == 0
  54. continue
  55. # Normally we need to load/forward/send
  56. assert len(cmds) == 3
  57. assert _count_type(cmds, schedule.LoadMicroBatch) == 1
  58. assert _count_type(cmds, schedule.ForwardPass) == 1
  59. assert _count_type(cmds, schedule.SendActivation) == 1
  60. assert len(full) == micro_batches + stages - 1
  61. @pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
  62. def test_pipe_inference_schedule_midstage(micro_batches, stages=3, verbose=False):
  63. sched = schedule.InferenceSchedule(micro_batches=micro_batches,
  64. stages=stages,
  65. stage_id=1)
  66. full = list(iter(sched))
  67. if verbose:
  68. print()
  69. for idx, cmds in enumerate(full):
  70. if verbose:
  71. print(idx, cmds)
  72. if idx < sched.stage:
  73. assert len(cmds) == 0
  74. continue
  75. if idx == sched.stage + sched.num_micro_batches:
  76. assert len(cmds) == 1
  77. assert type(cmds[0]) == schedule.SendActivation
  78. continue
  79. if idx > sched.stage + sched.num_micro_batches:
  80. assert len(cmds) == 0
  81. continue
  82. assert _count_type(cmds, schedule.LoadMicroBatch) == 0
  83. assert _count_type(cmds, schedule.ForwardPass) == 1
  84. assert _count_type(cmds, schedule.RecvActivation) == 1
  85. if idx > sched.stage:
  86. assert _count_type(cmds, schedule.SendActivation) == 1
  87. assert len(full) == micro_batches + stages - 1
  88. @pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
  89. def test_pipe_inference_schedule_laststage(micro_batches, stages=3, verbose=False):
  90. sched = schedule.InferenceSchedule(micro_batches=micro_batches,
  91. stages=stages,
  92. stage_id=2)
  93. full = list(iter(sched))
  94. if verbose:
  95. print()
  96. for idx, cmds in enumerate(full):
  97. if verbose:
  98. print(idx, cmds)
  99. if idx < sched.stage or idx > sched.stage + sched.num_micro_batches:
  100. assert len(cmds) == 0
  101. continue
  102. assert _count_type(cmds, schedule.LoadMicroBatch) == 1
  103. assert _count_type(cmds, schedule.ForwardPass) == 1
  104. assert _count_type(cmds, schedule.RecvActivation) == 1
  105. assert _count_type(cmds, schedule.SendActivation) == 0
  106. assert len(full) == micro_batches + stages - 1
  107. def test_pipe_schedule_firststage():
  108. sched = schedule.TrainSchedule(micro_batches=8, stages=3, stage_id=0)
  109. for cmds in sched:
  110. assert all(instr.__class__ != schedule.SendGrad for instr in cmds)
  111. assert all(instr.__class__ != schedule.RecvActivation for instr in cmds)
  112. for instr in cmds:
  113. if isinstance(instr, schedule.BufferOpInstruction):
  114. assert 0 <= instr.buffer_id < sched.num_pipe_buffers()
  115. def test_pipe_schedule_laststage():
  116. sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
  117. #assert len(sched) == 2 * (sched.micro_batches + sched.stages - 1)
  118. print()
  119. for cmds in sched:
  120. print(cmds)
  121. assert all(instr.__class__ != schedule.SendActivation for instr in cmds)
  122. assert all(instr.__class__ != schedule.RecvGrad for instr in cmds)
  123. def test_pipe_stagequery():
  124. sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=0)
  125. assert sched.is_first_stage
  126. assert not sched.is_last_stage
  127. sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=1)
  128. assert not sched.is_first_stage
  129. assert not sched.is_last_stage
  130. sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
  131. assert not sched.is_first_stage
  132. assert sched.is_last_stage