softmax.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import triton
  6. import triton.language as tl
  7. '''
  8. softmax
  9. modified the triton kernel in
  10. https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
  11. '''
  12. @triton.jit
  13. def softmax_kernel(output_ptr, input_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr):
  14. row_idx = tl.program_id(0)
  15. row_start_ptr = input_ptr + row_idx * stride
  16. col_offsets = tl.arange(0, BLOCK_SIZE)
  17. input_ptrs = row_start_ptr + col_offsets
  18. row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
  19. row_minus_max = row - tl.max(row, axis=0)
  20. numerator = tl.exp(row_minus_max)
  21. denominator = tl.sum(numerator, axis=0)
  22. softmax_output = numerator / denominator
  23. output_row_start_ptr = output_ptr + row_idx * stride
  24. output_ptrs = output_row_start_ptr + col_offsets
  25. tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
  26. @triton.jit
  27. def masked_softmax_kernel(output_ptr, input_ptr, stride, mask_ptr, mask_stride, n_cols, BLOCK_SIZE: tl.constexpr):
  28. row_idx = tl.program_id(0)
  29. row_start_ptr = input_ptr + row_idx * stride
  30. col_offsets = tl.arange(0, BLOCK_SIZE)
  31. input_ptrs = row_start_ptr + col_offsets
  32. mask_ptrs = mask_ptr + col_offsets + row_idx * mask_stride # mask_stride is 0 for 1d mask
  33. row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
  34. mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
  35. row_minus_max = row - tl.max(row, axis=0)
  36. row_minus_max = row_minus_max + mask
  37. numerator = tl.exp(row_minus_max)
  38. denominator = tl.sum(numerator, axis=0)
  39. softmax_output = numerator / denominator
  40. output_row_start_ptr = output_ptr + row_idx * stride
  41. output_ptrs = output_row_start_ptr + col_offsets
  42. tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
  43. def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
  44. assert input.is_contiguous()
  45. assert (dim == -1) or (dim == len(input.shape) - 1), "Only dim=-1 is supported"
  46. use_mask = False if mask is None else True
  47. input_arg = input.view(-1, input.shape[-1])
  48. n_rows, n_cols = input_arg.shape
  49. BLOCK_SIZE = max(triton.next_power_of_2(n_cols), 2)
  50. num_warps = 4
  51. if BLOCK_SIZE >= 2048:
  52. num_warps = 8
  53. if BLOCK_SIZE >= 4096:
  54. num_warps = 16
  55. # Allocate output
  56. output = torch.empty_like(input)
  57. if use_mask:
  58. assert mask.is_contiguous()
  59. mask = mask.view(-1, mask.shape[-1])
  60. mask_stride = mask.shape[-1] if mask.shape[-2] > 1 else 0
  61. masked_softmax_kernel[(n_rows, )](
  62. output,
  63. input,
  64. input_arg.stride(0),
  65. mask,
  66. mask_stride,
  67. n_cols,
  68. num_warps=num_warps,
  69. BLOCK_SIZE=BLOCK_SIZE,
  70. )
  71. else:
  72. softmax_kernel[(n_rows, )](
  73. output,
  74. input,
  75. input_arg.stride(0),
  76. n_cols,
  77. num_warps=num_warps,
  78. BLOCK_SIZE=BLOCK_SIZE,
  79. )
  80. return output