softmax.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. import torch.nn.functional as F
  7. from ..config import DeepSpeedInferenceConfig
  8. from .base import BaseOp
  9. class SoftmaxOp(BaseOp):
  10. def __init__(self, config: DeepSpeedInferenceConfig):
  11. super(SoftmaxOp, self).__init__(config)
  12. self.num_attention_heads_per_partition = config.heads // config.mp_size
  13. try:
  14. if self.config.dtype in [torch.float16, torch.int8]:
  15. self.softmax_func = self.inference_module.softmax_fp16
  16. elif self.config.dtype == torch.bfloat16:
  17. self.softmax_func = self.inference_module.softmax_bf16
  18. else:
  19. self.softmax_func = self.inference_module.softmax_fp32
  20. except AttributeError:
  21. self.softmax_func = self.softmax_fallback
  22. def softmax_fallback(self, attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size,
  23. async_op, layer_scale, head_offset, mp_size):
  24. if os.environ.get('DS_KI_FALLBACK') == 'True':
  25. alibi = alibi[head_offset:head_offset + self.num_attention_heads_per_partition]
  26. input_dtype = attn_scores.dtype
  27. if (triangular):
  28. tri = ~torch.tril(torch.ones(attn_scores.size(), device=attn_scores.device)).to(bool)
  29. attn_scores = torch.masked_fill(attn_scores * layer_scale, tri, torch.finfo(input_dtype).min)
  30. if alibi is not None:
  31. attn_scores += alibi
  32. if attn_mask is not None:
  33. # expand atten_mask from two dim into 4 dim, insert two dims in the middle
  34. attn_mask = attn_mask[:, None, None, :]
  35. attn_scores += attn_mask
  36. output = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(input_dtype)
  37. return output
  38. else:
  39. raise NotImplementedError
  40. def forward(self, attn_scores: torch.Tensor, attn_mask: torch.Tensor, alibi: torch.Tensor, triangular: bool,
  41. recompute: bool, local_attention: bool, window_size: int, async_op: bool, layer_scale: float,
  42. head_offset: int):
  43. output = self.softmax_func(attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size,
  44. async_op, layer_scale, head_offset, self.config.mp_size)
  45. return output