gemm_test.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <cuda_fp16.h>
  6. #ifndef __HIP_PLATFORM_HCC__
  7. #include <cuda_profiler_api.h>
  8. #endif
  9. #include <array>
  10. #include <cstdio>
  11. #include <cstdlib>
  12. #include <ctime>
  13. #include <limits>
  14. #include <memory>
  15. #include "StopWatch.h"
  16. #include "cublas_wrappers.h"
  17. template <typename T>
  18. void check(T result, char const* const func, const char* const file, int const line)
  19. {
  20. if (result) {
  21. std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) +
  22. " \n");
  23. }
  24. }
  25. #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
  26. template <typename T>
  27. class GemmTest {
  28. public:
  29. GemmTest(int m, int n, int k, cublasOperation_t ta, cublasOperation_t tb, cublasHandle_t h)
  30. : M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
  31. {
  32. check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K));
  33. check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N));
  34. check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N));
  35. }
  36. ~GemmTest()
  37. {
  38. check_cuda_error(cudaFree(A));
  39. check_cuda_error(cudaFree(B));
  40. check_cuda_error(cudaFree(C));
  41. }
  42. std::array<int, 3> TestAlgo(int loops)
  43. {
  44. float alpha = (T)1.0f;
  45. float beta = (T)0.0f;
  46. int algo_fw = Run(loops, [=](int algo) {
  47. cublas_gemm_ex(handle,
  48. CUBLAS_OP_T,
  49. CUBLAS_OP_N,
  50. N,
  51. M,
  52. K,
  53. &alpha,
  54. &beta,
  55. B,
  56. A,
  57. C,
  58. #ifdef __HIP_PLATFORM_HCC__
  59. static_cast<rocblas_gemm_algo>(algo));
  60. #else
  61. static_cast<cublasGemmAlgo_t>(algo));
  62. #endif
  63. });
  64. int algo_bw1 = Run(loops, [=](int algo) {
  65. cublas_gemm_ex(handle,
  66. CUBLAS_OP_N,
  67. CUBLAS_OP_T,
  68. K,
  69. N,
  70. M,
  71. &alpha,
  72. &beta,
  73. A,
  74. C,
  75. B,
  76. #ifdef __HIP_PLATFORM_HCC__
  77. static_cast<rocblas_gemm_algo>(algo));
  78. #else
  79. static_cast<cublasGemmAlgo_t>(algo));
  80. #endif
  81. });
  82. int algo_bw2 = Run(loops, [=](int algo) {
  83. cublas_gemm_ex(handle,
  84. CUBLAS_OP_N,
  85. CUBLAS_OP_N,
  86. K,
  87. M,
  88. N,
  89. &alpha,
  90. &beta,
  91. B,
  92. C,
  93. A,
  94. #ifdef __HIP_PLATFORM_HCC__
  95. static_cast<rocblas_gemm_algo>(algo));
  96. #else
  97. static_cast<cublasGemmAlgo_t>(algo));
  98. #endif
  99. });
  100. return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
  101. }
  102. template <typename Func>
  103. int Run(int loops, Func f)
  104. {
  105. float fast_latency = (std::numeric_limits<float>::max)();
  106. int fast_algo = 0;
  107. #ifdef __HIP_PLATFORM_HCC__
  108. for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
  109. #else
  110. for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
  111. algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
  112. #endif
  113. algo++) {
  114. int warm_up = 5;
  115. for (int i = 0; i < warm_up; ++i) f(algo);
  116. cudaDeviceSynchronize();
  117. Stopwatch timer;
  118. timer.Restart();
  119. for (int i = 0; i < loops; ++i) f(algo);
  120. cudaDeviceSynchronize();
  121. timer.Stop();
  122. float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
  123. printf("algo-%d: %.3fms\n", algo, avg_latency);
  124. if (avg_latency < fast_latency) {
  125. fast_latency = avg_latency;
  126. fast_algo = algo;
  127. }
  128. }
  129. printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
  130. return fast_algo;
  131. }
  132. private:
  133. int M, N, K;
  134. cublasHandle_t handle;
  135. cublasOperation_t transa, transb;
  136. T *A, *B, *C;
  137. };
  138. template <typename T>
  139. class StridedGemmTest {
  140. public:
  141. StridedGemmTest(int b,
  142. int m,
  143. int n,
  144. int k,
  145. cublasOperation_t ta,
  146. cublasOperation_t tb,
  147. cublasHandle_t h)
  148. : bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
  149. {
  150. check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K * bsz));
  151. check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N * bsz));
  152. check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N * bsz));
  153. }
  154. ~StridedGemmTest()
  155. {
  156. check_cuda_error(cudaFree(A));
  157. check_cuda_error(cudaFree(B));
  158. check_cuda_error(cudaFree(C));
  159. }
  160. std::array<int, 3> TestAlgo(int loops)
  161. {
  162. float alpha = (T)1.0f;
  163. float beta = (T)0.0f;
  164. int algo_fw = Run(loops, [=](int algo) {
  165. int stride_a = M * K;
  166. int stride_b = N * K;
  167. int stride_c = M * N;
  168. cublas_strided_batched_gemm(handle,
  169. M,
  170. N,
  171. K,
  172. &alpha,
  173. &beta,
  174. A,
  175. B,
  176. C,
  177. transa,
  178. transb,
  179. stride_a,
  180. stride_b,
  181. stride_c,
  182. bsz,
  183. #ifdef __HIP_PLATFORM_HCC__
  184. static_cast<rocblas_gemm_algo>(algo));
  185. #else
  186. static_cast<cublasGemmAlgo_t>(algo));
  187. #endif
  188. });
  189. int algo_bw1 = Run(loops, [=](int algo) {
  190. int mb = (transa == CUBLAS_OP_T ? K : M);
  191. int kb = (transa == CUBLAS_OP_T ? M : K);
  192. int stride_a = mb * N;
  193. int stride_b = N * kb;
  194. int stride_c = M * K;
  195. // B need to transpose.
  196. cublasOperation_t op_b = (transb == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  197. // Calculate d_A.
  198. cublas_strided_batched_gemm(handle,
  199. mb,
  200. kb,
  201. N,
  202. &alpha,
  203. &beta,
  204. (transa == CUBLAS_OP_T ? B : C),
  205. (transa == CUBLAS_OP_T ? C : B),
  206. A,
  207. CUBLAS_OP_N,
  208. op_b,
  209. stride_a,
  210. stride_b,
  211. stride_c,
  212. bsz,
  213. #ifdef __HIP_PLATFORM_HCC__
  214. static_cast<rocblas_gemm_algo>(algo));
  215. #else
  216. static_cast<cublasGemmAlgo_t>(algo));
  217. #endif
  218. });
  219. int algo_bw2 = Run(loops, [=](int algo) {
  220. // A need to transpose.
  221. cublasOperation_t op_a = (transa == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  222. int stride_a = M * K;
  223. int stride_b = M * N;
  224. int stride_c = N * K;
  225. // Calculate d_B.
  226. cublas_strided_batched_gemm(handle,
  227. K,
  228. N,
  229. M,
  230. &alpha,
  231. &beta,
  232. A,
  233. C,
  234. B,
  235. op_a,
  236. CUBLAS_OP_N,
  237. stride_a,
  238. stride_b,
  239. stride_c,
  240. bsz,
  241. #ifdef __HIP_PLATFORM_HCC__
  242. static_cast<rocblas_gemm_algo>(algo));
  243. #else
  244. static_cast<cublasGemmAlgo_t>(algo));
  245. #endif
  246. });
  247. return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
  248. }
  249. template <typename Func>
  250. int Run(int loops, Func f)
  251. {
  252. float fast_latency = (std::numeric_limits<float>::max)();
  253. int fast_algo = 0;
  254. #ifdef __HIP_PLATFORM_HCC__
  255. for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
  256. #else
  257. for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
  258. algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
  259. #endif
  260. algo++) {
  261. int warm_up = 5;
  262. for (int i = 0; i < warm_up; ++i) f(algo);
  263. cudaDeviceSynchronize();
  264. Stopwatch timer;
  265. timer.Restart();
  266. for (int i = 0; i < loops; ++i) f(algo);
  267. cudaDeviceSynchronize();
  268. timer.Stop();
  269. float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
  270. printf("algo-%d: %.3fms\n", algo, avg_latency);
  271. if (avg_latency < fast_latency) {
  272. fast_latency = avg_latency;
  273. fast_algo = algo;
  274. }
  275. }
  276. printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
  277. return fast_algo;
  278. }
  279. private:
  280. int bsz, M, N, K;
  281. cublasHandle_t handle;
  282. cublasOperation_t transa, transb;
  283. T *A, *B, *C;
  284. };