gptneo.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .base import *
  5. from .features.meta_tensor import MetaTensorContainer
  6. from .features.split_qkv import HybridSplitQKVContainer
  7. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  8. import torch
  9. from torch.nn.parameter import Parameter
  10. from ..policy import TransformerPolicy
  11. from ..policy import transformer_param_names
  12. from ..policy import maybe_copy
  13. from ..policy import maybe_copy_qkv
  14. from ..policy import maybe_get_lora
  15. class DS_GPTNEOContainer(MetaTensorContainer, HybridSplitQKVContainer, BaseTransformerContainer):
  16. def __init__(self, **kwargs):
  17. super().__init__(**kwargs)
  18. # All model specific things should be defined here instead of the base class.
  19. def create_module(self, config=None):
  20. _config = config if config is not None else self.ds_model_config
  21. self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
  22. self.module.config.scale_attention = self.scale_attention
  23. return self.module
  24. def set_lora_params(self):
  25. """
  26. Necessary to implement for `HybridEngineContainer`
  27. """
  28. self.lora_params = [
  29. maybe_get_lora(p) for p in [
  30. self.policy.client_module.mlp.c_fc, self.policy.client_module.mlp.c_proj,
  31. self.policy.client_module.attn.attention.q_proj, self.policy.client_module.attn.attention.k_proj,
  32. self.policy.client_module.attn.attention.v_proj, self.policy.client_module.attn.attention.out_proj
  33. ]
  34. ]
  35. def set_q_k_v(self):
  36. """
  37. Necessary to implement for `HybridSplitQKVContainer`
  38. """
  39. self.qw = self.policy.client_module.attn.attention.q_proj.weight
  40. self.qb = None
  41. self.kw = self.policy.client_module.attn.attention.k_proj.weight
  42. self.kb = None
  43. self.vw = self.policy.client_module.attn.attention.v_proj.weight
  44. self.vb = None
  45. def get_lora_matched_pair(self):
  46. """
  47. Necessary to implement for `HybridEngineContainer`
  48. """
  49. fc1_lora, fc2_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params()
  50. ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (out_lora, self.dense_w), (q_lora, self.qw),
  51. (k_lora, self.kw), (v_lora, self.vw)]
  52. return ret
  53. def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
  54. param_names = (
  55. 'attn.attention.q_proj.weight', \
  56. 'attn.attention.k_proj.weight', \
  57. 'attn.attention.v_proj.weight', \
  58. 'attn.attention.out_proj.weight', \
  59. 'attn.attention.out_proj.bias', \
  60. 'mlp.c_fc.weight', \
  61. 'mlp.c_fc.bias', \
  62. 'mlp.c_proj.weight', \
  63. 'mlp.c_proj.bias', \
  64. 'ln_2.weight', \
  65. 'ln_2.bias', \
  66. 'ln_1.weight', \
  67. 'ln_1.bias'
  68. )
  69. maybe_copy_qkv(module.attention,
  70. sd,
  71. weight_quantizer,
  72. mp_replace,
  73. 'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]],
  74. split_qkv=self.policy.split_qkv)
  75. for i in range(3, 5):
  76. maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
  77. prefix + param_names[i])
  78. for i in range(5, 11):
  79. maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
  80. prefix + param_names[i])
  81. for i in range(11, 13):
  82. maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
  83. prefix + param_names[i])
  84. class HFGPTNEOLayerPolicy(TransformerPolicy):
  85. def __init__(self, client_module, inference=True):
  86. super().__init__(inference, scale_attention=False)
  87. self.client_module = client_module
  88. try:
  89. import transformers
  90. HFGPTNEOLayerPolicy._orig_layer_class = transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock
  91. except:
  92. HFGPTNEOLayerPolicy._orig_layer_class = None
  93. def get_hidden_heads(self):
  94. return self.client_module.attn.attention.embed_dim, \
  95. self.client_module.attn.attention.num_heads, \
  96. self.client_module.ln_1.eps, \
  97. DEFAULT_INTERMEDIATE_SIZE
  98. def get_q_k_v(self):
  99. return self.client_module.attn.attention.q_proj.weight, \
  100. None, \
  101. self.client_module.attn.attention.k_proj.weight, \
  102. None, \
  103. self.client_module.attn.attention.v_proj.weight, \
  104. None
  105. def attention(self, enable_training=False):
  106. qw = self.client_module.attn.attention.q_proj.weight
  107. kw = self.client_module.attn.attention.k_proj.weight
  108. vw = self.client_module.attn.attention.v_proj.weight
  109. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
  110. return qkvw, \
  111. None, \
  112. self.client_module.attn.attention.out_proj.weight, \
  113. self.client_module.attn.attention.out_proj.bias
  114. def mlp(self, enable_training=False):
  115. return self.client_module.mlp.c_fc.weight, \
  116. self.client_module.mlp.c_fc.bias, \
  117. self.client_module.mlp.c_proj.weight, \
  118. self.client_module.mlp.c_proj.bias
  119. def layernorm(self):
  120. return self.client_module.ln_2.weight, \
  121. self.client_module.ln_2.bias, \
  122. self.client_module.ln_1.weight, \
  123. self.client_module.ln_1.bias