strided_batch_gemm.h 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #pragma once
  2. #include <cuda.h>
  3. #include <cuda_fp16.h>
  4. #include <stdio.h>
  5. #include "context.h"
  6. template <typename T>
  7. class StridedBatchGemm {
  8. public:
  9. struct Config {
  10. int batch_size;
  11. int m;
  12. int n;
  13. int k;
  14. float alpha;
  15. float beta;
  16. cublasOperation_t op_A;
  17. cublasOperation_t op_B;
  18. std::array<int, 3> gemm_algos;
  19. Config(int batch,
  20. int mm,
  21. int nn,
  22. int kk,
  23. float param_alpha,
  24. float param_beta,
  25. cublasOperation_t opA,
  26. cublasOperation_t opB,
  27. const std::array<int, 3>& algos)
  28. : batch_size(batch),
  29. m(mm),
  30. n(nn),
  31. k(kk),
  32. alpha(param_alpha),
  33. beta(param_beta),
  34. op_A(opA),
  35. op_B(opB),
  36. gemm_algos(algos)
  37. {
  38. }
  39. void SetConfig(int mm, int nn, int kk)
  40. {
  41. m = mm;
  42. n = nn;
  43. k = kk;
  44. }
  45. };
  46. StridedBatchGemm(const Config& config) : _config(config) {}
  47. virtual ~StridedBatchGemm() {}
  48. void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle)
  49. {
  50. int stride_a = _config.m * _config.k;
  51. int stride_b = _config.n * _config.k;
  52. int stride_c = _config.m * _config.n;
  53. cublas_strided_batched_gemm(handle,
  54. _config.m,
  55. _config.n,
  56. _config.k,
  57. &_config.alpha,
  58. &_config.beta,
  59. _buffer_a,
  60. _buffer_b,
  61. output,
  62. _config.op_A,
  63. _config.op_B,
  64. stride_a,
  65. stride_b,
  66. stride_c,
  67. bsz,
  68. #ifdef __HIP_PLATFORM_HCC__
  69. rocblas_gemm_algo(_config.gemm_algos[0]));
  70. #else
  71. cublasGemmAlgo_t(_config.gemm_algos[0]));
  72. #endif
  73. }
  74. void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle)
  75. {
  76. int stride_a = _config.m * _config.k;
  77. int stride_b = _config.n * _config.k;
  78. int stride_c = _config.m * _config.n;
  79. cublas_strided_batched_gemm(handle,
  80. _config.m,
  81. _config.n,
  82. _config.k,
  83. &_config.alpha,
  84. &_config.beta,
  85. _buffer_a,
  86. _buffer_b,
  87. output,
  88. _config.op_A,
  89. _config.op_B,
  90. stride_a,
  91. stride_b,
  92. stride_c,
  93. _config.batch_size,
  94. #ifdef __HIP_PLATFORM_HCC__
  95. rocblas_gemm_algo(_config.gemm_algos[0]));
  96. #else
  97. cublasGemmAlgo_t(_config.gemm_algos[0]));
  98. #endif
  99. k_buf = _buffer_a;
  100. q_buf = _buffer_b;
  101. }
  102. void Backward(int bsz,
  103. const T* d_output,
  104. const T* _buffer_a,
  105. const T* _buffer_b,
  106. cublasHandle_t handle,
  107. T* inpGradA = nullptr,
  108. T* inpGradB = nullptr)
  109. {
  110. int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
  111. int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
  112. int stride_a = mb * _config.n;
  113. int stride_b = _config.n * kb;
  114. int stride_c = _config.m * _config.k;
  115. // B need to transpose.
  116. cublasOperation_t op_b = (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  117. // Calculate d_A.
  118. cublas_strided_batched_gemm(handle,
  119. mb,
  120. kb,
  121. _config.n,
  122. &_config.alpha,
  123. &_config.beta,
  124. (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
  125. (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b),
  126. inpGradA,
  127. CUBLAS_OP_N,
  128. op_b,
  129. stride_a,
  130. stride_b,
  131. stride_c,
  132. bsz,
  133. #ifdef __HIP_PLATFORM_HCC__
  134. rocblas_gemm_algo(_config.gemm_algos[1]));
  135. #else
  136. cublasGemmAlgo_t(_config.gemm_algos[1]));
  137. #endif
  138. // A need to transpose.
  139. cublasOperation_t op_a = (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  140. stride_a = _config.m * _config.k;
  141. stride_b = _config.m * _config.n;
  142. stride_c = _config.n * _config.k;
  143. // Calculate d_B.
  144. cublas_strided_batched_gemm(handle,
  145. _config.k,
  146. _config.n,
  147. _config.m,
  148. &_config.alpha,
  149. &_config.beta,
  150. _buffer_a,
  151. d_output,
  152. inpGradB,
  153. op_a,
  154. CUBLAS_OP_N,
  155. stride_a,
  156. stride_b,
  157. stride_c,
  158. bsz,
  159. #ifdef __HIP_PLATFORM_HCC__
  160. rocblas_gemm_algo(_config.gemm_algos[2]));
  161. #else
  162. cublasGemmAlgo_t(_config.gemm_algos[2]));
  163. #endif
  164. }
  165. inline int GetN() const { return _config.k; }
  166. inline const T* GetBufferA() const { return k_buf; }
  167. inline const T* GetBufferB() const { return q_buf; }
  168. inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
  169. private:
  170. Config _config;
  171. const T* q_buf;
  172. const T* k_buf;
  173. };