ds_transformer.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import torch.nn as nn
  6. from deepspeed import comm as dist
  7. from deepspeed.utils.logging import log_dist
  8. from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP
  9. from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
  10. from deepspeed.accelerator import get_accelerator
  11. from deepspeed.ops.op_builder import InferenceBuilder
  12. import deepspeed
  13. if deepspeed.HAS_TRITON:
  14. from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
  15. from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention
  16. inference_module = None
  17. class DeepSpeedTransformerInference(nn.Module):
  18. """Initialize the DeepSpeed Transformer Layer.
  19. Arguments:
  20. layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
  21. layer_id will be 0,1,2...23 when each layer object is instantiated
  22. config: An object of DeepSpeedInferenceConfig
  23. mp_group: Model parallelism group initialized on the modeling side.
  24. quantize_scales: This argument groups all the layers' scales used for quantization
  25. quantize_groups: Number of groups used for quantizing the model
  26. merge_count: Shows the number of model-parallel checkpoints merged before running inference.
  27. We use this argument to control the quantization scale for the model parameters if a bigger
  28. quantize-grouping than 1 is used.
  29. mlp_extra_grouping: This flag is used to show a 2x higher number of groups used for the MLP part
  30. of a Transformer layer. We use this feature for quantization to reduce the convergence impact
  31. for specific downstream tasks.
  32. """
  33. layer_id = 0
  34. def __init__(self,
  35. config,
  36. mp_group=None,
  37. quantize_scales=None,
  38. quantize_groups=1,
  39. merge_count=1,
  40. mlp_extra_grouping=False):
  41. super(DeepSpeedTransformerInference, self).__init__()
  42. self.config = config
  43. self.config.layer_id = DeepSpeedTransformerInference.layer_id
  44. DeepSpeedTransformerInference.layer_id += 1
  45. data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype
  46. global inference_module
  47. if inference_module is None:
  48. builder = InferenceBuilder()
  49. inference_module = builder.load()
  50. if DeepSpeedTransformerInference.layer_id == 1:
  51. log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
  52. if deepspeed.HAS_TRITON and self.config.use_triton:
  53. log_dist(f"Injecting Triton kernels ...", [0])
  54. if self.config.bigscience_bloom:
  55. self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
  56. assert not self.config.use_triton
  57. else:
  58. if deepspeed.HAS_TRITON and self.config.use_triton:
  59. self.attention = TritonSelfAttention(self.config)
  60. else:
  61. self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
  62. merge_count)
  63. if deepspeed.HAS_TRITON and self.config.use_triton:
  64. self.mlp = TritonMLP(self.config)
  65. else:
  66. self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
  67. mlp_extra_grouping)
  68. device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
  69. if self.config.set_empty_params:
  70. self.norm_w = None
  71. self.norm_b = None
  72. else:
  73. self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
  74. requires_grad=False)
  75. self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
  76. requires_grad=False)
  77. self.layer_past = None
  78. try:
  79. if config.dtype == torch.float32:
  80. self.allocate_workspace = inference_module.allocate_workspace_fp32
  81. elif config.dtype == torch.bfloat16:
  82. self.allocate_workspace = inference_module.allocate_workspace_bf16
  83. else:
  84. self.allocate_workspace = inference_module.allocate_workspace_fp32
  85. self._alloc_workspace = True
  86. except AttributeError:
  87. self.allocate_workspace = None
  88. self._alloc_workspace = False
  89. @classmethod
  90. def reset_cache(cls):
  91. if inference_module is not None:
  92. inference_module.reset_cache()
  93. def forward(
  94. self,
  95. input=None,
  96. input_mask=None,
  97. attention_mask=None,
  98. attn_mask=None,
  99. head_mask=None,
  100. layer_past=None,
  101. get_key_value=False,
  102. get_present=False,
  103. encoder_output=None,
  104. enc_dec_attn_mask=None,
  105. x=None,
  106. encoder_hidden_states=None,
  107. encoder_attention_mask=None,
  108. use_cache=False,
  109. alibi=None,
  110. output_attentions=False,
  111. # TODO(arashb): 'layer_head_mask' and 'past_key_value' are only added to satisfy the OPT models API.
  112. # This needs to be redesigned later!
  113. layer_head_mask=None,
  114. past_key_value=None,
  115. **kwargs):
  116. if x is not None:
  117. input = x
  118. if "hidden_states" in kwargs:
  119. input = kwargs["hidden_states"]
  120. input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask
  121. # Allocate memory only on first layer forward
  122. if self.config.layer_id == 0 and self._alloc_workspace:
  123. self.allocate_workspace(self.config.hidden_size, self.config.heads,
  124. input.size()[1],
  125. input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
  126. self.config.bigscience_bloom,
  127. dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
  128. self.config.min_out_tokens)
  129. self._alloc_workspace = False
  130. get_present = (get_present or get_key_value or use_cache)
  131. input_mask = input_mask if attention_mask is None else attention_mask
  132. # We set the prev key/value to None when there is a prompt
  133. if input.shape[1] > 1:
  134. self.layer_past = None
  135. layer_past = layer_past if layer_past is not None else self.layer_past
  136. head_mask = layer_head_mask if layer_head_mask is not None else head_mask
  137. attn_mask = None
  138. if isinstance(input, tuple):
  139. attn_mask = input[1]
  140. input = input[0]
  141. input_type = input.dtype
  142. if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \
  143. and input.dtype == torch.float:
  144. target_dtype = torch.half if self.dtype == torch.int8 else self.dtype
  145. input = input.to(target_dtype)
  146. with torch.no_grad():
  147. attention_output, key, value, context_outputtn_ctx, inp_norm = \
  148. self.attention(input,
  149. input_mask,
  150. head_mask,
  151. layer_past,
  152. get_present,
  153. encoder_hidden_states,
  154. encoder_attention_mask,
  155. output_attentions,
  156. self.norm_w,
  157. self.norm_b,
  158. alibi)
  159. presents = (key, value)
  160. self.layer_past = presents if layer_past is None else None
  161. output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
  162. if not self.config.pre_layer_norm:
  163. output = inference_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
  164. output = output.to(input_type)
  165. if get_present:
  166. output = (output, presents)
  167. if self.config.return_single_tuple:
  168. return (output, )
  169. elif self.config.return_tuple:
  170. return output if type(output) is tuple else (output, attn_mask)
  171. else:
  172. return output