feed_forward.h 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #ifndef __FEEDFORWARD_H__
  2. #define __FEEDFORWARD_H__
  3. #include <cuda.h>
  4. #include <cuda_fp16.h>
  5. #include <stdio.h>
  6. #include "custom_cuda_layers.h"
  7. template <typename T>
  8. class FeedForward {
  9. public:
  10. struct Config {
  11. int batchSize, outputSize;
  12. int inputSize;
  13. std::array<int, 3> gemm_algos;
  14. Config(int batch, int outputs, int inputs, const std::array<int, 3>& algos)
  15. : batchSize(batch), outputSize(outputs), inputSize(inputs), gemm_algos(algos)
  16. {
  17. }
  18. };
  19. FeedForward(Config config) : config_(config) {}
  20. ~FeedForward() {}
  21. void Forward(int bsz,
  22. const T* input_ptr,
  23. const T* weights,
  24. T* out,
  25. cublasHandle_t& _cublasHandle)
  26. {
  27. float alpha = T(1.);
  28. float beta = T(0.);
  29. cublas_gemm_ex(_cublasHandle,
  30. CUBLAS_OP_T,
  31. CUBLAS_OP_N,
  32. config_.outputSize,
  33. bsz,
  34. config_.inputSize,
  35. &alpha,
  36. &beta,
  37. weights,
  38. input_ptr,
  39. out,
  40. #ifdef __HIP_PLATFORM_HCC__
  41. rocblas_gemm_algo(config_.gemm_algos[0]));
  42. #else
  43. cublasGemmAlgo_t(config_.gemm_algos[0]));
  44. #endif
  45. }
  46. void Backward(int bsz,
  47. const T* out_grad,
  48. const T* input_ptr,
  49. const T* weights,
  50. T* weights_grad,
  51. T* bias_grad,
  52. cublasHandle_t& _cublasHandle,
  53. cudaStream_t& stream,
  54. T* inp_grad_out = nullptr,
  55. T* out_grad_trans_out = nullptr)
  56. {
  57. float alpha = (T)1.0, beta = (T)0.0;
  58. cublas_gemm_ex(_cublasHandle,
  59. CUBLAS_OP_N,
  60. CUBLAS_OP_T,
  61. config_.inputSize,
  62. config_.outputSize,
  63. bsz,
  64. &alpha,
  65. &beta,
  66. input_ptr,
  67. out_grad,
  68. weights_grad,
  69. #ifdef __HIP_PLATFORM_HCC__
  70. rocblas_gemm_algo(config_.gemm_algos[1]));
  71. #else
  72. cublasGemmAlgo_t(config_.gemm_algos[1]));
  73. #endif
  74. cublas_gemm_ex(_cublasHandle,
  75. CUBLAS_OP_N,
  76. CUBLAS_OP_N,
  77. config_.inputSize,
  78. bsz,
  79. config_.outputSize,
  80. &alpha,
  81. &beta,
  82. weights,
  83. out_grad,
  84. inp_grad_out,
  85. #ifdef __HIP_PLATFORM_HCC__
  86. rocblas_gemm_algo(config_.gemm_algos[2]));
  87. #else
  88. cublasGemmAlgo_t(config_.gemm_algos[2]));
  89. #endif
  90. launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz, config_.outputSize, stream);
  91. }
  92. private:
  93. Config config_;
  94. };
  95. #endif