llama.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .base import *
  5. from .features import HybridSplitQKVContainer, HybridGatedMLPContainer, MetaTensorContainer
  6. from deepspeed.utils.types import ActivationFuncType, NormType
  7. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  8. import torch
  9. from torch.nn.parameter import Parameter
  10. from ..policy import (
  11. TransformerPolicy,
  12. transformer_param_names,
  13. maybe_copy,
  14. maybe_copy_qkv,
  15. maybe_copy_geglu,
  16. maybe_get_lora,
  17. )
  18. class DS_LLAMAContainer(MetaTensorContainer, HybridGatedMLPContainer, HybridSplitQKVContainer,
  19. BaseTransformerContainer):
  20. def __init__(self, **kwargs):
  21. super().__init__(**kwargs)
  22. # All model specific things should be defined here instead of the base class.
  23. def create_module(self, config=None):
  24. _config = config if config is not None else self.ds_model_config
  25. _config.rotate_half = True
  26. _config.rotate_every_two = False
  27. _config.rotary_dim = self.hidden_size // self.num_attention_heads
  28. self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
  29. return self.module
  30. def set_lora_params(self):
  31. """
  32. Necessary to implement for `HybridEngineContainer`
  33. """
  34. self.lora_params = [
  35. maybe_get_lora(p) for p in [
  36. self.policy.client_module.mlp.up_proj.weight, self.policy.client_module.mlp.gate_proj.weight,
  37. self.policy.client_module.mlp.down_proj.weight, self.policy.client_module.self_attn.q_proj.weight,
  38. self.policy.client_module.self_attn.k_proj.weight, self.policy.client_module.self_attn.v_proj.weight,
  39. self.policy.client_module.self_attn.o_proj.weight
  40. ]
  41. ]
  42. def get_lora_matched_pair(self):
  43. up_proj_lora, gate_proj_lora, down_proj_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params()
  44. ret = [(up_proj_lora, self.inter_up_w), (gate_proj_lora, self.inter_gate_w), (down_proj_lora, self._4hh_w),
  45. (out_lora, self.dense_w), (q_lora, self.qw), (k_lora, self.kw), (v_lora, self.vw)]
  46. return ret
  47. def set_q_k_v(self):
  48. """
  49. Necessary to implement for `HybridSplitQKVContainer`
  50. """
  51. self.qw = self.policy.client_module.self_attn.q_proj.weight
  52. self.qb = None
  53. self.kw = self.policy.client_module.self_attn.k_proj.weight
  54. self.kb = None
  55. self.vw = self.policy.client_module.self_attn.v_proj.weight
  56. self.vb = None
  57. def set_mlp_gate(self):
  58. """
  59. Necessary to implement for `HybridGatedMLPContainer`
  60. """
  61. self.inter_up_w = self.policy.client_module.mlp.up_proj.weight
  62. self.inter_up_b = None
  63. self.inter_gate_w = self.policy.client_module.mlp.gate_proj.weight
  64. self.inter_gate_b = None
  65. def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
  66. param_names = (
  67. 'self_attn.q_proj.weight', \
  68. 'self_attn.k_proj.weight', \
  69. 'self_attn.v_proj.weight', \
  70. 'self_attn.o_proj.weight', \
  71. 'mlp.up_proj.weight', \
  72. 'mlp.gate_proj.weight', \
  73. 'mlp.down_proj.weight', \
  74. 'post_attention_layernorm.weight', \
  75. 'input_layernorm.weight',
  76. )
  77. maybe_copy_qkv(module.attention,
  78. sd,
  79. weight_quantizer,
  80. mp_replace,
  81. 'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]],
  82. split_qkv=self.policy.split_qkv)
  83. for i in range(3, 4):
  84. maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
  85. prefix + param_names[i])
  86. maybe_copy_geglu(module.mlp, sd, weight_quantizer, mp_replace, 'inter_w',
  87. [prefix + param_names[4], prefix + param_names[5]])
  88. maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, 'output_w', prefix + param_names[6])
  89. maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[8], prefix + param_names[7])
  90. maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[10], prefix + param_names[8])
  91. # This line is necessary for proper output when kernels + meta tensors are used in Llama models
  92. # TODO: Investigate root-cause and fix meta tensor loading
  93. module.mlp.output_b = None
  94. class LLAMALayerPolicy(TransformerPolicy):
  95. def __init__(self, client_module, inference=True):
  96. super().__init__(
  97. inference,
  98. mlp_act_func_type=ActivationFuncType.GATED_SILU,
  99. norm_type=NormType.RMSNorm,
  100. )
  101. self.client_module = client_module
  102. try:
  103. import transformers
  104. LLAMALayerPolicy._orig_layer_class = transformers.models.llama.modeling_llama.LlamaDecoderLayer # type: ignore
  105. except:
  106. LLAMALayerPolicy._orig_layer_class = None
  107. def get_hidden_heads(self):
  108. hidden_heads = (
  109. getattr(self.client_module.self_attn.q_proj.weight, "ds_shape",
  110. self.client_module.self_attn.q_proj.weight.shape)[1],
  111. self.client_module.self_attn.num_heads,
  112. self.client_module.input_layernorm.variance_epsilon,
  113. getattr(self.client_module.mlp.gate_proj.weight, "ds_shape",
  114. self.client_module.mlp.gate_proj.weight.shape)[0],
  115. )
  116. return hidden_heads
  117. def attention(self, enable_training=False):
  118. qw = self.client_module.self_attn.q_proj.weight
  119. kw = self.client_module.self_attn.k_proj.weight
  120. vw = self.client_module.self_attn.v_proj.weight
  121. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
  122. return qkvw, \
  123. None, \
  124. self.client_module.self_attn.o_proj.weight, \
  125. None
  126. def mlp(self, enable_training=False):
  127. mlp1_up = self.client_module.mlp.up_proj.weight
  128. mlp1_gate = self.client_module.mlp.gate_proj.weight
  129. mlp2 = self.client_module.mlp.down_proj.weight
  130. mlp1 = Parameter(torch.cat((mlp1_up, mlp1_gate), dim=0), requires_grad=enable_training)
  131. return mlp1, None, mlp2, None
  132. def layernorm(self):
  133. return self.client_module.post_attention_layernorm.weight, \
  134. None, \
  135. self.client_module.input_layernorm.weight, \
  136. None