gemm_test.h 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. #pragma once
  2. #include <cuda_fp16.h>
  3. #include <cuda_profiler_api.h>
  4. #include <array>
  5. #include <cstdio>
  6. #include <cstdlib>
  7. #include <ctime>
  8. #include <limits>
  9. #include <memory>
  10. #include "StopWatch.h"
  11. #include "cublas_wrappers.h"
  12. template <typename T>
  13. void check(T result, char const* const func, const char* const file, int const line)
  14. {
  15. if (result) {
  16. std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) +
  17. " \n");
  18. }
  19. }
  20. #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
  21. template <typename T>
  22. class GemmTest {
  23. public:
  24. GemmTest(int m, int n, int k, cublasOperation_t ta, cublasOperation_t tb, cublasHandle_t h)
  25. : M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
  26. {
  27. check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K));
  28. check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N));
  29. check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N));
  30. }
  31. ~GemmTest()
  32. {
  33. check_cuda_error(cudaFree(A));
  34. check_cuda_error(cudaFree(B));
  35. check_cuda_error(cudaFree(C));
  36. }
  37. std::array<int, 3> TestAlgo(int loops)
  38. {
  39. float alpha = (T)1.0f;
  40. float beta = (T)0.0f;
  41. int algo_fw = Run(loops, [=](int algo) {
  42. cublas_gemm_ex(handle,
  43. CUBLAS_OP_T,
  44. CUBLAS_OP_N,
  45. N,
  46. M,
  47. K,
  48. &alpha,
  49. &beta,
  50. B,
  51. A,
  52. C,
  53. static_cast<cublasGemmAlgo_t>(algo));
  54. });
  55. int algo_bw1 = Run(loops, [=](int algo) {
  56. cublas_gemm_ex(handle,
  57. CUBLAS_OP_N,
  58. CUBLAS_OP_T,
  59. K,
  60. N,
  61. M,
  62. &alpha,
  63. &beta,
  64. A,
  65. C,
  66. B,
  67. static_cast<cublasGemmAlgo_t>(algo));
  68. });
  69. int algo_bw2 = Run(loops, [=](int algo) {
  70. cublas_gemm_ex(handle,
  71. CUBLAS_OP_N,
  72. CUBLAS_OP_N,
  73. K,
  74. M,
  75. N,
  76. &alpha,
  77. &beta,
  78. B,
  79. C,
  80. A,
  81. static_cast<cublasGemmAlgo_t>(algo));
  82. });
  83. return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
  84. }
  85. template <typename Func>
  86. int Run(int loops, Func f)
  87. {
  88. float fast_latency = (std::numeric_limits<float>::max)();
  89. int fast_algo = 0;
  90. for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
  91. algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
  92. algo++) {
  93. int warm_up = 5;
  94. for (int i = 0; i < warm_up; ++i) f(algo);
  95. cudaDeviceSynchronize();
  96. Stopwatch timer;
  97. timer.Restart();
  98. for (int i = 0; i < loops; ++i) f(algo);
  99. cudaDeviceSynchronize();
  100. timer.Stop();
  101. float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
  102. printf("algo-%d: %.3fms\n", algo, avg_latency);
  103. if (avg_latency < fast_latency) {
  104. fast_latency = avg_latency;
  105. fast_algo = algo;
  106. }
  107. }
  108. printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
  109. return fast_algo;
  110. }
  111. private:
  112. int M, N, K;
  113. cublasHandle_t handle;
  114. cublasOperation_t transa, transb;
  115. T *A, *B, *C;
  116. };
  117. template <typename T>
  118. class StridedGemmTest {
  119. public:
  120. StridedGemmTest(int b,
  121. int m,
  122. int n,
  123. int k,
  124. cublasOperation_t ta,
  125. cublasOperation_t tb,
  126. cublasHandle_t h)
  127. : bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
  128. {
  129. check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K * bsz));
  130. check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N * bsz));
  131. check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N * bsz));
  132. }
  133. ~StridedGemmTest()
  134. {
  135. check_cuda_error(cudaFree(A));
  136. check_cuda_error(cudaFree(B));
  137. check_cuda_error(cudaFree(C));
  138. }
  139. std::array<int, 3> TestAlgo(int loops)
  140. {
  141. float alpha = (T)1.0f;
  142. float beta = (T)0.0f;
  143. int algo_fw = Run(loops, [=](int algo) {
  144. int stride_a = M * K;
  145. int stride_b = N * K;
  146. int stride_c = M * N;
  147. cublas_strided_batched_gemm(handle,
  148. M,
  149. N,
  150. K,
  151. &alpha,
  152. &beta,
  153. A,
  154. B,
  155. C,
  156. transa,
  157. transb,
  158. stride_a,
  159. stride_b,
  160. stride_c,
  161. bsz,
  162. static_cast<cublasGemmAlgo_t>(algo));
  163. });
  164. int algo_bw1 = Run(loops, [=](int algo) {
  165. int mb = (transa == CUBLAS_OP_T ? K : M);
  166. int kb = (transa == CUBLAS_OP_T ? M : K);
  167. int stride_a = mb * N;
  168. int stride_b = N * kb;
  169. int stride_c = M * K;
  170. // B need to transpose.
  171. cublasOperation_t op_b = (transb == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  172. // Calculate d_A.
  173. cublas_strided_batched_gemm(handle,
  174. mb,
  175. kb,
  176. N,
  177. &alpha,
  178. &beta,
  179. (transa == CUBLAS_OP_T ? B : C),
  180. (transa == CUBLAS_OP_T ? C : B),
  181. A,
  182. CUBLAS_OP_N,
  183. op_b,
  184. stride_a,
  185. stride_b,
  186. stride_c,
  187. bsz,
  188. static_cast<cublasGemmAlgo_t>(algo));
  189. });
  190. int algo_bw2 = Run(loops, [=](int algo) {
  191. // A need to transpose.
  192. cublasOperation_t op_a = (transa == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  193. int stride_a = M * K;
  194. int stride_b = M * N;
  195. int stride_c = N * K;
  196. // Calculate d_B.
  197. cublas_strided_batched_gemm(handle,
  198. K,
  199. N,
  200. M,
  201. &alpha,
  202. &beta,
  203. A,
  204. C,
  205. B,
  206. op_a,
  207. CUBLAS_OP_N,
  208. stride_a,
  209. stride_b,
  210. stride_c,
  211. bsz,
  212. static_cast<cublasGemmAlgo_t>(algo));
  213. });
  214. return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
  215. }
  216. template <typename Func>
  217. int Run(int loops, Func f)
  218. {
  219. float fast_latency = (std::numeric_limits<float>::max)();
  220. int fast_algo = 0;
  221. for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
  222. algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
  223. algo++) {
  224. int warm_up = 5;
  225. for (int i = 0; i < warm_up; ++i) f(algo);
  226. cudaDeviceSynchronize();
  227. Stopwatch timer;
  228. timer.Restart();
  229. for (int i = 0; i < loops; ++i) f(algo);
  230. cudaDeviceSynchronize();
  231. timer.Stop();
  232. float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
  233. printf("algo-%d: %.3fms\n", algo, avg_latency);
  234. if (avg_latency < fast_latency) {
  235. fast_latency = avg_latency;
  236. fast_algo = algo;
  237. }
  238. }
  239. printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
  240. return fast_algo;
  241. }
  242. private:
  243. int bsz, M, N, K;
  244. cublasHandle_t handle;
  245. cublasOperation_t transa, transb;
  246. T *A, *B, *C;
  247. };