feed_forward.h 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #ifndef __FEEDFORWARD_H__
  5. #define __FEEDFORWARD_H__
  6. #include <cuda.h>
  7. #include <cuda_fp16.h>
  8. #include <stdio.h>
  9. #include "custom_cuda_layers.h"
  10. template <typename T>
  11. class FeedForward {
  12. public:
  13. struct Config {
  14. int batchSize, outputSize;
  15. int inputSize;
  16. std::array<int, 3> gemm_algos;
  17. Config(int batch, int outputs, int inputs, const std::array<int, 3>& algos)
  18. : batchSize(batch), outputSize(outputs), inputSize(inputs), gemm_algos(algos)
  19. {
  20. }
  21. };
  22. FeedForward(Config config) : config_(config) {}
  23. ~FeedForward() {}
  24. void Forward(int bsz,
  25. const T* input_ptr,
  26. const T* weights,
  27. T* out,
  28. cublasHandle_t& _cublasHandle)
  29. {
  30. float alpha = T(1.);
  31. float beta = T(0.);
  32. cublas_gemm_ex(_cublasHandle,
  33. CUBLAS_OP_T,
  34. CUBLAS_OP_N,
  35. config_.outputSize,
  36. bsz,
  37. config_.inputSize,
  38. &alpha,
  39. &beta,
  40. weights,
  41. input_ptr,
  42. out,
  43. #ifdef __HIP_PLATFORM_AMD__
  44. rocblas_gemm_algo(config_.gemm_algos[0]));
  45. #else
  46. cublasGemmAlgo_t(config_.gemm_algos[0]));
  47. #endif
  48. }
  49. void Backward(int bsz,
  50. const T* out_grad,
  51. const T* input_ptr,
  52. const T* weights,
  53. T* weights_grad,
  54. T* bias_grad,
  55. cublasHandle_t& _cublasHandle,
  56. cudaStream_t& stream,
  57. T* inp_grad_out = nullptr,
  58. T* out_grad_trans_out = nullptr)
  59. {
  60. float alpha = (T)1.0, beta = (T)0.0;
  61. cublas_gemm_ex(_cublasHandle,
  62. CUBLAS_OP_N,
  63. CUBLAS_OP_T,
  64. config_.inputSize,
  65. config_.outputSize,
  66. bsz,
  67. &alpha,
  68. &beta,
  69. input_ptr,
  70. out_grad,
  71. weights_grad,
  72. #ifdef __HIP_PLATFORM_AMD__
  73. rocblas_gemm_algo(config_.gemm_algos[1]));
  74. #else
  75. cublasGemmAlgo_t(config_.gemm_algos[1]));
  76. #endif
  77. cublas_gemm_ex(_cublasHandle,
  78. CUBLAS_OP_N,
  79. CUBLAS_OP_N,
  80. config_.inputSize,
  81. bsz,
  82. config_.outputSize,
  83. &alpha,
  84. &beta,
  85. weights,
  86. out_grad,
  87. inp_grad_out,
  88. #ifdef __HIP_PLATFORM_AMD__
  89. rocblas_gemm_algo(config_.gemm_algos[2]));
  90. #else
  91. cublasGemmAlgo_t(config_.gemm_algos[2]));
  92. #endif
  93. launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz, config_.outputSize, stream);
  94. }
  95. private:
  96. Config config_;
  97. };
  98. #endif