opt.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .base import *
  5. from .features import MetaTensorContainer, HybridSplitQKVContainer
  6. from deepspeed.model_implementations.transformers.ds_opt import DeepSpeedOPTInference
  7. import torch
  8. from torch.nn.parameter import Parameter
  9. from ..policy import TransformerPolicy
  10. from ..policy import transformer_param_names
  11. from ..policy import maybe_copy
  12. from ..policy import maybe_copy_qkv
  13. from ..policy import maybe_get_lora
  14. from deepspeed.utils.types import ActivationFuncType
  15. class DS_OPTContainer(MetaTensorContainer, HybridSplitQKVContainer, BaseTransformerContainer):
  16. def __init__(self, **kwargs):
  17. super().__init__(**kwargs)
  18. # All model specific things should be defined here instead of the base class.
  19. def create_module(self, config=None):
  20. _config = config if config is not None else self.ds_model_config
  21. self.module = DeepSpeedOPTInference(_config, mp_group=self.mp_group)
  22. self.module.config.scale_attention = self.scale_attention
  23. return self.module
  24. def set_lora_params(self):
  25. """
  26. Necessary to implement for `HybridEngineContainer`
  27. """
  28. self.lora_params = [
  29. maybe_get_lora(p) for p in [
  30. self.policy.client_module.fc1,
  31. self.policy.client_module.fc2,
  32. self.policy.client_module.self_attn.q_proj,
  33. self.policy.client_module.self_attn.k_proj,
  34. self.policy.client_module.self_attn.v_proj,
  35. self.policy.client_module.self_attn.out_proj,
  36. ]
  37. ]
  38. def set_q_k_v(self):
  39. """
  40. Necessary to implement for `HybridSplitQKVContainer`
  41. """
  42. self.qw = self.policy.client_module.self_attn.q_proj.weight
  43. self.qb = self.policy.client_module.self_attn.q_proj.bias
  44. self.kw = self.policy.client_module.self_attn.k_proj.weight
  45. self.kb = self.policy.client_module.self_attn.k_proj.bias
  46. self.vw = self.policy.client_module.self_attn.v_proj.weight
  47. self.vb = self.policy.client_module.self_attn.v_proj.bias
  48. def get_lora_matched_pair(self):
  49. fc1_lora, fc2_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params()
  50. ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (out_lora, self.dense_w), (q_lora, self.qw),
  51. (k_lora, self.kw), (v_lora, self.vw)]
  52. return ret
  53. def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
  54. param_names = (
  55. 'self_attn.q_proj.weight', \
  56. 'self_attn.k_proj.weight', \
  57. 'self_attn.v_proj.weight', \
  58. 'self_attn.q_proj.bias', \
  59. 'self_attn.k_proj.bias', \
  60. 'self_attn.v_proj.bias', \
  61. 'self_attn.out_proj.weight', \
  62. 'self_attn.out_proj.bias', \
  63. 'fc1.weight', \
  64. 'fc1.bias', \
  65. 'fc2.weight', \
  66. 'fc2.bias', \
  67. 'final_layer_norm.weight', \
  68. 'final_layer_norm.bias', \
  69. 'self_attn_layer_norm.weight', \
  70. 'self_attn_layer_norm.bias'
  71. )
  72. for i in range(0, 6, 3):
  73. maybe_copy_qkv(module.attention,
  74. sd,
  75. weight_quantizer,
  76. mp_replace,
  77. transformer_param_names[i // 3],
  78. [prefix + param_names[i], prefix + param_names[i + 1], prefix + param_names[i + 2]],
  79. split_qkv=self.policy.split_qkv)
  80. for i in range(6, 8):
  81. maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4],
  82. prefix + param_names[i])
  83. for i in range(8, 14):
  84. maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4],
  85. prefix + param_names[i])
  86. for i in range(14, 16):
  87. maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4],
  88. prefix + param_names[i])
  89. class HFOPTLayerPolicy(TransformerPolicy):
  90. _orig_layer_class = None
  91. def __init__(self, client_module, inference=True, use_load_prefix=True):
  92. super().__init__(inference, linear_layer=True, pre_attn_norm=True, use_load_prefix=use_load_prefix)
  93. self.client_module = client_module
  94. try:
  95. import transformers
  96. HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer
  97. except:
  98. HFOPTLayerPolicy._orig_layer_class = None
  99. if hasattr(TransformerPolicy, "hf_model_config") and hasattr(TransformerPolicy.hf_model_config,
  100. "activation_function"):
  101. if TransformerPolicy.hf_model_config.activation_function == "relu":
  102. self.mlp_act_func_type = ActivationFuncType.ReLU
  103. elif TransformerPolicy.hf_model_config.activation_function in ["gelu", "gelu_new"]:
  104. self.mlp_act_func_type = ActivationFuncType.GELU
  105. else:
  106. raise ValueError("Unsupported activation function: {}".format(
  107. TransformerPolicy.hf_model_config.activation_function))
  108. else:
  109. self.mlp_act_func_type = ActivationFuncType.ReLU # default
  110. def get_hidden_heads(self):
  111. return self.client_module.self_attn.embed_dim, \
  112. self.client_module.self_attn.num_heads, \
  113. self.client_module.self_attn_layer_norm.eps, \
  114. DEFAULT_INTERMEDIATE_SIZE
  115. def attention(self, enable_training=False):
  116. qw = self.client_module.self_attn.q_proj.weight
  117. qb = self.client_module.self_attn.q_proj.bias
  118. kw = self.client_module.self_attn.k_proj.weight
  119. kb = self.client_module.self_attn.k_proj.bias
  120. vw = self.client_module.self_attn.v_proj.weight
  121. vb = self.client_module.self_attn.v_proj.bias
  122. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
  123. qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
  124. return qkvw, \
  125. qkvb, \
  126. self.client_module.self_attn.out_proj.weight, \
  127. self.client_module.self_attn.out_proj.bias
  128. def mlp(self, enable_training=False):
  129. return self.client_module.fc1.weight, \
  130. self.client_module.fc1.bias, \
  131. self.client_module.fc2.weight, \
  132. self.client_module.fc2.bias
  133. def layernorm(self):
  134. return self.client_module.final_layer_norm.weight, \
  135. self.client_module.final_layer_norm.bias, \
  136. self.client_module.self_attn_layer_norm.weight, \
  137. self.client_module.self_attn_layer_norm.bias