inference_cublas_wrappers.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <assert.h>
  6. #include <cublas_v2.h>
  7. #include <cuda.h>
  8. #include <cuda_bf16.h>
  9. #include <cuda_fp16.h>
  10. #include <cuda_runtime.h>
  11. #ifndef __HIP_PLATFORM_HCC__
  12. #include <mma.h>
  13. #endif
  14. #include <stdio.h>
  15. #ifdef __HIP_PLATFORM_HCC__
  16. int cublas_gemm_ex(rocblas_handle handle,
  17. rocblas_operation transa,
  18. rocblas_operation transb,
  19. int m,
  20. int n,
  21. int k,
  22. const float* alpha,
  23. const float* beta,
  24. const float* A,
  25. const float* B,
  26. float* C,
  27. rocblas_gemm_algo algo,
  28. int b_stride = -1)
  29. #else
  30. int cublas_gemm_ex(cublasHandle_t handle,
  31. cublasOperation_t transa,
  32. cublasOperation_t transb,
  33. int m,
  34. int n,
  35. int k,
  36. const float* alpha,
  37. const float* beta,
  38. const float* A,
  39. const float* B,
  40. float* C,
  41. cublasGemmAlgo_t algo,
  42. int b_stride = -1)
  43. #endif
  44. {
  45. const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride;
  46. #ifdef __HIP_PLATFORM_HCC__
  47. rocblas_status status = rocblas_gemm_ex(handle,
  48. transa,
  49. transb,
  50. m,
  51. n,
  52. k,
  53. (const void*)alpha,
  54. (const void*)A,
  55. rocblas_datatype_f32_r,
  56. (transa == rocblas_operation_none) ? m : k,
  57. (const void*)B,
  58. rocblas_datatype_f32_r,
  59. ldb,
  60. (const void*)beta,
  61. C,
  62. rocblas_datatype_f32_r,
  63. m,
  64. C,
  65. rocblas_datatype_f32_r,
  66. m,
  67. rocblas_datatype_f32_r,
  68. algo,
  69. 0,
  70. 0);
  71. #else
  72. cublasStatus_t status = cublasGemmEx(handle,
  73. transa,
  74. transb,
  75. m,
  76. n,
  77. k,
  78. (const void*)alpha,
  79. (const void*)A,
  80. CUDA_R_32F,
  81. (transa == CUBLAS_OP_N) ? m : k,
  82. (const void*)B,
  83. CUDA_R_32F,
  84. ldb,
  85. (const void*)beta,
  86. C,
  87. CUDA_R_32F,
  88. m,
  89. CUDA_R_32F,
  90. algo);
  91. #endif
  92. #ifdef __HIP_PLATFORM_HCC__
  93. if (status != rocblas_status_success) {
  94. #else
  95. if (status != CUBLAS_STATUS_SUCCESS) {
  96. #endif
  97. fprintf(stderr,
  98. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  99. m,
  100. n,
  101. k,
  102. (int)status);
  103. return EXIT_FAILURE;
  104. }
  105. return 0;
  106. }
  107. template <typename T>
  108. #ifdef __HIP_PLATFORM_HCC__
  109. int cublas_gemm_ex(rocblas_handle handle,
  110. rocblas_operation transa,
  111. rocblas_operation transb,
  112. int m,
  113. int n,
  114. int k,
  115. const float* alpha,
  116. const float* beta,
  117. const T* A,
  118. const T* B,
  119. T* C,
  120. rocblas_gemm_algo algo,
  121. int b_stride = -1)
  122. #else
  123. int cublas_gemm_ex(cublasHandle_t handle,
  124. cublasOperation_t transa,
  125. cublasOperation_t transb,
  126. int m,
  127. int n,
  128. int k,
  129. const float* alpha,
  130. const float* beta,
  131. const T* A,
  132. const T* B,
  133. T* C,
  134. cublasGemmAlgo_t algo,
  135. int b_stride = -1)
  136. #endif
  137. {
  138. const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride;
  139. #ifdef __HIP_PLATFORM_HCC__
  140. constexpr auto rocblas_dtype_16 = std::is_same<T, __half>::value ? rocblas_datatype_f16_r
  141. : rocblas_datatype_bf16_r;
  142. rocblas_status status = rocblas_gemm_ex(handle,
  143. transa,
  144. transb,
  145. m,
  146. n,
  147. k,
  148. (const void*)alpha,
  149. (const void*)A,
  150. rocblas_dtype_16,
  151. (transa == rocblas_operation_none) ? m : k,
  152. (const void*)B,
  153. rocblas_dtype_16,
  154. ldb,
  155. (const void*)beta,
  156. (void*)C,
  157. rocblas_dtype_16,
  158. m,
  159. (void*)C,
  160. rocblas_dtype_16,
  161. m,
  162. rocblas_datatype_f32_r,
  163. algo,
  164. 0,
  165. 0);
  166. #else
  167. constexpr auto cublas_dtype_16 = std::is_same<T, __half>::value ? CUDA_R_16F : CUDA_R_16BF;
  168. cublasStatus_t status = cublasGemmEx(handle,
  169. transa,
  170. transb,
  171. m,
  172. n,
  173. k,
  174. (const void*)alpha,
  175. (const void*)A,
  176. cublas_dtype_16,
  177. (transa == CUBLAS_OP_N) ? m : k,
  178. (const void*)B,
  179. cublas_dtype_16,
  180. ldb,
  181. (const void*)beta,
  182. (void*)C,
  183. cublas_dtype_16,
  184. m,
  185. CUDA_R_32F,
  186. algo);
  187. #endif
  188. #ifdef __HIP_PLATFORM_HCC__
  189. if (status != rocblas_status_success) {
  190. #else
  191. if (status != CUBLAS_STATUS_SUCCESS) {
  192. #endif
  193. fprintf(stderr,
  194. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  195. m,
  196. n,
  197. k,
  198. (int)status);
  199. return EXIT_FAILURE;
  200. }
  201. return 0;
  202. }
  203. #ifdef __HIP_PLATFORM_HCC__
  204. int cublas_strided_batched_gemm(rocblas_handle handle,
  205. int m,
  206. int n,
  207. int k,
  208. const float* alpha,
  209. const float* beta,
  210. const float* A,
  211. const float* B,
  212. float* C,
  213. rocblas_operation op_A,
  214. rocblas_operation op_B,
  215. int stride_A,
  216. int stride_B,
  217. int stride_C,
  218. int batch,
  219. rocblas_gemm_algo algo)
  220. #else
  221. int cublas_strided_batched_gemm(cublasHandle_t handle,
  222. int m,
  223. int n,
  224. int k,
  225. const float* alpha,
  226. const float* beta,
  227. const float* A,
  228. const float* B,
  229. float* C,
  230. cublasOperation_t op_A,
  231. cublasOperation_t op_B,
  232. int stride_A,
  233. int stride_B,
  234. int stride_C,
  235. int batch,
  236. cublasGemmAlgo_t algo)
  237. #endif
  238. {
  239. #ifdef __HIP_PLATFORM_HCC__
  240. rocblas_status status =
  241. rocblas_gemm_strided_batched_ex(handle,
  242. op_A,
  243. op_B,
  244. m,
  245. n,
  246. k,
  247. alpha,
  248. A,
  249. rocblas_datatype_f32_r,
  250. (op_A == rocblas_operation_none) ? m : k,
  251. stride_A,
  252. B,
  253. rocblas_datatype_f32_r,
  254. (op_B == rocblas_operation_none) ? k : n,
  255. stride_B,
  256. beta,
  257. C,
  258. rocblas_datatype_f32_r,
  259. m,
  260. stride_C,
  261. C,
  262. rocblas_datatype_f32_r,
  263. m,
  264. stride_C,
  265. batch,
  266. rocblas_datatype_f32_r,
  267. algo,
  268. 0,
  269. 0);
  270. #else
  271. cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
  272. op_A,
  273. op_B,
  274. m,
  275. n,
  276. k,
  277. alpha,
  278. A,
  279. CUDA_R_32F,
  280. (op_A == CUBLAS_OP_N) ? m : k,
  281. stride_A,
  282. B,
  283. CUDA_R_32F,
  284. (op_B == CUBLAS_OP_N) ? k : n,
  285. stride_B,
  286. beta,
  287. C,
  288. CUDA_R_32F,
  289. m,
  290. stride_C,
  291. batch,
  292. CUDA_R_32F,
  293. algo);
  294. #endif
  295. #ifdef __HIP_PLATFORM_HCC__
  296. if (status != rocblas_status_success) {
  297. #else
  298. if (status != CUBLAS_STATUS_SUCCESS) {
  299. #endif
  300. fprintf(stderr,
  301. "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
  302. batch,
  303. m,
  304. n,
  305. k,
  306. (int)status);
  307. return EXIT_FAILURE;
  308. }
  309. return 0;
  310. }
  311. template <typename T>
  312. #ifdef __HIP_PLATFORM_HCC__
  313. int cublas_strided_batched_gemm(rocblas_handle handle,
  314. int m,
  315. int n,
  316. int k,
  317. const float* alpha,
  318. const float* beta,
  319. const T* A,
  320. const T* B,
  321. T* C,
  322. rocblas_operation op_A,
  323. rocblas_operation op_B,
  324. int stride_A,
  325. int stride_B,
  326. int stride_C,
  327. int batch,
  328. rocblas_gemm_algo algo)
  329. #else
  330. int cublas_strided_batched_gemm(cublasHandle_t handle,
  331. int m,
  332. int n,
  333. int k,
  334. const float* alpha,
  335. const float* beta,
  336. const T* A,
  337. const T* B,
  338. T* C,
  339. cublasOperation_t op_A,
  340. cublasOperation_t op_B,
  341. int stride_A,
  342. int stride_B,
  343. int stride_C,
  344. int batch,
  345. cublasGemmAlgo_t algo)
  346. #endif
  347. {
  348. #ifdef __HIP_PLATFORM_HCC__
  349. constexpr auto rocblas_dtype_16 = std::is_same<T, __half>::value ? rocblas_datatype_f16_r
  350. : rocblas_datatype_bf16_r;
  351. rocblas_status status =
  352. rocblas_gemm_strided_batched_ex(handle,
  353. op_A,
  354. op_B,
  355. m,
  356. n,
  357. k,
  358. alpha,
  359. A,
  360. rocblas_dtype_16,
  361. (op_A == rocblas_operation_none) ? m : k,
  362. stride_A,
  363. B,
  364. rocblas_dtype_16,
  365. (op_B == rocblas_operation_none) ? k : n,
  366. stride_B,
  367. beta,
  368. C,
  369. rocblas_dtype_16,
  370. m,
  371. stride_C,
  372. C,
  373. rocblas_dtype_16,
  374. m,
  375. stride_C,
  376. batch,
  377. rocblas_datatype_f32_r,
  378. algo,
  379. 0,
  380. 0);
  381. #else
  382. constexpr auto cublas_dtype_16 = std::is_same<T, __half>::value ? CUDA_R_16F : CUDA_R_16BF;
  383. cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
  384. op_A,
  385. op_B,
  386. m,
  387. n,
  388. k,
  389. alpha,
  390. A,
  391. cublas_dtype_16,
  392. (op_A == CUBLAS_OP_N) ? m : k,
  393. stride_A,
  394. B,
  395. cublas_dtype_16,
  396. (op_B == CUBLAS_OP_N) ? k : n,
  397. stride_B,
  398. beta,
  399. C,
  400. cublas_dtype_16,
  401. m,
  402. stride_C,
  403. batch,
  404. CUDA_R_32F,
  405. algo);
  406. #endif
  407. #ifdef __HIP_PLATFORM_HCC__
  408. if (status != rocblas_status_success) {
  409. #else
  410. if (status != CUBLAS_STATUS_SUCCESS) {
  411. #endif
  412. fprintf(stderr,
  413. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  414. m,
  415. n,
  416. k,
  417. (int)status);
  418. return EXIT_FAILURE;
  419. }
  420. return 0;
  421. }