gelu.h 1018 B

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <cuda.h>
  6. #include <cuda_fp16.h>
  7. #include <stdio.h>
  8. #include "custom_cuda_layers.h"
  9. template <typename T>
  10. class Gelu {
  11. public:
  12. struct Config {
  13. uint32_t intermediate_size;
  14. Config(uint32_t inter_size) : intermediate_size(inter_size) {}
  15. };
  16. Gelu(const Config& config) : _config(config) {}
  17. virtual ~Gelu() {}
  18. void ForwardWithBiasAdd(int bsz,
  19. const T* input_buf,
  20. const T* bias,
  21. T* output,
  22. cudaStream_t stream)
  23. {
  24. launch_bias_gelu<T>(input_buf, bias, output, _config.intermediate_size, bsz, stream);
  25. }
  26. void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream)
  27. {
  28. launch_d_gelu<T>(d_output, input_buf, bias, _config.intermediate_size, bsz, stream);
  29. }
  30. private:
  31. Config _config;
  32. };