feed_forward.h 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #ifndef __FEEDFORWARD_H__
  2. #define __FEEDFORWARD_H__
  3. #include <cuda.h>
  4. #include <cuda_fp16.h>
  5. #include <stdio.h>
  6. #include "custom_cuda_layers.h"
  7. template <typename T>
  8. class FeedForward {
  9. public:
  10. struct Config {
  11. int batchSize, outputSize;
  12. int inputSize;
  13. std::array<int, 3> gemm_algos;
  14. Config(int batch, int outputs, int inputs, const std::array<int, 3>& algos)
  15. : batchSize(batch), outputSize(outputs), inputSize(inputs), gemm_algos(algos)
  16. {
  17. }
  18. };
  19. FeedForward(Config config) : config_(config) {}
  20. ~FeedForward() {}
  21. void Forward(int bsz,
  22. const T* input_ptr,
  23. const T* weights,
  24. T* out,
  25. cublasHandle_t& _cublasHandle)
  26. {
  27. float alpha = T(1.);
  28. float beta = T(0.);
  29. cublas_gemm_ex(_cublasHandle,
  30. CUBLAS_OP_T,
  31. CUBLAS_OP_N,
  32. config_.outputSize,
  33. bsz,
  34. config_.inputSize,
  35. &alpha,
  36. &beta,
  37. weights,
  38. input_ptr,
  39. out,
  40. cublasGemmAlgo_t(config_.gemm_algos[0]));
  41. }
  42. void Backward(int bsz,
  43. const T* out_grad,
  44. const T* input_ptr,
  45. const T* weights,
  46. T* weights_grad,
  47. T* bias_grad,
  48. cublasHandle_t& _cublasHandle,
  49. cudaStream_t& stream,
  50. T* inp_grad_out = nullptr,
  51. T* out_grad_trans_out = nullptr)
  52. {
  53. float alpha = (T)1.0, beta = (T)0.0;
  54. cublas_gemm_ex(_cublasHandle,
  55. CUBLAS_OP_N,
  56. CUBLAS_OP_T,
  57. config_.inputSize,
  58. config_.outputSize,
  59. bsz,
  60. &alpha,
  61. &beta,
  62. input_ptr,
  63. out_grad,
  64. weights_grad,
  65. cublasGemmAlgo_t(config_.gemm_algos[1]));
  66. cublas_gemm_ex(_cublasHandle,
  67. CUBLAS_OP_N,
  68. CUBLAS_OP_N,
  69. config_.inputSize,
  70. bsz,
  71. config_.outputSize,
  72. &alpha,
  73. &beta,
  74. weights,
  75. out_grad,
  76. inp_grad_out,
  77. cublasGemmAlgo_t(config_.gemm_algos[2]));
  78. launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz, config_.outputSize, stream);
  79. }
  80. private:
  81. Config config_;
  82. };
  83. #endif