policy.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from abc import ABC, abstractmethod
  5. from deepspeed.utils.types import ActivationFuncType, NormType
  6. import torch
  7. from deepspeed.accelerator import get_accelerator
  8. transformer_param_names = (
  9. 'attn_qkvw', \
  10. 'attn_qkvb', \
  11. 'attn_ow' , \
  12. 'attn_ob', \
  13. 'inter_w', \
  14. 'inter_b', \
  15. 'output_w', \
  16. 'output_b', \
  17. 'attn_nw', \
  18. 'attn_nb', \
  19. 'norm_w', \
  20. 'norm_b')
  21. class DSPolicy(ABC):
  22. _orig_layer_class = None
  23. def __init__(self):
  24. self.cuda_graph_supported = False
  25. @abstractmethod
  26. def attention(self):
  27. """
  28. Returns attention qkv and dense parameters
  29. weight: (3*hidden, hidden) and (hidden, hidden)
  30. bias: (3*hidden) and (hidden)
  31. """
  32. raise NotImplementedError
  33. class TransformerPolicy(DSPolicy):
  34. # a static class variable containing the HuggingFace model configuration.
  35. # see e.g., transformers.models.opt.configuration_opt.OPTConfig
  36. hf_model_config = None
  37. def __init__(
  38. self,
  39. inference=True,
  40. linear_layer=True,
  41. scale_attention=True,
  42. megatron_v2=False,
  43. use_mup=False,
  44. # the type of activation function used in MLP
  45. mlp_act_func_type=ActivationFuncType.GELU,
  46. # applies layer norm before attention if `pre_attn_norm` is set to True
  47. pre_attn_norm=True,
  48. # this flag shows whether or not using prefix in loading the checkpoint
  49. use_load_prefix=False,
  50. # whether or not the qkv is stored in the split-format
  51. split_qkv=True,
  52. # Type of normalization to perform
  53. norm_type=NormType.LayerNorm):
  54. super().__init__()
  55. self.cuda_graph_supported = False
  56. self.inference = inference
  57. self.linear_layer = linear_layer
  58. self.scale_attention = scale_attention
  59. self.is_megatron_v2 = megatron_v2
  60. self.use_mup = use_mup
  61. self.mlp_act_func_type = mlp_act_func_type
  62. self.pre_attn_norm = pre_attn_norm
  63. self.use_load_prefix = use_load_prefix
  64. self.split_qkv = split_qkv
  65. self.norm_type = norm_type
  66. @abstractmethod
  67. def attention(self):
  68. """
  69. Returns attention qkv and dense parameters
  70. weight: (3*hidden, hidden) and (hidden, hidden)
  71. bias: (3*hidden) and (hidden)
  72. """
  73. raise NotImplementedError
  74. @abstractmethod
  75. def get_hidden_heads(self):
  76. """
  77. return hidden_size and number of heads
  78. """
  79. raise NotImplementedError
  80. @abstractmethod
  81. def mlp(self):
  82. """
  83. Returns mlp intermediate and output
  84. weight: (intermediate, hidden) and (hidden, intermediate)
  85. bias: (intermediate) and (hidden)
  86. """
  87. raise NotImplementedError
  88. @abstractmethod
  89. def layernorm(self):
  90. """
  91. Returns LayerNorms used in transformer layer
  92. Post-Attention and pre/post layer norm
  93. gamma and beta with shape: (hidden)
  94. """
  95. raise NotImplementedError
  96. # TODO (lekurile): This function exists in base container as well, consolidate as some point
  97. def transpose(data):
  98. with torch.no_grad():
  99. data = data.contiguous()
  100. data1 = data.transpose(-1, -2).reshape(-1)
  101. data.reshape(-1).copy_(data1)
  102. data1 = None
  103. return data.reshape(data.shape[-1], data.shape[-2])
  104. # TODO (lekurile): This function exists in megatron feature container as well, consolidate as some point
  105. def _transpose(x, heads=1, mp_replace=None):
  106. heads = heads // mp_replace.mp_size # type: ignore
  107. outer_dim = -1
  108. attention_head_size = x.shape[outer_dim] // heads
  109. new_x_shape = x.size()[:outer_dim] + (heads, attention_head_size)
  110. x_1 = x.view(*new_x_shape)
  111. (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=-1)
  112. if len(q.shape) > 2:
  113. new_shape = (q.shape[0], ) + (-1, )
  114. return torch.cat((q.reshape(new_shape), k.reshape(new_shape), v.reshape(new_shape)),
  115. dim=outer_dim).reshape(x.shape)
  116. else:
  117. return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape)
  118. # This checks if the parameter exits in the checkpoint file and maybe copies it into the corresponding destination tensor.
  119. # Note that not all parameters are saved in one checkpoint, that's why we always need to check if they exist!
  120. def maybe_copy(module,
  121. sd,
  122. weight_quantizer,
  123. mp_replace,
  124. dst_name,
  125. src_name,
  126. qkv=False,
  127. megatron_v2=False,
  128. split_qkv=False,
  129. heads=1):
  130. if src_name in sd:
  131. dst = getattr(module, dst_name)
  132. tmp = sd[src_name]
  133. if len(dst.shape) == 1:
  134. if split_qkv:
  135. dst = mp_replace.strided_copy(dst, tmp, num_splits=3)
  136. else:
  137. dst = mp_replace.copy(dst, tmp)
  138. if qkv and megatron_v2:
  139. dst = torch.nn.parameter.Parameter(_transpose(dst, heads=heads, mp_replace=mp_replace).contiguous())
  140. else:
  141. if split_qkv:
  142. dst = mp_replace.strided_copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \
  143. (transpose(tmp).contiguous())), num_splits=3, int8=weight_quantizer.q_int8)
  144. else:
  145. if qkv and megatron_v2:
  146. tmp = _transpose(transpose(tmp), heads=heads, mp_replace=mp_replace).contiguous()
  147. if weight_quantizer.q_int8:
  148. tmp = transpose(tmp)
  149. dst = mp_replace.copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \
  150. transpose(tmp)), int8=weight_quantizer.q_int8)
  151. setattr(module, dst_name, dst)
  152. # Extending the maybe_copy function for when the q, k, and v are in separate parameters!
  153. def maybe_copy_qkv(module, sd, weight_quantizer, mp_replace, dst_name, src_names, split_qkv=False):
  154. if src_names[0] in sd:
  155. q = sd[src_names[0]]
  156. k = sd[src_names[1]]
  157. v = sd[src_names[2]]
  158. qkv_data = torch.cat((q, k, v), dim=0)
  159. dst = getattr(module, dst_name)
  160. if len(dst.shape) == 1:
  161. if split_qkv:
  162. dst = mp_replace.strided_copy(dst, qkv_data.contiguous(), num_splits=3)
  163. else:
  164. dst = mp_replace.copy(dst, qkv_data)
  165. else:
  166. if split_qkv:
  167. dst = mp_replace.strided_copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \
  168. ((transpose(qkv_data)).contiguous())), num_splits=3, int8=weight_quantizer.q_int8)
  169. else:
  170. dst = mp_replace.copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \
  171. transpose(qkv_data)), int8=weight_quantizer.q_int8)
  172. setattr(module, dst_name, dst)
  173. # Extending the `maybe_copy` function for when mlp1 is in separate parameters for GeGLU
  174. def maybe_copy_geglu(module, sd, weight_quantizer, mp_replace, dst_name, src_names):
  175. if src_names[0] in sd:
  176. reg_proj = sd[src_names[0]]
  177. gate_proj = sd[src_names[1]]
  178. mlp1_data = torch.cat((reg_proj, gate_proj), dim=0)
  179. dst = getattr(module, dst_name)
  180. dst = mp_replace.strided_copy(dst, weight_quantizer.quantize(mlp1_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \
  181. transpose(mlp1_data)), num_splits=2, int8=weight_quantizer.q_int8)
  182. setattr(module, dst_name, dst)
  183. def pack_lora_weights(p):
  184. return [
  185. p.lora_right_weight, \
  186. p.lora_left_weight, \
  187. p.lora_scaling
  188. ]
  189. def maybe_get_lora(p):
  190. if hasattr(p, 'lora_right_weight'):
  191. lora_param = pack_lora_weights(p)
  192. else:
  193. lora_param = []
  194. return lora_param