ds_transformer.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. '''
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. '''
  4. import torch
  5. import torch.nn as nn
  6. from deepspeed import comm as dist
  7. from deepspeed.utils.logging import log_dist
  8. from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP
  9. from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
  10. from deepspeed.accelerator import get_accelerator
  11. from deepspeed.ops.op_builder import InferenceBuilder
  12. inference_cuda_module = None
  13. class DeepSpeedTransformerInference(nn.Module):
  14. """Initialize the DeepSpeed Transformer Layer.
  15. Arguments:
  16. layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
  17. layer_id will be 0,1,2...23 when each layer object is instantiated
  18. config: An object of DeepSpeedInferenceConfig
  19. mp_group: Model parallelism group initialized on the modeling side.
  20. quantize_scales: This argument groups all the layers' scales used for quantization
  21. quantize_groups: Number of groups used for quantizing the model
  22. merge_count: Shows the number of model-parallel checkpoints merged before running inference.
  23. We use this argument to control the quantization scale for the model parameters if a bigger
  24. quantize-grouping than 1 is used.
  25. mlp_extra_grouping: This flag is used to show a 2x higher number of groups used for the MLP part
  26. of a Transformer layer. We use this feature for quantization to reduce the convergence impact
  27. for specific downstream tasks.
  28. """
  29. layer_id = 0
  30. def __init__(self,
  31. config,
  32. mp_group=None,
  33. quantize_scales=None,
  34. quantize_groups=1,
  35. merge_count=1,
  36. mlp_extra_grouping=False):
  37. super(DeepSpeedTransformerInference, self).__init__()
  38. self.config = config
  39. self.config.layer_id = DeepSpeedTransformerInference.layer_id
  40. DeepSpeedTransformerInference.layer_id += 1
  41. data_type = torch.half if config.fp16 else torch.float
  42. global inference_cuda_module
  43. if inference_cuda_module is None:
  44. builder = InferenceBuilder()
  45. inference_cuda_module = builder.load()
  46. if DeepSpeedTransformerInference.layer_id == 1:
  47. log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
  48. if self.config.bigscience_bloom:
  49. self.attention = BloomSelfAttention(self.config,
  50. mp_group,
  51. quantize_scales,
  52. quantize_groups,
  53. merge_count)
  54. else:
  55. self.attention = DeepSpeedSelfAttention(self.config,
  56. mp_group,
  57. quantize_scales,
  58. quantize_groups,
  59. merge_count)
  60. self.mlp = DeepSpeedMLP(self.config,
  61. mp_group,
  62. quantize_scales,
  63. quantize_groups,
  64. merge_count,
  65. mlp_extra_grouping)
  66. device = get_accelerator().current_device_name(
  67. ) # if config.bigscience_bloom else 'cpu'
  68. self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size,
  69. dtype=data_type,
  70. device=device),
  71. requires_grad=False)
  72. self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size,
  73. dtype=data_type,
  74. device=device),
  75. requires_grad=False)
  76. self.layer_past = None
  77. self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \
  78. inference_cuda_module.allocate_workspace_fp16
  79. @classmethod
  80. def reset_cache(cls):
  81. if inference_cuda_module is not None:
  82. inference_cuda_module.reset_cache()
  83. def forward(
  84. self,
  85. input=None,
  86. input_mask=None,
  87. attention_mask=None,
  88. attn_mask=None,
  89. head_mask=None,
  90. layer_past=None,
  91. get_key_value=False,
  92. get_present=False,
  93. encoder_output=None,
  94. enc_dec_attn_mask=None,
  95. x=None,
  96. encoder_hidden_states=None,
  97. encoder_attention_mask=None,
  98. use_cache=False,
  99. alibi=None,
  100. output_attentions=False,
  101. # TODO(arashb): 'layer_head_mask' and 'past_key_value' are only added to satisfy the OPT models API.
  102. # This needs to be redesigned later!
  103. layer_head_mask=None,
  104. past_key_value=None):
  105. if x is not None:
  106. input = x
  107. input_mask = (input_mask if attn_mask is None else
  108. attn_mask) if attention_mask is None else attention_mask
  109. # Allocate memory only on first layer forward
  110. if self.config.layer_id == 0:
  111. self.allocate_workspace(self.config.hidden_size,
  112. self.config.heads,
  113. input.size()[1],
  114. input.size()[0],
  115. DeepSpeedTransformerInference.layer_id,
  116. self.config.mp_size,
  117. self.config.bigscience_bloom,
  118. dist.get_rank() if dist.is_initialized() else 0,
  119. self.config.max_out_tokens)
  120. get_present = (get_present or get_key_value or use_cache)
  121. input_mask = input_mask if attention_mask is None else attention_mask
  122. # We set the prev key/value to None when there is a prompt
  123. if input.shape[1] > 1:
  124. self.layer_past = None
  125. layer_past = layer_past if layer_past is not None else self.layer_past
  126. head_mask = layer_head_mask if layer_head_mask is not None else head_mask
  127. attn_mask = None
  128. if isinstance(input, tuple):
  129. attn_mask = input[1]
  130. input = input[0]
  131. input_type = input.dtype
  132. if (self.config.fp16 or self.config.q_int8) \
  133. and input.dtype == torch.float:
  134. input = input.half()
  135. with torch.no_grad():
  136. attention_output, key, value, context_outputtn_ctx, inp_norm = \
  137. self.attention(input,
  138. input_mask,
  139. head_mask,
  140. layer_past,
  141. get_present,
  142. encoder_hidden_states,
  143. encoder_attention_mask,
  144. output_attentions,
  145. self.norm_w,
  146. self.norm_b,
  147. alibi)
  148. presents = (key, value)
  149. self.layer_past = presents if layer_past is None else None
  150. output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
  151. if not self.config.pre_layer_norm:
  152. output = inference_cuda_module.layer_norm(output,
  153. self.norm_w,
  154. self.norm_b,
  155. self.config.epsilon)
  156. output = output.to(input_type)
  157. if get_present:
  158. output = (output, presents)
  159. if self.config.return_single_tuple:
  160. return (output, )
  161. elif self.config.return_tuple:
  162. return output if type(output) is tuple else (output, attn_mask)
  163. else:
  164. return output