gptj.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .base import *
  5. from .features.meta_tensor import MetaTensorContainer
  6. from .features.split_qkv import HybridSplitQKVContainer
  7. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  8. import torch
  9. from torch.nn.parameter import Parameter
  10. from ..policy import TransformerPolicy
  11. from ..policy import transformer_param_names
  12. from ..policy import maybe_copy
  13. from ..policy import maybe_copy_qkv
  14. from ..policy import maybe_get_lora
  15. class DS_GPTJContainer(MetaTensorContainer, HybridSplitQKVContainer, BaseTransformerContainer):
  16. def __init__(self, **kwargs):
  17. super().__init__(**kwargs)
  18. # All model specific things should be defined here instead of the base class.
  19. def create_module(self, config=None):
  20. _config = config if config is not None else self.ds_model_config
  21. self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
  22. self.module.config.scale_attention = self.scale_attention
  23. return self.module
  24. def set_lora_params(self):
  25. """
  26. Necessary to implement for `HybridEngineContainer`
  27. """
  28. self.lora_params = [
  29. maybe_get_lora(p) for p in [
  30. self.policy.client_module.mlp.fc_in, self.policy.client_module.mlp.fc_out,
  31. self.policy.client_module.attn.q_proj, self.policy.client_module.attn.k_proj,
  32. self.policy.client_module.attn.v_proj, self.policy.client_module.attn.out_proj
  33. ]
  34. ]
  35. def get_lora_matched_pair(self):
  36. fc1_lora, fc2_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params()
  37. ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (out_lora, self.dense_w), (q_lora, self.qw),
  38. (k_lora, self.kw), (v_lora, self.vw)]
  39. return ret
  40. def set_q_k_v(self):
  41. """
  42. Necessary to implement for `HybridSplitQKVContainer`
  43. """
  44. self.qw = self.policy.client_module.attn.q_proj.weight
  45. self.qb = None
  46. self.kw = self.policy.client_module.attn.k_proj.weight
  47. self.kb = None
  48. self.vw = self.policy.client_module.attn.v_proj.weight
  49. self.vb = None
  50. def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
  51. param_names = (
  52. 'attn.q_proj.weight', \
  53. 'attn.k_proj.weight', \
  54. 'attn.v_proj.weight', \
  55. 'attn.out_proj.weight', \
  56. 'mlp.fc_in.weight', \
  57. 'mlp.fc_in.bias', \
  58. 'mlp.fc_out.weight', \
  59. 'mlp.fc_out.bias', \
  60. 'ln_1.weight', \
  61. 'ln_1.bias'
  62. )
  63. maybe_copy_qkv(module.attention,
  64. sd,
  65. weight_quantizer,
  66. mp_replace,
  67. 'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]],
  68. split_qkv=self.policy.split_qkv)
  69. for i in range(3, 4):
  70. maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
  71. prefix + param_names[i])
  72. for i in range(4, 8):
  73. maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i],
  74. prefix + param_names[i])
  75. for i in range(8, 10):
  76. maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i + 2],
  77. prefix + param_names[i])
  78. class HFGPTJLayerPolicy(TransformerPolicy):
  79. _orig_layer_class = None
  80. def __init__(self, client_module, inference=True):
  81. super().__init__(inference, scale_attention=True)
  82. self.client_module = client_module
  83. try:
  84. import transformers
  85. HFGPTJLayerPolicy._orig_layer_class = transformers.models.gptj.modeling_gptj.GPTJBlock
  86. except:
  87. HFGPTJLayerPolicy._orig_layer_class = None
  88. def get_hidden_heads(self):
  89. return self.client_module.attn.embed_dim, \
  90. self.client_module.attn.num_attention_heads, \
  91. self.client_module.ln_1.eps, \
  92. DEFAULT_INTERMEDIATE_SIZE
  93. def attention(self, enable_training=False):
  94. qw = self.client_module.attn.q_proj.weight
  95. kw = self.client_module.attn.k_proj.weight
  96. vw = self.client_module.attn.v_proj.weight
  97. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
  98. return qkvw, \
  99. None, \
  100. self.client_module.attn.out_proj.weight, \
  101. None,
  102. def mlp(self, enable_training=False):
  103. return self.client_module.mlp.fc_in.weight, \
  104. self.client_module.mlp.fc_in.bias, \
  105. self.client_module.mlp.fc_out.weight, \
  106. self.client_module.mlp.fc_out.bias
  107. def layernorm(self):
  108. return None, \
  109. None, \
  110. self.client_module.ln_1.weight, \
  111. self.client_module.ln_1.bias