gemm_test.h 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. #pragma once
  2. #include <cuda_fp16.h>
  3. #ifndef __HIP_PLATFORM_HCC__
  4. #include <cuda_profiler_api.h>
  5. #endif
  6. #include <array>
  7. #include <cstdio>
  8. #include <cstdlib>
  9. #include <ctime>
  10. #include <limits>
  11. #include <memory>
  12. #include "StopWatch.h"
  13. #include "cublas_wrappers.h"
  14. template <typename T>
  15. void check(T result, char const* const func, const char* const file, int const line)
  16. {
  17. if (result) {
  18. std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) +
  19. " \n");
  20. }
  21. }
  22. #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
  23. template <typename T>
  24. class GemmTest {
  25. public:
  26. GemmTest(int m, int n, int k, cublasOperation_t ta, cublasOperation_t tb, cublasHandle_t h)
  27. : M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
  28. {
  29. check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K));
  30. check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N));
  31. check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N));
  32. }
  33. ~GemmTest()
  34. {
  35. check_cuda_error(cudaFree(A));
  36. check_cuda_error(cudaFree(B));
  37. check_cuda_error(cudaFree(C));
  38. }
  39. std::array<int, 3> TestAlgo(int loops)
  40. {
  41. float alpha = (T)1.0f;
  42. float beta = (T)0.0f;
  43. int algo_fw = Run(loops, [=](int algo) {
  44. cublas_gemm_ex(handle,
  45. CUBLAS_OP_T,
  46. CUBLAS_OP_N,
  47. N,
  48. M,
  49. K,
  50. &alpha,
  51. &beta,
  52. B,
  53. A,
  54. C,
  55. #ifdef __HIP_PLATFORM_HCC__
  56. static_cast<rocblas_gemm_algo>(algo));
  57. #else
  58. static_cast<cublasGemmAlgo_t>(algo));
  59. #endif
  60. });
  61. int algo_bw1 = Run(loops, [=](int algo) {
  62. cublas_gemm_ex(handle,
  63. CUBLAS_OP_N,
  64. CUBLAS_OP_T,
  65. K,
  66. N,
  67. M,
  68. &alpha,
  69. &beta,
  70. A,
  71. C,
  72. B,
  73. #ifdef __HIP_PLATFORM_HCC__
  74. static_cast<rocblas_gemm_algo>(algo));
  75. #else
  76. static_cast<cublasGemmAlgo_t>(algo));
  77. #endif
  78. });
  79. int algo_bw2 = Run(loops, [=](int algo) {
  80. cublas_gemm_ex(handle,
  81. CUBLAS_OP_N,
  82. CUBLAS_OP_N,
  83. K,
  84. M,
  85. N,
  86. &alpha,
  87. &beta,
  88. B,
  89. C,
  90. A,
  91. #ifdef __HIP_PLATFORM_HCC__
  92. static_cast<rocblas_gemm_algo>(algo));
  93. #else
  94. static_cast<cublasGemmAlgo_t>(algo));
  95. #endif
  96. });
  97. return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
  98. }
  99. template <typename Func>
  100. int Run(int loops, Func f)
  101. {
  102. float fast_latency = (std::numeric_limits<float>::max)();
  103. int fast_algo = 0;
  104. #ifdef __HIP_PLATFORM_HCC__
  105. for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
  106. #else
  107. for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
  108. algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
  109. #endif
  110. algo++) {
  111. int warm_up = 5;
  112. for (int i = 0; i < warm_up; ++i) f(algo);
  113. cudaDeviceSynchronize();
  114. Stopwatch timer;
  115. timer.Restart();
  116. for (int i = 0; i < loops; ++i) f(algo);
  117. cudaDeviceSynchronize();
  118. timer.Stop();
  119. float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
  120. printf("algo-%d: %.3fms\n", algo, avg_latency);
  121. if (avg_latency < fast_latency) {
  122. fast_latency = avg_latency;
  123. fast_algo = algo;
  124. }
  125. }
  126. printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
  127. return fast_algo;
  128. }
  129. private:
  130. int M, N, K;
  131. cublasHandle_t handle;
  132. cublasOperation_t transa, transb;
  133. T *A, *B, *C;
  134. };
  135. template <typename T>
  136. class StridedGemmTest {
  137. public:
  138. StridedGemmTest(int b,
  139. int m,
  140. int n,
  141. int k,
  142. cublasOperation_t ta,
  143. cublasOperation_t tb,
  144. cublasHandle_t h)
  145. : bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
  146. {
  147. check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K * bsz));
  148. check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N * bsz));
  149. check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N * bsz));
  150. }
  151. ~StridedGemmTest()
  152. {
  153. check_cuda_error(cudaFree(A));
  154. check_cuda_error(cudaFree(B));
  155. check_cuda_error(cudaFree(C));
  156. }
  157. std::array<int, 3> TestAlgo(int loops)
  158. {
  159. float alpha = (T)1.0f;
  160. float beta = (T)0.0f;
  161. int algo_fw = Run(loops, [=](int algo) {
  162. int stride_a = M * K;
  163. int stride_b = N * K;
  164. int stride_c = M * N;
  165. cublas_strided_batched_gemm(handle,
  166. M,
  167. N,
  168. K,
  169. &alpha,
  170. &beta,
  171. A,
  172. B,
  173. C,
  174. transa,
  175. transb,
  176. stride_a,
  177. stride_b,
  178. stride_c,
  179. bsz,
  180. #ifdef __HIP_PLATFORM_HCC__
  181. static_cast<rocblas_gemm_algo>(algo));
  182. #else
  183. static_cast<cublasGemmAlgo_t>(algo));
  184. #endif
  185. });
  186. int algo_bw1 = Run(loops, [=](int algo) {
  187. int mb = (transa == CUBLAS_OP_T ? K : M);
  188. int kb = (transa == CUBLAS_OP_T ? M : K);
  189. int stride_a = mb * N;
  190. int stride_b = N * kb;
  191. int stride_c = M * K;
  192. // B need to transpose.
  193. cublasOperation_t op_b = (transb == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  194. // Calculate d_A.
  195. cublas_strided_batched_gemm(handle,
  196. mb,
  197. kb,
  198. N,
  199. &alpha,
  200. &beta,
  201. (transa == CUBLAS_OP_T ? B : C),
  202. (transa == CUBLAS_OP_T ? C : B),
  203. A,
  204. CUBLAS_OP_N,
  205. op_b,
  206. stride_a,
  207. stride_b,
  208. stride_c,
  209. bsz,
  210. #ifdef __HIP_PLATFORM_HCC__
  211. static_cast<rocblas_gemm_algo>(algo));
  212. #else
  213. static_cast<cublasGemmAlgo_t>(algo));
  214. #endif
  215. });
  216. int algo_bw2 = Run(loops, [=](int algo) {
  217. // A need to transpose.
  218. cublasOperation_t op_a = (transa == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  219. int stride_a = M * K;
  220. int stride_b = M * N;
  221. int stride_c = N * K;
  222. // Calculate d_B.
  223. cublas_strided_batched_gemm(handle,
  224. K,
  225. N,
  226. M,
  227. &alpha,
  228. &beta,
  229. A,
  230. C,
  231. B,
  232. op_a,
  233. CUBLAS_OP_N,
  234. stride_a,
  235. stride_b,
  236. stride_c,
  237. bsz,
  238. #ifdef __HIP_PLATFORM_HCC__
  239. static_cast<rocblas_gemm_algo>(algo));
  240. #else
  241. static_cast<cublasGemmAlgo_t>(algo));
  242. #endif
  243. });
  244. return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
  245. }
  246. template <typename Func>
  247. int Run(int loops, Func f)
  248. {
  249. float fast_latency = (std::numeric_limits<float>::max)();
  250. int fast_algo = 0;
  251. #ifdef __HIP_PLATFORM_HCC__
  252. for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
  253. #else
  254. for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
  255. algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
  256. #endif
  257. algo++) {
  258. int warm_up = 5;
  259. for (int i = 0; i < warm_up; ++i) f(algo);
  260. cudaDeviceSynchronize();
  261. Stopwatch timer;
  262. timer.Restart();
  263. for (int i = 0; i < loops; ++i) f(algo);
  264. cudaDeviceSynchronize();
  265. timer.Stop();
  266. float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
  267. printf("algo-%d: %.3fms\n", algo, avg_latency);
  268. if (avg_latency < fast_latency) {
  269. fast_latency = avg_latency;
  270. fast_algo = algo;
  271. }
  272. }
  273. printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
  274. return fast_algo;
  275. }
  276. private:
  277. int bsz, M, N, K;
  278. cublasHandle_t handle;
  279. cublasOperation_t transa, transb;
  280. T *A, *B, *C;
  281. };