gelu.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import triton
  6. import triton.language as tl
  7. from deepspeed.accelerator import get_accelerator
  8. @triton.jit
  9. def gelu_functor(x):
  10. # Using approximation introduces greater parity errors.
  11. # return tl.sigmoid(1.702 * x) * x
  12. return x * 0.5 * (1.0 + tl.libdevice.erf(x / 1.41421356237))
  13. @triton.jit
  14. def gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
  15. pid = tl.program_id(axis=0)
  16. block_start = pid * BLOCK_SIZE
  17. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  18. mask = offsets < n_elements
  19. x = tl.load(x_ptr + offsets, mask=mask)
  20. output = gelu_functor(x)
  21. tl.store(output_ptr + offsets, output, mask=mask)
  22. def gelu(activations: torch.Tensor) -> torch.Tensor:
  23. assert activations.is_contiguous()
  24. assert get_accelerator().on_accelerator(activations)
  25. output = torch.empty_like(activations)
  26. n_elements = output.numel()
  27. grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
  28. gelu_kernel[grid](activations, output, n_elements, BLOCK_SIZE=1024)
  29. return output