residual_add.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import triton
  6. import triton.language as tl
  7. from deepspeed.accelerator import get_accelerator
  8. @triton.jit
  9. def residual_add_bias_kernel(
  10. hidden_state_ptr,
  11. residual_ptr,
  12. attn_output_ptr,
  13. hidden_state_size,
  14. attn_bias_ptr,
  15. final_bias_ptr,
  16. bias_size,
  17. output_ptr,
  18. mp_size: tl.constexpr,
  19. mlp_after_attn: tl.constexpr,
  20. pre_attn_norm: tl.constexpr,
  21. add_attn_bias: tl.constexpr,
  22. BLOCK_SIZE: tl.constexpr,
  23. ):
  24. pid = tl.program_id(axis=0)
  25. block_start = pid * BLOCK_SIZE
  26. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  27. mask = offsets < hidden_state_size
  28. bias_offsets = offsets % bias_size
  29. bias_mask = bias_offsets < bias_size
  30. tl_hidden_state = tl.load(hidden_state_ptr + offsets, mask=mask)
  31. tl_residual = tl.load(residual_ptr + offsets, mask=mask)
  32. tl_attn_output = tl.load(attn_output_ptr + offsets, mask=mask)
  33. tl_attn_bias = tl.load(attn_bias_ptr + bias_offsets, mask=bias_mask)
  34. tl_final_bias = tl.load(final_bias_ptr + bias_offsets, mask=bias_mask)
  35. if mlp_after_attn:
  36. if pre_attn_norm:
  37. output = tl_hidden_state + (tl_residual + tl_final_bias + tl_attn_output + tl_attn_bias) / mp_size
  38. else:
  39. output = tl_hidden_state + tl_residual + tl_final_bias
  40. else:
  41. output = tl_hidden_state + tl_attn_output + (tl_residual + tl_final_bias) / mp_size
  42. if add_attn_bias:
  43. output += tl_attn_bias / mp_size
  44. tl.store(output_ptr + offsets, output, mask=mask)
  45. def residual_add_bias(hidden_state: torch.Tensor, residual: torch.Tensor, attn_output: torch.Tensor,
  46. attn_bias: torch.Tensor, final_bias: torch.Tensor, mp_size: int, mlp_after_attn: bool,
  47. add_attn_bias: bool, pre_attn_norm: bool):
  48. # check that all tensors are on the same device
  49. assert get_accelerator().on_accelerator(hidden_state) \
  50. and get_accelerator().on_accelerator(residual) \
  51. and get_accelerator().on_accelerator(attn_output) \
  52. and get_accelerator().on_accelerator(attn_bias) \
  53. and get_accelerator().on_accelerator(final_bias)
  54. # check that all tensors have the same dtype
  55. assert hidden_state.dtype == residual.dtype == attn_output.dtype \
  56. == attn_bias.dtype == final_bias.dtype
  57. # check that all tensors have the right shape
  58. assert hidden_state.shape == residual.shape == attn_output.shape
  59. assert attn_bias.shape == final_bias.shape
  60. assert attn_bias.shape[0] == hidden_state.shape[2]
  61. output = torch.empty_like(hidden_state)
  62. hidden_state_size = output.numel()
  63. bias_size = attn_bias.numel()
  64. grid = lambda meta: (triton.cdiv(hidden_state_size, meta['BLOCK_SIZE']), )
  65. residual_add_bias_kernel[grid](hidden_state, residual, attn_output, hidden_state_size,\
  66. attn_bias, final_bias, bias_size, output, mp_size, mlp_after_attn, pre_attn_norm, \
  67. add_attn_bias, \
  68. BLOCK_SIZE=1024)
  69. return output