cublas_wrappers.cu 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. #include "cublas_wrappers.h"
  2. #ifdef __HIP_PLATFORM_HCC__
  3. int cublas_gemm_ex(rocblas_handle handle,
  4. rocblas_operation transa,
  5. rocblas_operation transb,
  6. int m,
  7. int n,
  8. int k,
  9. const float* alpha,
  10. const float* beta,
  11. const float* A,
  12. const float* B,
  13. float* C,
  14. rocblas_gemm_algo algo)
  15. #else
  16. int cublas_gemm_ex(cublasHandle_t handle,
  17. cublasOperation_t transa,
  18. cublasOperation_t transb,
  19. int m,
  20. int n,
  21. int k,
  22. const float* alpha,
  23. const float* beta,
  24. const float* A,
  25. const float* B,
  26. float* C,
  27. cublasGemmAlgo_t algo)
  28. #endif
  29. {
  30. #ifdef __HIP_PLATFORM_HCC__
  31. rocblas_status status = rocblas_gemm_ex(handle,
  32. transa,
  33. transb,
  34. m,
  35. n,
  36. k,
  37. (const void*)alpha,
  38. (const void*)A,
  39. rocblas_datatype_f32_r,
  40. (transa == rocblas_operation_none) ? m : k,
  41. (const void*)B,
  42. rocblas_datatype_f32_r,
  43. (transb == rocblas_operation_none) ? k : n,
  44. (const void*)beta,
  45. C,
  46. rocblas_datatype_f32_r,
  47. m,
  48. C,
  49. rocblas_datatype_f32_r,
  50. m,
  51. rocblas_datatype_f32_r,
  52. algo,
  53. 0,
  54. 0);
  55. #else
  56. cublasStatus_t status = cublasGemmEx(handle,
  57. transa,
  58. transb,
  59. m,
  60. n,
  61. k,
  62. (const void*)alpha,
  63. (const void*)A,
  64. CUDA_R_32F,
  65. (transa == CUBLAS_OP_N) ? m : k,
  66. (const void*)B,
  67. CUDA_R_32F,
  68. (transb == CUBLAS_OP_N) ? k : n,
  69. (const void*)beta,
  70. C,
  71. CUDA_R_32F,
  72. m,
  73. CUDA_R_32F,
  74. algo);
  75. #endif
  76. #ifdef __HIP_PLATFORM_HCC__
  77. if (status != rocblas_status_success) {
  78. #else
  79. if (status != CUBLAS_STATUS_SUCCESS) {
  80. #endif
  81. fprintf(stderr,
  82. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  83. m,
  84. n,
  85. k,
  86. (int)status);
  87. return EXIT_FAILURE;
  88. }
  89. return 0;
  90. }
  91. #ifdef __HIP_PLATFORM_HCC__
  92. int cublas_gemm_ex(rocblas_handle handle,
  93. rocblas_operation transa,
  94. rocblas_operation transb,
  95. int m,
  96. int n,
  97. int k,
  98. const float* alpha,
  99. const float* beta,
  100. const __half* A,
  101. const __half* B,
  102. __half* C,
  103. rocblas_gemm_algo algo)
  104. #else
  105. int cublas_gemm_ex(cublasHandle_t handle,
  106. cublasOperation_t transa,
  107. cublasOperation_t transb,
  108. int m,
  109. int n,
  110. int k,
  111. const float* alpha,
  112. const float* beta,
  113. const __half* A,
  114. const __half* B,
  115. __half* C,
  116. cublasGemmAlgo_t algo)
  117. #endif
  118. {
  119. #ifdef __HIP_PLATFORM_HCC__
  120. rocblas_status status = rocblas_gemm_ex(handle,
  121. transa,
  122. transb,
  123. m,
  124. n,
  125. k,
  126. (const void*)alpha,
  127. (const void*)A,
  128. rocblas_datatype_f16_r,
  129. (transa == rocblas_operation_none) ? m : k,
  130. (const void*)B,
  131. rocblas_datatype_f16_r,
  132. (transb == rocblas_operation_none) ? k : n,
  133. (const void*)beta,
  134. (void*)C,
  135. rocblas_datatype_f16_r,
  136. m,
  137. (void*)C,
  138. rocblas_datatype_f16_r,
  139. m,
  140. rocblas_datatype_f32_r,
  141. algo,
  142. 0,
  143. 0);
  144. #else
  145. cublasStatus_t status = cublasGemmEx(handle,
  146. transa,
  147. transb,
  148. m,
  149. n,
  150. k,
  151. (const void*)alpha,
  152. (const void*)A,
  153. CUDA_R_16F,
  154. (transa == CUBLAS_OP_N) ? m : k,
  155. (const void*)B,
  156. CUDA_R_16F,
  157. (transb == CUBLAS_OP_N) ? k : n,
  158. (const void*)beta,
  159. (void*)C,
  160. CUDA_R_16F,
  161. m,
  162. CUDA_R_32F,
  163. algo);
  164. #endif
  165. #ifdef __HIP_PLATFORM_HCC__
  166. if (status != rocblas_status_success) {
  167. #else
  168. if (status != CUBLAS_STATUS_SUCCESS) {
  169. #endif
  170. fprintf(stderr,
  171. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  172. m,
  173. n,
  174. k,
  175. (int)status);
  176. return EXIT_FAILURE;
  177. }
  178. return 0;
  179. }
  180. #ifdef __HIP_PLATFORM_HCC__
  181. int cublas_strided_batched_gemm(rocblas_handle handle,
  182. int m,
  183. int n,
  184. int k,
  185. const float* alpha,
  186. const float* beta,
  187. const float* A,
  188. const float* B,
  189. float* C,
  190. rocblas_operation op_A,
  191. rocblas_operation op_B,
  192. int stride_A,
  193. int stride_B,
  194. int stride_C,
  195. int batch,
  196. rocblas_gemm_algo algo)
  197. #else
  198. int cublas_strided_batched_gemm(cublasHandle_t handle,
  199. int m,
  200. int n,
  201. int k,
  202. const float* alpha,
  203. const float* beta,
  204. const float* A,
  205. const float* B,
  206. float* C,
  207. cublasOperation_t op_A,
  208. cublasOperation_t op_B,
  209. int stride_A,
  210. int stride_B,
  211. int stride_C,
  212. int batch,
  213. cublasGemmAlgo_t algo)
  214. #endif
  215. {
  216. #ifdef __HIP_PLATFORM_HCC__
  217. rocblas_status status =
  218. rocblas_gemm_strided_batched_ex(handle,
  219. op_A,
  220. op_B,
  221. m,
  222. n,
  223. k,
  224. alpha,
  225. A,
  226. rocblas_datatype_f32_r,
  227. (op_A == rocblas_operation_none) ? m : k,
  228. stride_A,
  229. B,
  230. rocblas_datatype_f32_r,
  231. (op_B == rocblas_operation_none) ? k : n,
  232. stride_B,
  233. beta,
  234. C,
  235. rocblas_datatype_f32_r,
  236. m,
  237. stride_C,
  238. C,
  239. rocblas_datatype_f32_r,
  240. m,
  241. stride_C,
  242. batch,
  243. rocblas_datatype_f32_r,
  244. algo,
  245. 0,
  246. 0);
  247. #else
  248. cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
  249. op_A,
  250. op_B,
  251. m,
  252. n,
  253. k,
  254. alpha,
  255. A,
  256. CUDA_R_32F,
  257. (op_A == CUBLAS_OP_N) ? m : k,
  258. stride_A,
  259. B,
  260. CUDA_R_32F,
  261. (op_B == CUBLAS_OP_N) ? k : n,
  262. stride_B,
  263. beta,
  264. C,
  265. CUDA_R_32F,
  266. m,
  267. stride_C,
  268. batch,
  269. CUDA_R_32F,
  270. algo);
  271. #endif
  272. #ifdef __HIP_PLATFORM_HCC__
  273. if (status != rocblas_status_success) {
  274. #else
  275. if (status != CUBLAS_STATUS_SUCCESS) {
  276. #endif
  277. fprintf(stderr,
  278. "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
  279. batch,
  280. m,
  281. n,
  282. k,
  283. (int)status);
  284. return EXIT_FAILURE;
  285. }
  286. return 0;
  287. }
  288. #ifdef __HIP_PLATFORM_HCC__
  289. int cublas_strided_batched_gemm(rocblas_handle handle,
  290. int m,
  291. int n,
  292. int k,
  293. const float* alpha,
  294. const float* beta,
  295. const __half* A,
  296. const __half* B,
  297. __half* C,
  298. rocblas_operation op_A,
  299. rocblas_operation op_B,
  300. int stride_A,
  301. int stride_B,
  302. int stride_C,
  303. int batch,
  304. rocblas_gemm_algo algo)
  305. #else
  306. int cublas_strided_batched_gemm(cublasHandle_t handle,
  307. int m,
  308. int n,
  309. int k,
  310. const float* alpha,
  311. const float* beta,
  312. const __half* A,
  313. const __half* B,
  314. __half* C,
  315. cublasOperation_t op_A,
  316. cublasOperation_t op_B,
  317. int stride_A,
  318. int stride_B,
  319. int stride_C,
  320. int batch,
  321. cublasGemmAlgo_t algo)
  322. #endif
  323. {
  324. #ifdef __HIP_PLATFORM_HCC__
  325. rocblas_status status =
  326. rocblas_gemm_strided_batched_ex(handle,
  327. op_A,
  328. op_B,
  329. m,
  330. n,
  331. k,
  332. alpha,
  333. A,
  334. rocblas_datatype_f16_r,
  335. (op_A == rocblas_operation_none) ? m : k,
  336. stride_A,
  337. B,
  338. rocblas_datatype_f16_r,
  339. (op_B == rocblas_operation_none) ? k : n,
  340. stride_B,
  341. beta,
  342. C,
  343. rocblas_datatype_f16_r,
  344. m,
  345. stride_C,
  346. C,
  347. rocblas_datatype_f16_r,
  348. m,
  349. stride_C,
  350. batch,
  351. rocblas_datatype_f32_r,
  352. algo,
  353. 0,
  354. 0);
  355. #else
  356. cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
  357. op_A,
  358. op_B,
  359. m,
  360. n,
  361. k,
  362. alpha,
  363. A,
  364. CUDA_R_16F,
  365. (op_A == CUBLAS_OP_N) ? m : k,
  366. stride_A,
  367. B,
  368. CUDA_R_16F,
  369. (op_B == CUBLAS_OP_N) ? k : n,
  370. stride_B,
  371. beta,
  372. C,
  373. CUDA_R_16F,
  374. m,
  375. stride_C,
  376. batch,
  377. CUDA_R_32F,
  378. algo);
  379. #endif
  380. #ifdef __HIP_PLATFORM_HCC__
  381. if (status != rocblas_status_success) {
  382. #else
  383. if (status != CUBLAS_STATUS_SUCCESS) {
  384. #endif
  385. fprintf(stderr,
  386. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  387. m,
  388. n,
  389. k,
  390. (int)status);
  391. return EXIT_FAILURE;
  392. }
  393. return 0;
  394. }