megatron_gpt.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .base import *
  3. from .features.megatron import MegatronContainer
  4. from deepspeed.model_implementations.transformers.ds_megatron_gpt import DeepSpeedMegatronGPTInference
  5. import torch
  6. from ..policy import TransformerPolicy
  7. from packaging import version as pkg_version
  8. class DS_MegatronGPTContainer(MegatronContainer, BaseTransformerContainer):
  9. def __init__(self, **kwargs):
  10. super().__init__(**kwargs)
  11. # All model specific things should be defined here instead of the base class.
  12. def create_module(self, config=None):
  13. _config = config if config is not None else self.ds_model_config
  14. self.module = DeepSpeedMegatronGPTInference(_config, mp_group=self.mp_group)
  15. self.module.config.scale_attention = self.scale_attention
  16. if self.megatron_v2:
  17. self.module.config.rotate_half = True
  18. self.module.config.rotate_every_two = False
  19. return self.module
  20. # TODO: Megatron GPT MoE inherits from Megatron policy and replaces mlp
  21. # TODO: Generalize MoE overall goal, expand beyond Megatron
  22. class MegatronLayerPolicy(TransformerPolicy):
  23. _orig_layer_class = None
  24. version = 0
  25. moe_type = 'standard'
  26. megatron_v2 = True
  27. use_mup = False
  28. def __init__(self, client_module, inference=True):
  29. super().__init__(inference,
  30. megatron_v2=MegatronLayerPolicy.megatron_v2,
  31. use_mup=MegatronLayerPolicy.use_mup)
  32. self.client_module = client_module
  33. # we use megatron version to differentiate between the old and new
  34. # megatron-lm source code
  35. if MegatronLayerPolicy._orig_layer_class is None:
  36. if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
  37. MegatronLayerPolicy._orig_layer_class = None
  38. else:
  39. try:
  40. from megatron.model.transformer import ParallelTransformerLayer
  41. MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer
  42. except ImportError:
  43. MegatronLayerPolicy._orig_layer_class = None
  44. def get_hidden_heads(self):
  45. return self.client_module.attention.query_key_value.weight.shape[1], \
  46. self.client_module.attention.num_attention_heads
  47. def attention(self):
  48. if self.inference:
  49. if MegatronLayerPolicy.version == 0:
  50. attention = self.client_module.attention
  51. else:
  52. attention = self.client_module.self_attention
  53. return attention.query_key_value.weight, \
  54. attention.query_key_value.bias, \
  55. attention.dense.weight, \
  56. attention.dense.bias
  57. def mlp(self, moe_type='standard'):
  58. from deepspeed.moe.utils import has_moe_layers
  59. moe, _ = has_moe_layers(self.client_module)
  60. if moe:
  61. moe_experts = self.client_module.mlp.deepspeed_moe.experts.deepspeed_experts if moe_type == 'standard' else \
  62. self.client_module.mlp.moe.deepspeed_moe.experts.deepspeed_experts
  63. num_experts = len(moe_experts)
  64. if moe_type == 'standard':
  65. return [moe_experts[i].dense_h_to_4h.weight for i in range(num_experts)], \
  66. [moe_experts[i].dense_h_to_4h.bias for i in range(num_experts)], \
  67. [moe_experts[i].dense_4h_to_h.weight for i in range(num_experts)], \
  68. [moe_experts[i].dense_4h_to_h.bias for i in range(num_experts)]
  69. else:
  70. return [moe_experts[i].dense_h_to_4h.weight for i in range(num_experts)], \
  71. [moe_experts[i].dense_h_to_4h.bias for i in range(num_experts)], \
  72. [moe_experts[i].dense_4h_to_h.weight for i in range(num_experts)], \
  73. [moe_experts[i].dense_4h_to_h.bias for i in range(num_experts)], \
  74. self.client_module.mlp.mlp.dense_h_to_4h.weight, \
  75. self.client_module.mlp.mlp.dense_h_to_4h.bias, \
  76. self.client_module.mlp.mlp.dense_4h_to_h.weight, \
  77. self.client_module.mlp.mlp.dense_4h_to_h.bias, \
  78. self.client_module.mlp.coefficient.weight
  79. else:
  80. return self.client_module.mlp.dense_h_to_4h.weight, \
  81. self.client_module.mlp.dense_h_to_4h.bias, \
  82. self.client_module.mlp.dense_4h_to_h.weight, \
  83. self.client_module.mlp.dense_4h_to_h.bias
  84. def layernorm(self):
  85. return self.client_module.post_attention_layernorm.weight, \
  86. self.client_module.post_attention_layernorm.bias, \
  87. self.client_module.input_layernorm.weight, \
  88. self.client_module.input_layernorm.bias