ds_llama2.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed import comm as dist
  6. from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
  7. inference_module = None
  8. class DeepSpeedLlama2Inference(DeepSpeedTransformerInference):
  9. """Initialize the DeepSpeed OPT Transformer Layer.
  10. """
  11. def __init__(self,
  12. config,
  13. mp_group=None,
  14. quantize_scales=None,
  15. quantize_groups=1,
  16. merge_count=1,
  17. mlp_extra_grouping=False):
  18. super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)
  19. def forward(self, *args, **kwargs):
  20. input = args[0]
  21. input_mask = None
  22. # Allocate memory only on first layer forward
  23. if self.config.layer_id == 0 and self._alloc_workspace:
  24. self.allocate_workspace(self.config.hidden_size, self.config.heads,
  25. input.size()[1],
  26. input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
  27. self.config.bigscience_bloom,
  28. dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
  29. self.config.min_out_tokens)
  30. self._alloc_workspace = False
  31. get_present = True
  32. # We set the prev key/value to None when there is a prompt
  33. if input.shape[1] > 1:
  34. self.layer_past = None
  35. layer_past = self.layer_past
  36. input_type = input.dtype
  37. if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \
  38. and input.dtype == torch.float:
  39. target_dtype = torch.half if self.dtype == torch.int8 else self.dtype
  40. input = input.to(target_dtype)
  41. with torch.no_grad():
  42. attention_output, key, value, context_outputtn_ctx, inp_norm = \
  43. self.attention(input,
  44. input_mask,
  45. None,
  46. layer_past,
  47. get_present,
  48. None, None, None,
  49. self.norm_w,
  50. self.norm_b,
  51. None)
  52. self.layer_past = (key, value)
  53. output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
  54. output = output.to(input_type)
  55. return output