base_moe.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. # Create a container object to save model-specific tensors using the policy file above.
  3. from .base import *
  4. from deepspeed import comm as dist
  5. import deepspeed.ops.transformer as transformer_inference
  6. from deepspeed.accelerator import get_accelerator
  7. class BaseTransformerMoEContainer(BaseTransformerContainer):
  8. def __init__(self, **kwargs):
  9. # Call the init function of the parent class to initialize the tensors and configs from parent class
  10. super().__init__(**kwargs)
  11. self.num_experts = self.policy.get_num_experts()
  12. self.ep_world_size = dist.get_world_size()
  13. self.local_ep_size = 1 if self.num_experts < self.ep_world_size else self.num_experts // self.ep_world_size
  14. self.layer_norm_eps = self.config.layer_norm_eps if hasattr(
  15. self.config,
  16. 'layer_norm_eps') else 1e-12,
  17. # MoE models will have a list of mlp related tensors
  18. self._h4h_w = []
  19. self._h4h_b = []
  20. self._4hh_w = []
  21. self._4hh_b = []
  22. # Residual MoE needs extra parameters
  23. self._res_h4h_w = None
  24. self._res_h4h_b = None
  25. self._res_4hh_w = None
  26. self._res_4hh_b = None
  27. self._res_coef = None
  28. def create_ds_model_config(self):
  29. self.set_hidden_heads(*self.policy.get_hidden_heads())
  30. assert self.num_attention_heads % self.mp_size == 0,\
  31. "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
  32. "This is because the attention computation is partitioned evenly among the parallel GPUs."
  33. self.ds_model_config = transformer_inference.DeepSpeedMoEInferenceConfig(
  34. hidden_size=self.hidden_size,
  35. heads=self.num_attention_heads,
  36. layer_norm_eps=self.layer_norm_eps,
  37. fp16=self.fp16,
  38. pre_layer_norm=self.pre_layer_norm,
  39. mp_size=self.mp_size,
  40. q_int8=self.quantize,
  41. moe_experts=self.local_ep_size,
  42. global_experts=self.num_experts,
  43. mlp_type=self.config.moe.type,
  44. scale_attn_by_inverse_layer_idx=self.scale_attn_by_inverse_layer_idx,
  45. )
  46. return self.ds_model_config
  47. def initialize_tensors(self):
  48. # Set the tensors from policy (user module) to container (DS module)
  49. self.set_attention(*self.policy.attention())
  50. self.set_mlp(self.config.moe.type)
  51. self.set_layernorm(*self.policy.layernorm())
  52. def set_mlp(self, config_moe_type):
  53. if config_moe_type == 'standard':
  54. self._h4h_w, self._h4h_b, \
  55. self._4hh_w, self._4hh_b = self.policy.mlp()
  56. else:
  57. self._h4h_w, self._h4h_b, self._4hh_w, \
  58. self._4hh_b, self._res_h4h_w, self._res_h4h_b, \
  59. self._res_4hh_w, self._res_4hh_b, \
  60. self._res_coef = self.policy.mlp(config_moe_type)
  61. def transpose(self):
  62. self.transpose_attention()
  63. self.transpose_mlp()
  64. if self.config.moe.type == 'residual':
  65. self.transpose_residual()
  66. def transpose_mlp(self):
  67. self._h4h_w = [self.transpose_impl(moe_w1.data) for moe_w1 in self._h4h_w]
  68. self._4hh_w = [self.transpose_impl(moe_w1.data) for moe_w1 in self._4hh_w]
  69. def transpose_residual(self):
  70. self._res_h4h_w.data = self.transpose_impl(self._res_h4h_w.data)
  71. self._res_4hh_w.data = self.transpose_impl(self._res_4hh_w.data)
  72. self._res_coef.data = self.transpose_impl(self._res_coef.data)
  73. def apply_tensor_parallelism(self, mp_replace):
  74. # setup the new Attention module
  75. self.attention_qkv_mp(mp_replace)
  76. self.attention_o_mp(mp_replace)
  77. # quantize attention weights
  78. self.attention_quantization()
  79. # setup the new MLP module
  80. self.mlp_mp()
  81. def mlp_mp(self):
  82. gpu_index = dist.get_rank()
  83. for ep_index in range(self.local_ep_size):
  84. # mlp inter
  85. self.module.mlp[ep_index].inter_w.data = self._h4h_w[
  86. gpu_index * self.local_ep_size + ep_index].to(
  87. get_accelerator().current_device_name())
  88. self.module.mlp[ep_index].inter_b.data = self._h4h_b[
  89. gpu_index * self.local_ep_size + ep_index].to(
  90. get_accelerator().current_device_name())
  91. # mlp output
  92. self.module.mlp[ep_index].output_w.data = self._4hh_w[
  93. gpu_index * self.local_ep_size + ep_index].to(
  94. get_accelerator().current_device_name())
  95. self.module.mlp[ep_index].output_b.data = self._4hh_b[
  96. gpu_index * self.local_ep_size + ep_index].to(
  97. get_accelerator().current_device_name())
  98. def copy_data_to_new_module(self):
  99. self.module.attn_nw.data = self.attn_nw.to(
  100. get_accelerator().current_device_name())
  101. self.module.attn_nb.data = self.attn_nb.to(
  102. get_accelerator().current_device_name())
  103. self.module.norm_w.data.copy_(
  104. self.input_nw.to(get_accelerator().current_device_name()))
  105. self.module.norm_b.data.copy_(
  106. self.input_nb.to(get_accelerator().current_device_name()))
  107. if self.config.moe.type == 'residual':
  108. self.module.res_mlp.inter_w.data = self._res_h4h_w.to(
  109. get_accelerator().current_device_name())
  110. self.module.res_mlp.inter_b.data = self._res_h4h_b.to(
  111. get_accelerator().current_device_name())
  112. self.module.res_mlp.output_w.data = self._res_4hh_w.to(
  113. get_accelerator().current_device_name())
  114. self.module.res_mlp.output_b.data = self._res_4hh_b.to(
  115. get_accelerator().current_device_name())
  116. self.module.res_coef.data = self._res_coef.to(
  117. get_accelerator().current_device_name())