strided_batch_gemm.h 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #pragma once
  2. #include <cuda.h>
  3. #include <cuda_fp16.h>
  4. #include <stdio.h>
  5. #include "context.h"
  6. template <typename T>
  7. class StridedBatchGemm {
  8. public:
  9. struct Config {
  10. int batch_size;
  11. int m;
  12. int n;
  13. int k;
  14. float alpha;
  15. float beta;
  16. cublasOperation_t op_A;
  17. cublasOperation_t op_B;
  18. std::array<int, 3> gemm_algos;
  19. Config(int batch,
  20. int mm,
  21. int nn,
  22. int kk,
  23. float param_alpha,
  24. float param_beta,
  25. cublasOperation_t opA,
  26. cublasOperation_t opB,
  27. const std::array<int, 3>& algos)
  28. : batch_size(batch),
  29. m(mm),
  30. n(nn),
  31. k(kk),
  32. alpha(param_alpha),
  33. beta(param_beta),
  34. op_A(opA),
  35. op_B(opB),
  36. gemm_algos(algos)
  37. {
  38. }
  39. void SetConfig(int mm, int nn, int kk)
  40. {
  41. m = mm;
  42. n = nn;
  43. k = kk;
  44. }
  45. };
  46. StridedBatchGemm(const Config& config) : _config(config) {}
  47. virtual ~StridedBatchGemm() {}
  48. void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle)
  49. {
  50. int stride_a = _config.m * _config.k;
  51. int stride_b = _config.n * _config.k;
  52. int stride_c = _config.m * _config.n;
  53. cublas_strided_batched_gemm(handle,
  54. _config.m,
  55. _config.n,
  56. _config.k,
  57. &_config.alpha,
  58. &_config.beta,
  59. _buffer_a,
  60. _buffer_b,
  61. output,
  62. _config.op_A,
  63. _config.op_B,
  64. stride_a,
  65. stride_b,
  66. stride_c,
  67. bsz,
  68. cublasGemmAlgo_t(_config.gemm_algos[0]));
  69. }
  70. void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle)
  71. {
  72. int stride_a = _config.m * _config.k;
  73. int stride_b = _config.n * _config.k;
  74. int stride_c = _config.m * _config.n;
  75. cublas_strided_batched_gemm(handle,
  76. _config.m,
  77. _config.n,
  78. _config.k,
  79. &_config.alpha,
  80. &_config.beta,
  81. _buffer_a,
  82. _buffer_b,
  83. output,
  84. _config.op_A,
  85. _config.op_B,
  86. stride_a,
  87. stride_b,
  88. stride_c,
  89. _config.batch_size,
  90. cublasGemmAlgo_t(_config.gemm_algos[0]));
  91. k_buf = _buffer_a;
  92. q_buf = _buffer_b;
  93. }
  94. void Backward(int bsz,
  95. const T* d_output,
  96. const T* _buffer_a,
  97. const T* _buffer_b,
  98. cublasHandle_t handle,
  99. T* inpGradA = nullptr,
  100. T* inpGradB = nullptr)
  101. {
  102. int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
  103. int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
  104. int stride_a = mb * _config.n;
  105. int stride_b = _config.n * kb;
  106. int stride_c = _config.m * _config.k;
  107. // B need to transpose.
  108. cublasOperation_t op_b = (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  109. // Calculate d_A.
  110. cublas_strided_batched_gemm(handle,
  111. mb,
  112. kb,
  113. _config.n,
  114. &_config.alpha,
  115. &_config.beta,
  116. (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
  117. (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b),
  118. inpGradA,
  119. CUBLAS_OP_N,
  120. op_b,
  121. stride_a,
  122. stride_b,
  123. stride_c,
  124. bsz,
  125. cublasGemmAlgo_t(_config.gemm_algos[1]));
  126. // A need to transpose.
  127. cublasOperation_t op_a = (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  128. stride_a = _config.m * _config.k;
  129. stride_b = _config.m * _config.n;
  130. stride_c = _config.n * _config.k;
  131. // Calculate d_B.
  132. cublas_strided_batched_gemm(handle,
  133. _config.k,
  134. _config.n,
  135. _config.m,
  136. &_config.alpha,
  137. &_config.beta,
  138. _buffer_a,
  139. d_output,
  140. inpGradB,
  141. op_a,
  142. CUBLAS_OP_N,
  143. stride_a,
  144. stride_b,
  145. stride_c,
  146. bsz,
  147. cublasGemmAlgo_t(_config.gemm_algos[2]));
  148. }
  149. inline int GetN() const { return _config.k; }
  150. inline const T* GetBufferA() const { return k_buf; }
  151. inline const T* GetBufferB() const { return q_buf; }
  152. inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
  153. private:
  154. Config _config;
  155. const T* q_buf;
  156. const T* k_buf;
  157. };