internlm.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import importlib
  5. import torch
  6. from torch.nn.parameter import Parameter
  7. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  8. from deepspeed.utils.types import ActivationFuncType, NormType
  9. from ..policy import (TransformerPolicy, maybe_copy, maybe_copy_geglu, maybe_copy_qkv, maybe_get_lora,
  10. transformer_param_names)
  11. from .base import *
  12. from .features import HybridGatedMLPContainer, HybridSplitQKVContainer
  13. class DS_InternLMContainer(HybridGatedMLPContainer, HybridSplitQKVContainer, BaseTransformerContainer):
  14. def __init__(self, **kwargs):
  15. super().__init__(**kwargs)
  16. # All model specific things should be defined here instead of the base class.
  17. def create_module(self, config=None):
  18. _config = config if config is not None else self.ds_model_config
  19. _config.rotate_half = True
  20. _config.rotate_every_two = False
  21. _config.rotary_dim = self.hidden_size // self.num_attention_heads
  22. self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
  23. return self.module
  24. def set_lora_params(self):
  25. """
  26. Necessary to implement for `HybridEngineContainer`
  27. """
  28. self.lora_params = [
  29. maybe_get_lora(p) for p in [
  30. self.policy.client_module.mlp.up_proj.weight, self.policy.client_module.mlp.gate_proj.weight,
  31. self.policy.client_module.mlp.down_proj.weight, self.policy.client_module.self_attn.q_proj.weight,
  32. self.policy.client_module.self_attn.k_proj.weight, self.policy.client_module.self_attn.v_proj.weight,
  33. self.policy.client_module.self_attn.o_proj.weight
  34. ]
  35. ]
  36. def get_lora_matched_pair(self):
  37. up_proj_lora, gate_proj_lora, down_proj_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params()
  38. ret = [(up_proj_lora, self.inter_up_w), (gate_proj_lora, self.inter_gate_w), (down_proj_lora, self._4hh_w),
  39. (out_lora, self.dense_w), (q_lora, self.qw), (k_lora, self.kw), (v_lora, self.vw)]
  40. return ret
  41. def set_q_k_v(self):
  42. """
  43. Necessary to implement for `HybridSplitQKVContainer`
  44. """
  45. self.qw = self.policy.client_module.self_attn.q_proj.weight
  46. self.qb = self.policy.client_module.self_attn.q_proj.bias
  47. self.kw = self.policy.client_module.self_attn.k_proj.weight
  48. self.kb = self.policy.client_module.self_attn.k_proj.bias
  49. self.vw = self.policy.client_module.self_attn.v_proj.weight
  50. self.vb = self.policy.client_module.self_attn.v_proj.bias
  51. def set_mlp_gate(self):
  52. """
  53. Necessary to implement for `HybridGatedMLPContainer`
  54. """
  55. self.inter_up_w = self.policy.client_module.mlp.up_proj.weight
  56. self.inter_up_b = None
  57. self.inter_gate_w = self.policy.client_module.mlp.gate_proj.weight
  58. self.inter_gate_b = None
  59. def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
  60. param_names = (
  61. 'self_attn.q_proj.weight', \
  62. 'self_attn.k_proj.weight', \
  63. 'self_attn.v_proj.weight', \
  64. 'self_attn.o_proj.weight', \
  65. 'mlp.up_proj.weight', \
  66. 'mlp.gate_proj.weight', \
  67. 'mlp.down_proj.weight', \
  68. 'input_layernorm.weight', \
  69. 'post_attention_layernorm.weight'
  70. 'self_attn.q_proj.bias', \
  71. 'self_attn.k_proj.bias', \
  72. 'self_attn.v_proj.bias', \
  73. 'self_attn.o_proj.bias', \
  74. )
  75. maybe_copy_qkv(module.attention,
  76. sd,
  77. weight_quantizer,
  78. mp_replace,
  79. 'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]],
  80. split_qkv=self.policy.split_qkv)
  81. maybe_copy_qkv(module.attention,
  82. sd,
  83. weight_quantizer,
  84. mp_replace,
  85. 'attn_qkvb', [prefix + param_names[9], prefix + param_names[10], prefix + param_names[11]],
  86. split_qkv=self.policy.split_qkv)
  87. maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[2],
  88. prefix + param_names[3])
  89. maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[3],
  90. prefix + param_names[12])
  91. maybe_copy_geglu(module.mlp, sd, weight_quantizer, mp_replace, 'inter_w',
  92. [prefix + param_names[4], prefix + param_names[5]])
  93. maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, 'output_w', prefix + param_names[6])
  94. maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[8], prefix + param_names[7])
  95. maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[10], prefix + param_names[8])
  96. class InternLMLayerPolicy(TransformerPolicy):
  97. _orig_layer_class = []
  98. _orig_layer_class_inited = False
  99. def __init__(self, client_module, inference=True):
  100. super().__init__(
  101. inference,
  102. mlp_act_func_type=ActivationFuncType.GATED_SILU,
  103. norm_type=NormType.RMSNorm,
  104. )
  105. self.client_module = client_module
  106. self._init_orig_layer_class_once()
  107. def _init_orig_layer_class_once(self):
  108. if InternLMLayerPolicy._orig_layer_class_inited:
  109. return
  110. for sub_pkg in ['', '.internlm-7b', '.internlm-chat-7b']:
  111. try:
  112. from transformers.utils import TRANSFORMERS_DYNAMIC_MODULE_NAME
  113. module = importlib.import_module(f"{TRANSFORMERS_DYNAMIC_MODULE_NAME}{sub_pkg}.modeling_internlm")
  114. if module.InternLMDecoderLayer not in InternLMLayerPolicy._orig_layer_class:
  115. InternLMLayerPolicy._orig_layer_class.append(module.InternLMDecoderLayer)
  116. except ImportError:
  117. continue
  118. InternLMLayerPolicy._orig_layer_class_inited = True
  119. def get_hidden_heads(self):
  120. return self.client_module.self_attn.q_proj.weight.shape[1], \
  121. self.client_module.self_attn.num_heads, \
  122. self.client_module.input_layernorm.variance_epsilon, \
  123. self.client_module.mlp.gate_proj.weight.shape[0]
  124. def attention(self, enable_training=False):
  125. qw = self.client_module.self_attn.q_proj.weight
  126. kw = self.client_module.self_attn.k_proj.weight
  127. vw = self.client_module.self_attn.v_proj.weight
  128. qb = self.client_module.self_attn.q_proj.bias
  129. kb = self.client_module.self_attn.k_proj.bias
  130. vb = self.client_module.self_attn.v_proj.bias
  131. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
  132. qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
  133. return qkvw, \
  134. qkvb, \
  135. self.client_module.self_attn.o_proj.weight, \
  136. self.client_module.self_attn.o_proj.bias
  137. def mlp(self, enable_training=False):
  138. mlp1_up = self.client_module.mlp.up_proj.weight
  139. mlp1_gate = self.client_module.mlp.gate_proj.weight
  140. mlp2 = self.client_module.mlp.down_proj.weight
  141. mlp1 = Parameter(torch.cat((mlp1_up, mlp1_gate), dim=0), requires_grad=enable_training)
  142. return mlp1, None, mlp2, None
  143. def layernorm(self):
  144. return self.client_module.post_attention_layernorm.weight, \
  145. None, \
  146. self.client_module.input_layernorm.weight, \
  147. None