123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from deepspeed import comm as dist
- from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
- inference_module = None
- class DeepSpeedLlama2Inference(DeepSpeedTransformerInference):
- """Initialize the DeepSpeed OPT Transformer Layer.
- """
- def __init__(self,
- config,
- mp_group=None,
- quantize_scales=None,
- quantize_groups=1,
- merge_count=1,
- mlp_extra_grouping=False):
- super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)
- def forward(self, *args, **kwargs):
- input = args[0]
- input_mask = None
- # Allocate memory only on first layer forward
- if self.config.layer_id == 0 and self._alloc_workspace:
- self.allocate_workspace(self.config.hidden_size, self.config.heads,
- input.size()[1],
- input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
- self.config.bigscience_bloom,
- dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
- self.config.min_out_tokens)
- self._alloc_workspace = False
- get_present = True
- # We set the prev key/value to None when there is a prompt
- if input.shape[1] > 1:
- self.layer_past = None
- layer_past = self.layer_past
- input_type = input.dtype
- if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \
- and input.dtype == torch.float:
- target_dtype = torch.half if self.dtype == torch.int8 else self.dtype
- input = input.to(target_dtype)
- with torch.no_grad():
- attention_output, key, value, context_outputtn_ctx, inp_norm = \
- self.attention(input,
- input_mask,
- None,
- layer_past,
- get_present,
- None, None, None,
- self.norm_w,
- self.norm_b,
- None)
- self.layer_past = (key, value)
- output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
- output = output.to(input_type)
- return output
|