cublas_wrappers.cu 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "cublas_wrappers.h"
  5. #ifdef __HIP_PLATFORM_HCC__
  6. int cublas_gemm_ex(rocblas_handle handle,
  7. rocblas_operation transa,
  8. rocblas_operation transb,
  9. int m,
  10. int n,
  11. int k,
  12. const float* alpha,
  13. const float* beta,
  14. const float* A,
  15. const float* B,
  16. float* C,
  17. rocblas_gemm_algo algo)
  18. #else
  19. int cublas_gemm_ex(cublasHandle_t handle,
  20. cublasOperation_t transa,
  21. cublasOperation_t transb,
  22. int m,
  23. int n,
  24. int k,
  25. const float* alpha,
  26. const float* beta,
  27. const float* A,
  28. const float* B,
  29. float* C,
  30. cublasGemmAlgo_t algo)
  31. #endif
  32. {
  33. #ifdef __HIP_PLATFORM_HCC__
  34. rocblas_status status = rocblas_gemm_ex(handle,
  35. transa,
  36. transb,
  37. m,
  38. n,
  39. k,
  40. (const void*)alpha,
  41. (const void*)A,
  42. rocblas_datatype_f32_r,
  43. (transa == rocblas_operation_none) ? m : k,
  44. (const void*)B,
  45. rocblas_datatype_f32_r,
  46. (transb == rocblas_operation_none) ? k : n,
  47. (const void*)beta,
  48. C,
  49. rocblas_datatype_f32_r,
  50. m,
  51. C,
  52. rocblas_datatype_f32_r,
  53. m,
  54. rocblas_datatype_f32_r,
  55. algo,
  56. 0,
  57. 0);
  58. #else
  59. cublasStatus_t status = cublasGemmEx(handle,
  60. transa,
  61. transb,
  62. m,
  63. n,
  64. k,
  65. (const void*)alpha,
  66. (const void*)A,
  67. CUDA_R_32F,
  68. (transa == CUBLAS_OP_N) ? m : k,
  69. (const void*)B,
  70. CUDA_R_32F,
  71. (transb == CUBLAS_OP_N) ? k : n,
  72. (const void*)beta,
  73. C,
  74. CUDA_R_32F,
  75. m,
  76. CUDA_R_32F,
  77. algo);
  78. #endif
  79. #ifdef __HIP_PLATFORM_HCC__
  80. if (status != rocblas_status_success) {
  81. #else
  82. if (status != CUBLAS_STATUS_SUCCESS) {
  83. #endif
  84. fprintf(stderr,
  85. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  86. m,
  87. n,
  88. k,
  89. (int)status);
  90. return EXIT_FAILURE;
  91. }
  92. return 0;
  93. }
  94. #ifdef __HIP_PLATFORM_HCC__
  95. int cublas_gemm_ex(rocblas_handle handle,
  96. rocblas_operation transa,
  97. rocblas_operation transb,
  98. int m,
  99. int n,
  100. int k,
  101. const float* alpha,
  102. const float* beta,
  103. const __half* A,
  104. const __half* B,
  105. __half* C,
  106. rocblas_gemm_algo algo)
  107. #else
  108. int cublas_gemm_ex(cublasHandle_t handle,
  109. cublasOperation_t transa,
  110. cublasOperation_t transb,
  111. int m,
  112. int n,
  113. int k,
  114. const float* alpha,
  115. const float* beta,
  116. const __half* A,
  117. const __half* B,
  118. __half* C,
  119. cublasGemmAlgo_t algo)
  120. #endif
  121. {
  122. #ifdef __HIP_PLATFORM_HCC__
  123. rocblas_status status = rocblas_gemm_ex(handle,
  124. transa,
  125. transb,
  126. m,
  127. n,
  128. k,
  129. (const void*)alpha,
  130. (const void*)A,
  131. rocblas_datatype_f16_r,
  132. (transa == rocblas_operation_none) ? m : k,
  133. (const void*)B,
  134. rocblas_datatype_f16_r,
  135. (transb == rocblas_operation_none) ? k : n,
  136. (const void*)beta,
  137. (void*)C,
  138. rocblas_datatype_f16_r,
  139. m,
  140. (void*)C,
  141. rocblas_datatype_f16_r,
  142. m,
  143. rocblas_datatype_f32_r,
  144. algo,
  145. 0,
  146. 0);
  147. #else
  148. cublasStatus_t status = cublasGemmEx(handle,
  149. transa,
  150. transb,
  151. m,
  152. n,
  153. k,
  154. (const void*)alpha,
  155. (const void*)A,
  156. CUDA_R_16F,
  157. (transa == CUBLAS_OP_N) ? m : k,
  158. (const void*)B,
  159. CUDA_R_16F,
  160. (transb == CUBLAS_OP_N) ? k : n,
  161. (const void*)beta,
  162. (void*)C,
  163. CUDA_R_16F,
  164. m,
  165. CUDA_R_32F,
  166. algo);
  167. #endif
  168. #ifdef __HIP_PLATFORM_HCC__
  169. if (status != rocblas_status_success) {
  170. #else
  171. if (status != CUBLAS_STATUS_SUCCESS) {
  172. #endif
  173. fprintf(stderr,
  174. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  175. m,
  176. n,
  177. k,
  178. (int)status);
  179. return EXIT_FAILURE;
  180. }
  181. return 0;
  182. }
  183. #ifdef __HIP_PLATFORM_HCC__
  184. int cublas_strided_batched_gemm(rocblas_handle handle,
  185. int m,
  186. int n,
  187. int k,
  188. const float* alpha,
  189. const float* beta,
  190. const float* A,
  191. const float* B,
  192. float* C,
  193. rocblas_operation op_A,
  194. rocblas_operation op_B,
  195. int stride_A,
  196. int stride_B,
  197. int stride_C,
  198. int batch,
  199. rocblas_gemm_algo algo)
  200. #else
  201. int cublas_strided_batched_gemm(cublasHandle_t handle,
  202. int m,
  203. int n,
  204. int k,
  205. const float* alpha,
  206. const float* beta,
  207. const float* A,
  208. const float* B,
  209. float* C,
  210. cublasOperation_t op_A,
  211. cublasOperation_t op_B,
  212. int stride_A,
  213. int stride_B,
  214. int stride_C,
  215. int batch,
  216. cublasGemmAlgo_t algo)
  217. #endif
  218. {
  219. #ifdef __HIP_PLATFORM_HCC__
  220. rocblas_status status =
  221. rocblas_gemm_strided_batched_ex(handle,
  222. op_A,
  223. op_B,
  224. m,
  225. n,
  226. k,
  227. alpha,
  228. A,
  229. rocblas_datatype_f32_r,
  230. (op_A == rocblas_operation_none) ? m : k,
  231. stride_A,
  232. B,
  233. rocblas_datatype_f32_r,
  234. (op_B == rocblas_operation_none) ? k : n,
  235. stride_B,
  236. beta,
  237. C,
  238. rocblas_datatype_f32_r,
  239. m,
  240. stride_C,
  241. C,
  242. rocblas_datatype_f32_r,
  243. m,
  244. stride_C,
  245. batch,
  246. rocblas_datatype_f32_r,
  247. algo,
  248. 0,
  249. 0);
  250. #else
  251. cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
  252. op_A,
  253. op_B,
  254. m,
  255. n,
  256. k,
  257. alpha,
  258. A,
  259. CUDA_R_32F,
  260. (op_A == CUBLAS_OP_N) ? m : k,
  261. stride_A,
  262. B,
  263. CUDA_R_32F,
  264. (op_B == CUBLAS_OP_N) ? k : n,
  265. stride_B,
  266. beta,
  267. C,
  268. CUDA_R_32F,
  269. m,
  270. stride_C,
  271. batch,
  272. CUDA_R_32F,
  273. algo);
  274. #endif
  275. #ifdef __HIP_PLATFORM_HCC__
  276. if (status != rocblas_status_success) {
  277. #else
  278. if (status != CUBLAS_STATUS_SUCCESS) {
  279. #endif
  280. fprintf(stderr,
  281. "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
  282. batch,
  283. m,
  284. n,
  285. k,
  286. (int)status);
  287. return EXIT_FAILURE;
  288. }
  289. return 0;
  290. }
  291. #ifdef __HIP_PLATFORM_HCC__
  292. int cublas_strided_batched_gemm(rocblas_handle handle,
  293. int m,
  294. int n,
  295. int k,
  296. const float* alpha,
  297. const float* beta,
  298. const __half* A,
  299. const __half* B,
  300. __half* C,
  301. rocblas_operation op_A,
  302. rocblas_operation op_B,
  303. int stride_A,
  304. int stride_B,
  305. int stride_C,
  306. int batch,
  307. rocblas_gemm_algo algo)
  308. #else
  309. int cublas_strided_batched_gemm(cublasHandle_t handle,
  310. int m,
  311. int n,
  312. int k,
  313. const float* alpha,
  314. const float* beta,
  315. const __half* A,
  316. const __half* B,
  317. __half* C,
  318. cublasOperation_t op_A,
  319. cublasOperation_t op_B,
  320. int stride_A,
  321. int stride_B,
  322. int stride_C,
  323. int batch,
  324. cublasGemmAlgo_t algo)
  325. #endif
  326. {
  327. #ifdef __HIP_PLATFORM_HCC__
  328. rocblas_status status =
  329. rocblas_gemm_strided_batched_ex(handle,
  330. op_A,
  331. op_B,
  332. m,
  333. n,
  334. k,
  335. alpha,
  336. A,
  337. rocblas_datatype_f16_r,
  338. (op_A == rocblas_operation_none) ? m : k,
  339. stride_A,
  340. B,
  341. rocblas_datatype_f16_r,
  342. (op_B == rocblas_operation_none) ? k : n,
  343. stride_B,
  344. beta,
  345. C,
  346. rocblas_datatype_f16_r,
  347. m,
  348. stride_C,
  349. C,
  350. rocblas_datatype_f16_r,
  351. m,
  352. stride_C,
  353. batch,
  354. rocblas_datatype_f32_r,
  355. algo,
  356. 0,
  357. 0);
  358. #else
  359. cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
  360. op_A,
  361. op_B,
  362. m,
  363. n,
  364. k,
  365. alpha,
  366. A,
  367. CUDA_R_16F,
  368. (op_A == CUBLAS_OP_N) ? m : k,
  369. stride_A,
  370. B,
  371. CUDA_R_16F,
  372. (op_B == CUBLAS_OP_N) ? k : n,
  373. stride_B,
  374. beta,
  375. C,
  376. CUDA_R_16F,
  377. m,
  378. stride_C,
  379. batch,
  380. CUDA_R_32F,
  381. algo);
  382. #endif
  383. #ifdef __HIP_PLATFORM_HCC__
  384. if (status != rocblas_status_success) {
  385. #else
  386. if (status != CUBLAS_STATUS_SUCCESS) {
  387. #endif
  388. fprintf(stderr,
  389. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  390. m,
  391. n,
  392. k,
  393. (int)status);
  394. return EXIT_FAILURE;
  395. }
  396. return 0;
  397. }