base.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. # Create a container object to save model-specific tensors using the policy file above.
  3. from abc import ABC
  4. import torch
  5. from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig
  6. from deepspeed.accelerator import get_accelerator
  7. class BaseConvolutionContainer(ABC):
  8. # not implemented
  9. def __init__(self):
  10. pass
  11. class BaseTransformerContainer(ABC):
  12. def __init__(self, policy, config, model_config, layer_id, child):
  13. self.policy = policy
  14. self.config = config
  15. self.model_config = model_config
  16. self.layer_id = layer_id
  17. self.child = child
  18. self.megatron_v2 = self.policy.is_megatron_v2
  19. self.scale_attention = self.policy.scale_attention
  20. self.ckpt_load_enabled = False
  21. # configuration for models. todo: can this be moved to a pydantic model config?
  22. self.hidden_size = None
  23. self.num_attention_heads = None
  24. self.mp_size = self.config.tensor_parallel.tp_size
  25. self.pre_layer_norm = self.policy.pre_attn_norm
  26. self.fp16 = False
  27. self.attn_linear_layer = self.policy.linear_layer
  28. self.mlp_linear_layer = self.policy.linear_layer
  29. self.layer_norm_eps = self.model_config.layer_norm_eps if \
  30. hasattr(self.model_config, 'layer_norm_eps') else (self.model_config.layer_norm_epsilon if \
  31. hasattr(self.model_config, 'layer_norm_epsilon') else self.model_config.layernorm_epsilon if \
  32. hasattr(self.model_config, 'layernorm_epsilon') else 1.0e-12)
  33. self.return_tuple = self.config.return_tuple
  34. self.triangular_masking = True
  35. self.local_attention = ((self.model_config.attention_layers[self.layer_id]
  36. == "local") if hasattr(self.model_config,
  37. 'attention_layers') else False)
  38. self.window_size = getattr(self.model_config, "window_size", 1)
  39. self.mlp_act_func_type = self.policy.mlp_act_func_type
  40. self.training_mp_size = self.config.training_mp_size
  41. self.bigscience_bloom = False
  42. self.max_out_tokens = self.config.max_out_tokens
  43. self.scale_attn_by_inverse_layer_idx = getattr(
  44. self.config,
  45. "scale_attn_by_inverse_layer_idx",
  46. False)
  47. self.use_mup = self.policy.use_mup
  48. self.return_single_tuple = False
  49. self.rotary_dim = self.model_config.rotary_dim if hasattr(self.model_config, 'rotary_dim') \
  50. else self.child.attention.rotary_ndims if \
  51. hasattr(self.child, 'attention') and hasattr(self.child.attention,'rotary_ndims') else -1
  52. self.mlp_after_attn = (self.rotary_dim is None or self.rotary_dim < 0)
  53. # Attention tensors
  54. self.qkvw = None
  55. self.qkvb = None
  56. self.dense_w = None
  57. self.dense_b = None
  58. # MLP tensors
  59. self._h4h_w = None
  60. self._h4h_b = None
  61. self._4hh_w = None
  62. self._4hh_b = None
  63. # LayerNorm tensors
  64. self.attn_nw = None
  65. self.attn_nb = None
  66. self.input_nw = None
  67. self.input_nb = None
  68. def create_ds_model_config(self):
  69. self.set_hidden_heads(*self.policy.get_hidden_heads())
  70. assert self.num_attention_heads % self.mp_size == 0,\
  71. "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
  72. "This is because the attention computation is partitioned evenly among the parallel GPUs."
  73. self.ds_model_config = DeepSpeedInferenceConfig(
  74. hidden_size=self.hidden_size,
  75. heads=self.num_attention_heads,
  76. layer_norm_eps=self.layer_norm_eps,
  77. fp16=self.fp16,
  78. pre_layer_norm=self.pre_layer_norm,
  79. mp_size=self.mp_size,
  80. q_int8=self.quantize,
  81. return_tuple=self.return_tuple,
  82. triangular_masking=self.triangular_masking,
  83. local_attention=self.local_attention,
  84. window_size=self.window_size,
  85. rotary_dim=self.rotary_dim,
  86. mlp_after_attn=self.mlp_after_attn,
  87. mlp_act_func_type=self.mlp_act_func_type,
  88. training_mp_size=self.training_mp_size,
  89. bigscience_bloom=self.bigscience_bloom,
  90. max_out_tokens=self.max_out_tokens,
  91. scale_attn_by_inverse_layer_idx=self.scale_attn_by_inverse_layer_idx,
  92. use_mup=self.use_mup,
  93. return_single_tuple=self.return_single_tuple,
  94. )
  95. return self.ds_model_config
  96. def initialize_tensors(self):
  97. # Set the tensors from policy (user module) to container (DS module)
  98. self.set_attention(*self.policy.attention())
  99. self.set_mlp(*self.policy.mlp())
  100. self.set_layernorm(*self.policy.layernorm())
  101. def convert_to_required_dtype(self, dtype):
  102. # Note: converting tensors to fp16 requires that we do it in-place using self.__dict__ and not make a list/dict copy
  103. if dtype == torch.half:
  104. for k, v in self.__dict__.items():
  105. # The list comprehension is used for MoE tensor lists
  106. if isinstance(v, list) and all((isinstance(tensor, torch.Tensor) \
  107. or isinstance(tensor, torch.nn.Parameter)) for tensor in v):
  108. self.__dict__[k] = [moe_tensor.half() for moe_tensor in v]
  109. if isinstance(v, torch.Tensor) or isinstance(v, torch.nn.Parameter):
  110. self.__dict__[k] = v.half()
  111. def set_dtype(self, fp16=False):
  112. self.fp16 = fp16
  113. def set_moe(self, moe=False):
  114. self.moe = moe
  115. def set_tensor_parallel_config(self, mp_size, mp_group):
  116. self.mp_size = mp_size
  117. self.mp_group = mp_group
  118. def set_quantization_config(self, quantize, quantizer):
  119. self.quantize = quantize
  120. self.quantizer = quantizer
  121. def set_hidden_heads(self, hidden_size, num_attention_heads):
  122. self.hidden_size = hidden_size
  123. self.num_attention_heads = num_attention_heads
  124. def set_attention(self, qkvw, qkvb, dense_w, dense_b):
  125. self.qkvw = qkvw
  126. self.qkvb = qkvb
  127. self.dense_w = dense_w
  128. self.dense_b = dense_b
  129. def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b):
  130. self._h4h_w = _h4h_w
  131. self._h4h_b = _h4h_b
  132. self._4hh_w = _4hh_w
  133. self._4hh_b = _4hh_b
  134. def set_layernorm(self, attn_nw, attn_nb, input_nw, input_nb):
  135. self.attn_nw = attn_nw
  136. self.attn_nb = attn_nb
  137. self.input_nw = input_nw
  138. self.input_nb = input_nb
  139. def apply_weight_quantization(self):
  140. # quantize attention weights
  141. self.attention_quantization()
  142. # quantize mlp weights
  143. self.mlp_quantization()
  144. def attention_quantization(self):
  145. self.module.attention.attn_qkvw = self.quantizer.quantize(
  146. self.module.attention.attn_qkvw)
  147. self.module.attention.attn_ow = self.quantizer.quantize(
  148. self.module.attention.attn_ow)
  149. def mlp_quantization(self):
  150. self.module.mlp.inter_w = self.quantizer.quantize(self.module.mlp.inter_w)
  151. self.module.mlp.output_w = self.quantizer.quantize(self.module.mlp.output_w)
  152. def apply_tensor_parallelism(self, mp_replace):
  153. # setup the new Attention module
  154. self.attention_qkv_mp(mp_replace)
  155. self.attention_o_mp(mp_replace)
  156. # setup the new MLP module
  157. self.mlp_inter_mp(mp_replace)
  158. self.mlp_output_mp(mp_replace)
  159. # Apply weight quantization
  160. self.apply_weight_quantization()
  161. def attention_qkv_mp(self, mp_replace):
  162. self.module.attention.attn_qkvw = mp_replace.qkv_copy(
  163. self.module.attention.attn_qkvw,
  164. self.qkvw)
  165. self.module.attention.attn_qkvb = mp_replace.qkv_copy(
  166. self.module.attention.attn_qkvb,
  167. self.qkvb)
  168. def attention_o_mp(self, mp_replace):
  169. self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow,
  170. self.dense_w)
  171. self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob,
  172. self.dense_b)
  173. def mlp_inter_mp(self, mp_replace):
  174. self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w)
  175. self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b)
  176. def mlp_output_mp(self, mp_replace):
  177. self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w)
  178. self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b, self._4hh_b)
  179. def copy_data_to_new_module(self):
  180. if self.attn_nw is None:
  181. self.module.mlp.attn_nw = self.attn_nw
  182. self.module.mlp.attn_nb = self.attn_nb
  183. else:
  184. self.module.mlp.attn_nw.data.copy_(
  185. self.attn_nw.to(get_accelerator().current_device_name()))
  186. self.module.mlp.attn_nb.data.copy_(
  187. self.attn_nb.to(get_accelerator().current_device_name()))
  188. self.module.norm_w.data.copy_(
  189. self.input_nw.to(get_accelerator().current_device_name()))
  190. self.module.norm_b.data.copy_(
  191. self.input_nb.to(get_accelerator().current_device_name()))
  192. def transpose(self):
  193. self.transpose_attention()
  194. self.transpose_mlp()
  195. def transpose_attention(self):
  196. if self.attn_linear_layer:
  197. self.qkvw = self.transpose_impl(self.qkvw.data)
  198. self.dense_w = self.transpose_impl(self.dense_w.data)
  199. def transpose_mlp(self):
  200. if self.mlp_linear_layer:
  201. self._h4h_w = self.transpose_impl(self._h4h_w.data)
  202. self._4hh_w = self.transpose_impl(self._4hh_w.data)
  203. def transpose_impl(self, data):
  204. data = data.contiguous()
  205. data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
  206. data = data.reshape(data.shape[-1], data.shape[-2])
  207. data.to(get_accelerator().current_device_name())
  208. return data