strided_batch_gemm.h 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <cuda.h>
  6. #include <cuda_fp16.h>
  7. #include <stdio.h>
  8. #include "context.h"
  9. template <typename T>
  10. class StridedBatchGemm {
  11. public:
  12. struct Config {
  13. int batch_size;
  14. int m;
  15. int n;
  16. int k;
  17. float alpha;
  18. float beta;
  19. cublasOperation_t op_A;
  20. cublasOperation_t op_B;
  21. std::array<int, 3> gemm_algos;
  22. Config(int batch,
  23. int mm,
  24. int nn,
  25. int kk,
  26. float param_alpha,
  27. float param_beta,
  28. cublasOperation_t opA,
  29. cublasOperation_t opB,
  30. const std::array<int, 3>& algos)
  31. : batch_size(batch),
  32. m(mm),
  33. n(nn),
  34. k(kk),
  35. alpha(param_alpha),
  36. beta(param_beta),
  37. op_A(opA),
  38. op_B(opB),
  39. gemm_algos(algos)
  40. {
  41. }
  42. void SetConfig(int mm, int nn, int kk)
  43. {
  44. m = mm;
  45. n = nn;
  46. k = kk;
  47. }
  48. };
  49. StridedBatchGemm(const Config& config) : _config(config) {}
  50. virtual ~StridedBatchGemm() {}
  51. void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle)
  52. {
  53. int stride_a = _config.m * _config.k;
  54. int stride_b = _config.n * _config.k;
  55. int stride_c = _config.m * _config.n;
  56. cublas_strided_batched_gemm(handle,
  57. _config.m,
  58. _config.n,
  59. _config.k,
  60. &_config.alpha,
  61. &_config.beta,
  62. _buffer_a,
  63. _buffer_b,
  64. output,
  65. _config.op_A,
  66. _config.op_B,
  67. stride_a,
  68. stride_b,
  69. stride_c,
  70. bsz,
  71. #ifdef __HIP_PLATFORM_HCC__
  72. rocblas_gemm_algo(_config.gemm_algos[0]));
  73. #else
  74. cublasGemmAlgo_t(_config.gemm_algos[0]));
  75. #endif
  76. }
  77. void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle)
  78. {
  79. int stride_a = _config.m * _config.k;
  80. int stride_b = _config.n * _config.k;
  81. int stride_c = _config.m * _config.n;
  82. cublas_strided_batched_gemm(handle,
  83. _config.m,
  84. _config.n,
  85. _config.k,
  86. &_config.alpha,
  87. &_config.beta,
  88. _buffer_a,
  89. _buffer_b,
  90. output,
  91. _config.op_A,
  92. _config.op_B,
  93. stride_a,
  94. stride_b,
  95. stride_c,
  96. _config.batch_size,
  97. #ifdef __HIP_PLATFORM_HCC__
  98. rocblas_gemm_algo(_config.gemm_algos[0]));
  99. #else
  100. cublasGemmAlgo_t(_config.gemm_algos[0]));
  101. #endif
  102. k_buf = _buffer_a;
  103. q_buf = _buffer_b;
  104. }
  105. void Backward(int bsz,
  106. const T* d_output,
  107. const T* _buffer_a,
  108. const T* _buffer_b,
  109. cublasHandle_t handle,
  110. T* inpGradA = nullptr,
  111. T* inpGradB = nullptr)
  112. {
  113. int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
  114. int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
  115. int stride_a = mb * _config.n;
  116. int stride_b = _config.n * kb;
  117. int stride_c = _config.m * _config.k;
  118. // B need to transpose.
  119. cublasOperation_t op_b = (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  120. // Calculate d_A.
  121. cublas_strided_batched_gemm(handle,
  122. mb,
  123. kb,
  124. _config.n,
  125. &_config.alpha,
  126. &_config.beta,
  127. (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
  128. (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b),
  129. inpGradA,
  130. CUBLAS_OP_N,
  131. op_b,
  132. stride_a,
  133. stride_b,
  134. stride_c,
  135. bsz,
  136. #ifdef __HIP_PLATFORM_HCC__
  137. rocblas_gemm_algo(_config.gemm_algos[1]));
  138. #else
  139. cublasGemmAlgo_t(_config.gemm_algos[1]));
  140. #endif
  141. // A need to transpose.
  142. cublasOperation_t op_a = (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  143. stride_a = _config.m * _config.k;
  144. stride_b = _config.m * _config.n;
  145. stride_c = _config.n * _config.k;
  146. // Calculate d_B.
  147. cublas_strided_batched_gemm(handle,
  148. _config.k,
  149. _config.n,
  150. _config.m,
  151. &_config.alpha,
  152. &_config.beta,
  153. _buffer_a,
  154. d_output,
  155. inpGradB,
  156. op_a,
  157. CUBLAS_OP_N,
  158. stride_a,
  159. stride_b,
  160. stride_c,
  161. bsz,
  162. #ifdef __HIP_PLATFORM_HCC__
  163. rocblas_gemm_algo(_config.gemm_algos[2]));
  164. #else
  165. cublasGemmAlgo_t(_config.gemm_algos[2]));
  166. #endif
  167. }
  168. inline int GetN() const { return _config.k; }
  169. inline const T* GetBufferA() const { return k_buf; }
  170. inline const T* GetBufferB() const { return q_buf; }
  171. inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
  172. private:
  173. Config _config;
  174. const T* q_buf;
  175. const T* k_buf;
  176. };