1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import os
- import torch
- import torch.nn.functional as F
- from ..config import DeepSpeedInferenceConfig
- from .base import BaseOp
- class SoftmaxOp(BaseOp):
- def __init__(self, config: DeepSpeedInferenceConfig):
- super(SoftmaxOp, self).__init__(config)
- self.num_attention_heads_per_partition = config.heads // config.mp_size
- try:
- if self.config.dtype in [torch.float16, torch.int8]:
- self.softmax_func = self.inference_module.softmax_fp16
- elif self.config.dtype == torch.bfloat16:
- self.softmax_func = self.inference_module.softmax_bf16
- else:
- self.softmax_func = self.inference_module.softmax_fp32
- except AttributeError:
- self.softmax_func = self.softmax_fallback
- def softmax_fallback(self, attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size,
- async_op, layer_scale, head_offset, mp_size):
- if os.environ.get('DS_KI_FALLBACK') == 'True':
- alibi = alibi[head_offset:head_offset + self.num_attention_heads_per_partition]
- input_dtype = attn_scores.dtype
- if (triangular):
- tri = ~torch.tril(torch.ones(attn_scores.size(), device=attn_scores.device)).to(bool)
- attn_scores = torch.masked_fill(attn_scores * layer_scale, tri, torch.finfo(input_dtype).min)
- if alibi is not None:
- attn_scores += alibi
- if attn_mask is not None:
- # expand atten_mask from two dim into 4 dim, insert two dims in the middle
- attn_mask = attn_mask[:, None, None, :]
- attn_scores += attn_mask
- output = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(input_dtype)
- return output
- else:
- raise NotImplementedError
- def forward(self, attn_scores: torch.Tensor, attn_mask: torch.Tensor, alibi: torch.Tensor, triangular: bool,
- recompute: bool, local_attention: bool, window_size: int, async_op: bool, layer_scale: float,
- head_offset: int):
- output = self.softmax_func(attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size,
- async_op, layer_scale, head_offset, self.config.mp_size)
- return output
|