test_pipe_module.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. import torch.distributed as dist
  5. import pytest
  6. import deepspeed
  7. from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology
  8. PipeTopo = PipeDataParallelTopology
  9. from deepspeed.pipe import PipelineModule, LayerSpec
  10. from deepspeed.utils import RepeatingLoader
  11. from common import distributed_test
  12. from simple_model import args_from_dict
  13. HIDDEN_DIM = 32
  14. LAYERS = 8
  15. @pytest.fixture
  16. def sequential_model():
  17. model = torch.nn.Sequential(
  18. *[nn.Linear(HIDDEN_DIM,
  19. HIDDEN_DIM) for _ in range(LAYERS)],
  20. nn.Linear(HIDDEN_DIM,
  21. 1),
  22. )
  23. return model
  24. @pytest.fixture
  25. def simple_args(tmpdir):
  26. config_dict = {
  27. "train_batch_size": 1,
  28. "train_micro_batch_size_per_gpu": 1,
  29. "steps_per_print": 1,
  30. "optimizer": {
  31. "type": "Adam",
  32. "params": {
  33. "lr": 0.001,
  34. "betas": [0.9,
  35. 0.999],
  36. "eps": 1e-8,
  37. "weight_decay": 3e-7
  38. }
  39. },
  40. "pipeline": {
  41. "activation_checkpoint_interval": 1
  42. }
  43. }
  44. args = args_from_dict(tmpdir, config_dict)
  45. return args
  46. def test_pipe_module_sequential(sequential_model, simple_args):
  47. batch_input = torch.randn(1, HIDDEN_DIM)
  48. @distributed_test(world_size=4)
  49. def _helper():
  50. base_model = copy.deepcopy(sequential_model)
  51. base_input = batch_input.clone().detach()
  52. base_output = base_model(base_input)
  53. base_output = base_output
  54. base_params = sum(p.numel() for p in base_model.parameters())
  55. pipe_model = copy.deepcopy(sequential_model)
  56. pipe_model = PipelineModule(layers=pipe_model, num_stages=4)
  57. # Ensure all parameters are accounted for.
  58. my_params = sum(p.numel() for p in pipe_model.parameters())
  59. total_pipe_params = torch.LongTensor([my_params]).to('cuda')
  60. dist.all_reduce(total_pipe_params)
  61. total_pipe_params = total_pipe_params.item()
  62. assert total_pipe_params == base_params
  63. pipe_model, _, _, _ = deepspeed.initialize(
  64. args=simple_args,
  65. model=pipe_model,
  66. model_parameters=[p for p in pipe_model.parameters()])
  67. if pipe_model.is_first_stage or pipe_model.is_last_stage:
  68. pipe_input = base_input.clone().detach().to('cuda')
  69. # label 0 is meaningless
  70. dataset = [(pipe_input, 0)]
  71. loader = RepeatingLoader(dataset)
  72. data_iter = iter(loader)
  73. else:
  74. data_iter = None
  75. pipe_output = pipe_model.eval_batch(data_iter=data_iter)
  76. base_output = base_output.to('cpu')
  77. pipe_output = pipe_output.to('cpu')
  78. assert torch.allclose(base_output, pipe_output)
  79. _helper()