layer.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.utils import log_dist
  6. from deepspeed.utils import groups
  7. from .sharded_moe import MOELayer, TopKGate
  8. from .experts import Experts
  9. import typing
  10. class MoE(torch.nn.Module):
  11. """Initialize an MoE layer.
  12. Arguments:
  13. hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
  14. expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
  15. num_experts (int, optional): default=1, the total number of experts per layer.
  16. ep_size (int, optional): default=1, number of ranks in the expert parallel world or group.
  17. k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
  18. capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
  19. eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
  20. min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
  21. use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
  22. noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
  23. drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
  24. use_rts (bool, optional): default=True, whether to use Random Token Selection.
  25. use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
  26. enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts
  27. """
  28. def __init__(self,
  29. hidden_size,
  30. expert,
  31. num_experts=1,
  32. ep_size=1,
  33. k=1,
  34. capacity_factor=1.,
  35. eval_capacity_factor=1.,
  36. min_capacity=4,
  37. use_residual=False,
  38. noisy_gate_policy: typing.Optional[str] = None,
  39. drop_tokens: bool = True,
  40. use_rts=True,
  41. use_tutel: bool = False,
  42. enable_expert_tensor_parallelism: bool = False):
  43. super(MoE, self).__init__()
  44. self.use_residual = use_residual
  45. self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism
  46. assert num_experts % ep_size == 0, f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})"
  47. self.ep_size = ep_size
  48. self.expert_group_name = f"ep_size_{self.ep_size}"
  49. self.num_experts = num_experts
  50. self.num_local_experts = num_experts // self.ep_size
  51. log_dist(
  52. f'Creating MoE layer with num_experts: {num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}',
  53. [0])
  54. assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \
  55. 'Unsupported noisy_gate_policy: ' + noisy_gate_policy
  56. experts = Experts(expert, self.num_local_experts, self.expert_group_name)
  57. self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
  58. min_capacity, noisy_gate_policy, drop_tokens, use_rts),
  59. experts,
  60. self.expert_group_name,
  61. self.ep_size,
  62. self.num_local_experts,
  63. use_tutel=use_tutel)
  64. if self.use_residual:
  65. self.mlp = expert
  66. # coefficient is used for weighted sum of the output of expert and mlp
  67. self.coefficient = torch.nn.Linear(hidden_size, 2)
  68. def set_deepspeed_parallelism(self, use_data_before_expert_parallel_=False):
  69. self._create_process_groups(use_data_before_expert_parallel_=use_data_before_expert_parallel_)
  70. def _create_process_groups(self, use_data_before_expert_parallel_=False):
  71. # Create process group for a layer if needed
  72. if self.expert_group_name not in groups._get_expert_parallel_group_dict():
  73. print(f"No existing process group found, creating a new group named: {self.expert_group_name}")
  74. if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism):
  75. # Condition 1 - no groups.mpu means no tensor parallelism
  76. # Condition 2 - disabling expert tensor parallelism on purpose
  77. groups._create_expert_and_data_parallel(
  78. self.ep_size, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
  79. else:
  80. # expert tensor parallelism is enabled
  81. groups._create_expert_data_and_model_parallel(
  82. self.ep_size, mpu=groups.mpu, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
  83. # Set the group handle for the MOELayer (deepspeed_moe) object
  84. self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))
  85. def forward(self, hidden_states, used_token=None):
  86. """ MoE forward
  87. Arguments:
  88. hidden_states (Tensor): input to the layer
  89. used_token (Tensor, optional): default: None, mask only used tokens
  90. Returns:
  91. A tuple including output, gate loss, and expert count.
  92. * output (Tensor): output of the model
  93. * l_aux (Tensor): gate loss value
  94. * exp_counts (int): expert count
  95. """
  96. output = self.deepspeed_moe(hidden_states, used_token)
  97. if self.use_residual:
  98. # Residual MoE
  99. output_mlp = self.mlp(hidden_states)
  100. if type(output_mlp) is tuple:
  101. output_mlp = output_mlp[0] # Ignore the bias term for now
  102. coef = self.coefficient(hidden_states)
  103. coef = torch.nn.functional.softmax(coef, dim=-1)
  104. output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
  105. return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts