gemm_test.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <cuda_fp16.h>
  6. #ifndef __HIP_PLATFORM_AMD__
  7. #include <cuda_profiler_api.h>
  8. #endif
  9. #ifdef __HIP_PLATFORM_AMD__
  10. #include <rocblas/rocblas.h>
  11. #endif
  12. #include <array>
  13. #include <cstdio>
  14. #include <cstdlib>
  15. #include <ctime>
  16. #include <limits>
  17. #include <memory>
  18. #include "StopWatch.h"
  19. #include "cublas_wrappers.h"
  20. template <typename T>
  21. void check(T result, char const* const func, const char* const file, int const line)
  22. {
  23. if (result) {
  24. std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) +
  25. " \n");
  26. }
  27. }
  28. #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
  29. template <typename T>
  30. class GemmTest {
  31. public:
  32. GemmTest(int m, int n, int k, cublasOperation_t ta, cublasOperation_t tb, cublasHandle_t h)
  33. : M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
  34. {
  35. check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K));
  36. check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N));
  37. check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N));
  38. }
  39. ~GemmTest()
  40. {
  41. check_cuda_error(cudaFree(A));
  42. check_cuda_error(cudaFree(B));
  43. check_cuda_error(cudaFree(C));
  44. }
  45. std::array<int, 3> TestAlgo(int loops)
  46. {
  47. float alpha = (T)1.0f;
  48. float beta = (T)0.0f;
  49. int algo_fw = Run(loops, [=](int algo) {
  50. cublas_gemm_ex(handle,
  51. CUBLAS_OP_T,
  52. CUBLAS_OP_N,
  53. N,
  54. M,
  55. K,
  56. &alpha,
  57. &beta,
  58. B,
  59. A,
  60. C,
  61. #ifdef __HIP_PLATFORM_AMD__
  62. static_cast<rocblas_gemm_algo>(algo));
  63. #else
  64. static_cast<cublasGemmAlgo_t>(algo));
  65. #endif
  66. });
  67. int algo_bw1 = Run(loops, [=](int algo) {
  68. cublas_gemm_ex(handle,
  69. CUBLAS_OP_N,
  70. CUBLAS_OP_T,
  71. K,
  72. N,
  73. M,
  74. &alpha,
  75. &beta,
  76. A,
  77. C,
  78. B,
  79. #ifdef __HIP_PLATFORM_AMD__
  80. static_cast<rocblas_gemm_algo>(algo));
  81. #else
  82. static_cast<cublasGemmAlgo_t>(algo));
  83. #endif
  84. });
  85. int algo_bw2 = Run(loops, [=](int algo) {
  86. cublas_gemm_ex(handle,
  87. CUBLAS_OP_N,
  88. CUBLAS_OP_N,
  89. K,
  90. M,
  91. N,
  92. &alpha,
  93. &beta,
  94. B,
  95. C,
  96. A,
  97. #ifdef __HIP_PLATFORM_AMD__
  98. static_cast<rocblas_gemm_algo>(algo));
  99. #else
  100. static_cast<cublasGemmAlgo_t>(algo));
  101. #endif
  102. });
  103. return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
  104. }
  105. template <typename Func>
  106. int Run(int loops, Func f)
  107. {
  108. float fast_latency = (std::numeric_limits<float>::max)();
  109. int fast_algo = 0;
  110. #ifdef __HIP_PLATFORM_AMD__
  111. for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
  112. #else
  113. for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
  114. algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
  115. #endif
  116. algo++) {
  117. int warm_up = 5;
  118. for (int i = 0; i < warm_up; ++i) f(algo);
  119. cudaDeviceSynchronize();
  120. Stopwatch timer;
  121. timer.Restart();
  122. for (int i = 0; i < loops; ++i) f(algo);
  123. cudaDeviceSynchronize();
  124. timer.Stop();
  125. float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
  126. printf("algo-%d: %.3fms\n", algo, avg_latency);
  127. if (avg_latency < fast_latency) {
  128. fast_latency = avg_latency;
  129. fast_algo = algo;
  130. }
  131. }
  132. printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
  133. return fast_algo;
  134. }
  135. private:
  136. int M, N, K;
  137. cublasHandle_t handle;
  138. cublasOperation_t transa, transb;
  139. T *A, *B, *C;
  140. };
  141. template <typename T>
  142. class StridedGemmTest {
  143. public:
  144. StridedGemmTest(int b,
  145. int m,
  146. int n,
  147. int k,
  148. cublasOperation_t ta,
  149. cublasOperation_t tb,
  150. cublasHandle_t h)
  151. : bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
  152. {
  153. check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K * bsz));
  154. check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N * bsz));
  155. check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N * bsz));
  156. }
  157. ~StridedGemmTest()
  158. {
  159. check_cuda_error(cudaFree(A));
  160. check_cuda_error(cudaFree(B));
  161. check_cuda_error(cudaFree(C));
  162. }
  163. std::array<int, 3> TestAlgo(int loops)
  164. {
  165. float alpha = (T)1.0f;
  166. float beta = (T)0.0f;
  167. int algo_fw = Run(loops, [=](int algo) {
  168. int stride_a = M * K;
  169. int stride_b = N * K;
  170. int stride_c = M * N;
  171. cublas_strided_batched_gemm(handle,
  172. M,
  173. N,
  174. K,
  175. &alpha,
  176. &beta,
  177. A,
  178. B,
  179. C,
  180. transa,
  181. transb,
  182. stride_a,
  183. stride_b,
  184. stride_c,
  185. bsz,
  186. #ifdef __HIP_PLATFORM_AMD__
  187. static_cast<rocblas_gemm_algo>(algo));
  188. #else
  189. static_cast<cublasGemmAlgo_t>(algo));
  190. #endif
  191. });
  192. int algo_bw1 = Run(loops, [=](int algo) {
  193. int mb = (transa == CUBLAS_OP_T ? K : M);
  194. int kb = (transa == CUBLAS_OP_T ? M : K);
  195. int stride_a = mb * N;
  196. int stride_b = N * kb;
  197. int stride_c = M * K;
  198. // B need to transpose.
  199. cublasOperation_t op_b = (transb == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  200. // Calculate d_A.
  201. cublas_strided_batched_gemm(handle,
  202. mb,
  203. kb,
  204. N,
  205. &alpha,
  206. &beta,
  207. (transa == CUBLAS_OP_T ? B : C),
  208. (transa == CUBLAS_OP_T ? C : B),
  209. A,
  210. CUBLAS_OP_N,
  211. op_b,
  212. stride_a,
  213. stride_b,
  214. stride_c,
  215. bsz,
  216. #ifdef __HIP_PLATFORM_AMD__
  217. static_cast<rocblas_gemm_algo>(algo));
  218. #else
  219. static_cast<cublasGemmAlgo_t>(algo));
  220. #endif
  221. });
  222. int algo_bw2 = Run(loops, [=](int algo) {
  223. // A need to transpose.
  224. cublasOperation_t op_a = (transa == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
  225. int stride_a = M * K;
  226. int stride_b = M * N;
  227. int stride_c = N * K;
  228. // Calculate d_B.
  229. cublas_strided_batched_gemm(handle,
  230. K,
  231. N,
  232. M,
  233. &alpha,
  234. &beta,
  235. A,
  236. C,
  237. B,
  238. op_a,
  239. CUBLAS_OP_N,
  240. stride_a,
  241. stride_b,
  242. stride_c,
  243. bsz,
  244. #ifdef __HIP_PLATFORM_AMD__
  245. static_cast<rocblas_gemm_algo>(algo));
  246. #else
  247. static_cast<cublasGemmAlgo_t>(algo));
  248. #endif
  249. });
  250. return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
  251. }
  252. template <typename Func>
  253. int Run(int loops, Func f)
  254. {
  255. float fast_latency = (std::numeric_limits<float>::max)();
  256. int fast_algo = 0;
  257. #ifdef __HIP_PLATFORM_AMD__
  258. for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
  259. #else
  260. for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
  261. algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
  262. #endif
  263. algo++) {
  264. int warm_up = 5;
  265. for (int i = 0; i < warm_up; ++i) f(algo);
  266. cudaDeviceSynchronize();
  267. Stopwatch timer;
  268. timer.Restart();
  269. for (int i = 0; i < loops; ++i) f(algo);
  270. cudaDeviceSynchronize();
  271. timer.Stop();
  272. float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
  273. printf("algo-%d: %.3fms\n", algo, avg_latency);
  274. if (avg_latency < fast_latency) {
  275. fast_latency = avg_latency;
  276. fast_algo = algo;
  277. }
  278. }
  279. printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
  280. return fast_algo;
  281. }
  282. private:
  283. int bsz, M, N, K;
  284. cublasHandle_t handle;
  285. cublasOperation_t transa, transb;
  286. T *A, *B, *C;
  287. };