test_moe_tp.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import deepspeed
  6. import pytest
  7. from unit.common import DistributedTest
  8. from deepspeed.utils.torch import required_torch_version
  9. from deepspeed.moe.layer import MoE
  10. class MPU():
  11. def __init__(self, tp_world_size):
  12. self.rank = deepspeed.comm.get_rank()
  13. self.world_size = deepspeed.comm.get_world_size()
  14. self.tp_world_size = tp_world_size
  15. for i in range(0, self.world_size, tp_world_size):
  16. ranks = range(i, i + tp_world_size)
  17. group = deepspeed.comm.new_group(ranks)
  18. if self.rank in ranks:
  19. self.tp_group = group
  20. for i in range(0, tp_world_size):
  21. ranks = range(i, self.world_size, tp_world_size)
  22. group = deepspeed.comm.new_group(ranks)
  23. if self.rank in ranks:
  24. self.dp_group = group
  25. def get_model_parallel_rank(self):
  26. return self.rank % self.tp_world_size
  27. def get_model_parallel_world_size(self):
  28. return self.tp_world_size
  29. def get_data_parallel_rank(self):
  30. return self.rank // self.tp_world_size
  31. def get_data_parallel_world_size(self):
  32. return self.world_size // self.tp_world_size
  33. def get_data_parallel_group(self):
  34. return self.dp_group
  35. def get_model_parallel_group(self):
  36. return self.tp_group
  37. @pytest.mark.parametrize("ep_size, tp_size", [(1, 2), (1, 4), (2, 2)])
  38. @pytest.mark.parametrize("enable_expert_tp", [True, False])
  39. @pytest.mark.parametrize("use_residual", [True, False])
  40. class TestMOETensorParallel(DistributedTest):
  41. world_size = 4
  42. def test(self, ep_size, tp_size, enable_expert_tp, use_residual):
  43. # TODO: replace this with a true parallel mlp in the future
  44. # and run convergence tests
  45. if not required_torch_version(min_version=1.8):
  46. pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")
  47. config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}}
  48. hidden_dim = 16
  49. tensor_parallel_expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, 4 * hidden_dim // tp_size),
  50. torch.nn.ReLU(),
  51. torch.nn.Linear(4 * hidden_dim // tp_size, hidden_dim))
  52. # set num experts to world size
  53. world_size = deepspeed.comm.get_world_size()
  54. model = MoE(
  55. hidden_size=hidden_dim,
  56. expert=tensor_parallel_expert,
  57. num_experts=world_size,
  58. ep_size=ep_size,
  59. use_residual=use_residual,
  60. enable_expert_tensor_parallelism=enable_expert_tp,
  61. )
  62. optimizer = torch.optim.AdamW(params=model.parameters())
  63. model, _, _, _ = deepspeed.initialize(config=config_dict,
  64. model=model,
  65. optimizer=optimizer,
  66. dist_init_required=False,
  67. mpu=MPU(tp_size))
  68. assert model.num_local_experts == world_size // ep_size
  69. if enable_expert_tp:
  70. assert deepspeed.utils.groups._get_expert_model_parallel_world_size() == tp_size
  71. else:
  72. assert deepspeed.utils.groups._get_expert_model_parallel_world_size() == 1