123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- import torch.nn as nn
- from deepspeed import comm as dist
- from deepspeed.utils.logging import log_dist
- from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP
- from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
- from deepspeed.accelerator import get_accelerator
- from deepspeed.ops.op_builder import InferenceBuilder
- import deepspeed
- if deepspeed.HAS_TRITON:
- from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
- from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention
- inference_module = None
- class DeepSpeedTransformerInference(nn.Module):
- """Initialize the DeepSpeed Transformer Layer.
- Arguments:
- layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
- layer_id will be 0,1,2...23 when each layer object is instantiated
- config: An object of DeepSpeedInferenceConfig
- mp_group: Model parallelism group initialized on the modeling side.
- quantize_scales: This argument groups all the layers' scales used for quantization
- quantize_groups: Number of groups used for quantizing the model
- merge_count: Shows the number of model-parallel checkpoints merged before running inference.
- We use this argument to control the quantization scale for the model parameters if a bigger
- quantize-grouping than 1 is used.
- mlp_extra_grouping: This flag is used to show a 2x higher number of groups used for the MLP part
- of a Transformer layer. We use this feature for quantization to reduce the convergence impact
- for specific downstream tasks.
- """
- layer_id = 0
- def __init__(self,
- config,
- mp_group=None,
- quantize_scales=None,
- quantize_groups=1,
- merge_count=1,
- mlp_extra_grouping=False):
- super(DeepSpeedTransformerInference, self).__init__()
- self.config = config
- self.config.layer_id = DeepSpeedTransformerInference.layer_id
- DeepSpeedTransformerInference.layer_id += 1
- data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype
- global inference_module
- if inference_module is None:
- builder = InferenceBuilder()
- inference_module = builder.load()
- if DeepSpeedTransformerInference.layer_id == 1:
- log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
- if deepspeed.HAS_TRITON and self.config.use_triton:
- log_dist(f"Injecting Triton kernels ...", [0])
- if self.config.bigscience_bloom:
- self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
- assert not self.config.use_triton
- else:
- if deepspeed.HAS_TRITON and self.config.use_triton:
- self.attention = TritonSelfAttention(self.config)
- else:
- self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
- merge_count)
- if deepspeed.HAS_TRITON and self.config.use_triton:
- self.mlp = TritonMLP(self.config)
- else:
- self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
- mlp_extra_grouping)
- device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
- if self.config.set_empty_params:
- self.norm_w = None
- self.norm_b = None
- else:
- self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
- requires_grad=False)
- self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
- requires_grad=False)
- self.layer_past = None
- try:
- if config.dtype == torch.float32:
- self.allocate_workspace = inference_module.allocate_workspace_fp32
- elif config.dtype == torch.bfloat16:
- self.allocate_workspace = inference_module.allocate_workspace_bf16
- else:
- self.allocate_workspace = inference_module.allocate_workspace_fp32
- self._alloc_workspace = True
- except AttributeError:
- self.allocate_workspace = None
- self._alloc_workspace = False
- @classmethod
- def reset_cache(cls):
- if inference_module is not None:
- inference_module.reset_cache()
- def forward(
- self,
- input=None,
- input_mask=None,
- attention_mask=None,
- attn_mask=None,
- head_mask=None,
- layer_past=None,
- get_key_value=False,
- get_present=False,
- encoder_output=None,
- enc_dec_attn_mask=None,
- x=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- use_cache=False,
- alibi=None,
- output_attentions=False,
- # TODO(arashb): 'layer_head_mask' and 'past_key_value' are only added to satisfy the OPT models API.
- # This needs to be redesigned later!
- layer_head_mask=None,
- past_key_value=None,
- **kwargs):
- if x is not None:
- input = x
- if "hidden_states" in kwargs:
- input = kwargs["hidden_states"]
- input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask
- # Allocate memory only on first layer forward
- if self.config.layer_id == 0 and self._alloc_workspace:
- self.allocate_workspace(self.config.hidden_size, self.config.heads,
- input.size()[1],
- input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
- self.config.bigscience_bloom,
- dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
- self.config.min_out_tokens)
- self._alloc_workspace = False
- get_present = (get_present or get_key_value or use_cache)
- input_mask = input_mask if attention_mask is None else attention_mask
- # We set the prev key/value to None when there is a prompt
- if input.shape[1] > 1:
- self.layer_past = None
- layer_past = layer_past if layer_past is not None else self.layer_past
- head_mask = layer_head_mask if layer_head_mask is not None else head_mask
- attn_mask = None
- if isinstance(input, tuple):
- attn_mask = input[1]
- input = input[0]
- input_type = input.dtype
- if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \
- and input.dtype == torch.float:
- target_dtype = torch.half if self.dtype == torch.int8 else self.dtype
- input = input.to(target_dtype)
- with torch.no_grad():
- attention_output, key, value, context_outputtn_ctx, inp_norm = \
- self.attention(input,
- input_mask,
- head_mask,
- layer_past,
- get_present,
- encoder_hidden_states,
- encoder_attention_mask,
- output_attentions,
- self.norm_w,
- self.norm_b,
- alibi)
- presents = (key, value)
- self.layer_past = presents if layer_past is None else None
- output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
- if not self.config.pre_layer_norm:
- output = inference_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
- output = output.to(input_type)
- if get_present:
- output = (output, presents)
- if self.config.return_single_tuple:
- return (output, )
- elif self.config.return_tuple:
- return output if type(output) is tuple else (output, attn_mask)
- else:
- return output
|