123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- import pytest
- import deepspeed.runtime.pipe.schedule as schedule
- def _count_type(cmds, classtype):
- return len(list(filter(lambda c: type(c) == classtype, cmds)))
- def test_pipe_inference_schedule_singlestage():
- sched = schedule.InferenceSchedule(micro_batches=4, stages=1, stage_id=0)
- assert sched.num_micro_batches == 4
- full = list(iter(sched))
- for idx, cmds in enumerate(full):
- assert len(cmds) == 2
- assert type(cmds[0]) == schedule.LoadMicroBatch
- assert type(cmds[1]) == schedule.ForwardPass
- assert cmds[0].buffer_id == cmds[1].buffer_id
- assert len(full) == sched.num_micro_batches
- def test_pipe_train_schedule_singlestage():
- sched = schedule.TrainSchedule(micro_batches=4, stages=1, stage_id=0)
- assert sched.num_micro_batches == 4
- full = list(iter(sched))
- print()
- for idx, cmds in enumerate(full):
- print(idx, cmds)
- #assert len(cmds) == 2
- #assert type(cmds[0]) == schedule.LoadMicroBatch
- #assert type(cmds[1]) == schedule.ForwardPass
- #assert cmds[0].buffer_id == cmds[1].buffer_id
- #assert len(full) == sched.num_micro_batches
- @pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
- def test_pipe_inference_schedule_firststage(micro_batches, stages=3, verbose=False):
- sched = schedule.InferenceSchedule(micro_batches=micro_batches,
- stages=stages,
- stage_id=0)
- assert sched.num_micro_batches == micro_batches
- full = list(iter(sched))
- if verbose:
- print()
- for idx, cmds in enumerate(full):
- if verbose:
- print(idx, cmds)
- # Ensure we don't send an activation the first step
- if idx == 0:
- assert len(cmds) == 2
- assert type(cmds[0]) == schedule.LoadMicroBatch
- assert type(cmds[1]) == schedule.ForwardPass
- assert cmds[0].buffer_id == cmds[1].buffer_id
- continue
- # the last active step is only a send
- if idx == sched.num_micro_batches:
- assert len(cmds) == 1
- assert type(cmds[0]) == schedule.SendActivation
- continue
- # no work later on
- if idx > sched.num_micro_batches:
- assert len(cmds) == 0
- continue
- # Normally we need to load/forward/send
- assert len(cmds) == 3
- assert _count_type(cmds, schedule.LoadMicroBatch) == 1
- assert _count_type(cmds, schedule.ForwardPass) == 1
- assert _count_type(cmds, schedule.SendActivation) == 1
- assert len(full) == micro_batches + stages - 1
- @pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
- def test_pipe_inference_schedule_midstage(micro_batches, stages=3, verbose=False):
- sched = schedule.InferenceSchedule(micro_batches=micro_batches,
- stages=stages,
- stage_id=1)
- full = list(iter(sched))
- if verbose:
- print()
- for idx, cmds in enumerate(full):
- if verbose:
- print(idx, cmds)
- if idx < sched.stage:
- assert len(cmds) == 0
- continue
- if idx == sched.stage + sched.num_micro_batches:
- assert len(cmds) == 1
- assert type(cmds[0]) == schedule.SendActivation
- continue
- if idx > sched.stage + sched.num_micro_batches:
- assert len(cmds) == 0
- continue
- assert _count_type(cmds, schedule.LoadMicroBatch) == 0
- assert _count_type(cmds, schedule.ForwardPass) == 1
- assert _count_type(cmds, schedule.RecvActivation) == 1
- if idx > sched.stage:
- assert _count_type(cmds, schedule.SendActivation) == 1
- assert len(full) == micro_batches + stages - 1
- @pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
- def test_pipe_inference_schedule_laststage(micro_batches, stages=3, verbose=False):
- sched = schedule.InferenceSchedule(micro_batches=micro_batches,
- stages=stages,
- stage_id=2)
- full = list(iter(sched))
- if verbose:
- print()
- for idx, cmds in enumerate(full):
- if verbose:
- print(idx, cmds)
- if idx < sched.stage or idx > sched.stage + sched.num_micro_batches:
- assert len(cmds) == 0
- continue
- assert _count_type(cmds, schedule.LoadMicroBatch) == 1
- assert _count_type(cmds, schedule.ForwardPass) == 1
- assert _count_type(cmds, schedule.RecvActivation) == 1
- assert _count_type(cmds, schedule.SendActivation) == 0
- assert len(full) == micro_batches + stages - 1
- def test_pipe_schedule_firststage():
- sched = schedule.TrainSchedule(micro_batches=8, stages=3, stage_id=0)
- for cmds in sched:
- assert all(instr.__class__ != schedule.SendGrad for instr in cmds)
- assert all(instr.__class__ != schedule.RecvActivation for instr in cmds)
- for instr in cmds:
- if isinstance(instr, schedule.BufferOpInstruction):
- assert 0 <= instr.buffer_id < sched.num_pipe_buffers()
- def test_pipe_schedule_laststage():
- sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
- #assert len(sched) == 2 * (sched.micro_batches + sched.stages - 1)
- print()
- for cmds in sched:
- print(cmds)
- assert all(instr.__class__ != schedule.SendActivation for instr in cmds)
- assert all(instr.__class__ != schedule.RecvGrad for instr in cmds)
- def test_pipe_stagequery():
- sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=0)
- assert sched.is_first_stage
- assert not sched.is_last_stage
- sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=1)
- assert not sched.is_first_stage
- assert not sched.is_last_stage
- sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
- assert not sched.is_first_stage
- assert sched.is_last_stage
|