12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- import triton
- import triton.language as tl
- from deepspeed.accelerator import get_accelerator
- @triton.jit
- def residual_add_bias_kernel(
- hidden_state_ptr,
- residual_ptr,
- attn_output_ptr,
- hidden_state_size,
- attn_bias_ptr,
- final_bias_ptr,
- bias_size,
- output_ptr,
- mp_size: tl.constexpr,
- mlp_after_attn: tl.constexpr,
- pre_attn_norm: tl.constexpr,
- add_attn_bias: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
- ):
- pid = tl.program_id(axis=0)
- block_start = pid * BLOCK_SIZE
- offsets = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offsets < hidden_state_size
- bias_offsets = offsets % bias_size
- bias_mask = bias_offsets < bias_size
- tl_hidden_state = tl.load(hidden_state_ptr + offsets, mask=mask)
- tl_residual = tl.load(residual_ptr + offsets, mask=mask)
- tl_attn_output = tl.load(attn_output_ptr + offsets, mask=mask)
- tl_attn_bias = tl.load(attn_bias_ptr + bias_offsets, mask=bias_mask)
- tl_final_bias = tl.load(final_bias_ptr + bias_offsets, mask=bias_mask)
- if mlp_after_attn:
- if pre_attn_norm:
- output = tl_hidden_state + (tl_residual + tl_final_bias + tl_attn_output + tl_attn_bias) / mp_size
- else:
- output = tl_hidden_state + tl_residual + tl_final_bias
- else:
- output = tl_hidden_state + tl_attn_output + (tl_residual + tl_final_bias) / mp_size
- if add_attn_bias:
- output += tl_attn_bias / mp_size
- tl.store(output_ptr + offsets, output, mask=mask)
- def residual_add_bias(hidden_state: torch.Tensor, residual: torch.Tensor, attn_output: torch.Tensor,
- attn_bias: torch.Tensor, final_bias: torch.Tensor, mp_size: int, mlp_after_attn: bool,
- add_attn_bias: bool, pre_attn_norm: bool):
- # check that all tensors are on the same device
- assert get_accelerator().on_accelerator(hidden_state) \
- and get_accelerator().on_accelerator(residual) \
- and get_accelerator().on_accelerator(attn_output) \
- and get_accelerator().on_accelerator(attn_bias) \
- and get_accelerator().on_accelerator(final_bias)
- # check that all tensors have the same dtype
- assert hidden_state.dtype == residual.dtype == attn_output.dtype \
- == attn_bias.dtype == final_bias.dtype
- # check that all tensors have the right shape
- assert hidden_state.shape == residual.shape == attn_output.shape
- assert attn_bias.shape == final_bias.shape
- assert attn_bias.shape[0] == hidden_state.shape[2]
- output = torch.empty_like(hidden_state)
- hidden_state_size = output.numel()
- bias_size = attn_bias.numel()
- grid = lambda meta: (triton.cdiv(hidden_state_size, meta['BLOCK_SIZE']), )
- residual_add_bias_kernel[grid](hidden_state, residual, attn_output, hidden_state_size,\
- attn_bias, final_bias, bias_size, output, mp_size, mlp_after_attn, pre_attn_norm, \
- add_attn_bias, \
- BLOCK_SIZE=1024)
- return output
|