cublas_wrappers.cu 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. #include "cublas_wrappers.h"
  2. int cublas_gemm_ex(cublasHandle_t handle,
  3. cublasOperation_t transa,
  4. cublasOperation_t transb,
  5. int m,
  6. int n,
  7. int k,
  8. const float* alpha,
  9. const float* beta,
  10. const float* A,
  11. const float* B,
  12. float* C,
  13. cublasGemmAlgo_t algo)
  14. {
  15. cublasStatus_t status = cublasGemmEx(handle,
  16. transa,
  17. transb,
  18. m,
  19. n,
  20. k,
  21. (const void*)alpha,
  22. (const void*)A,
  23. CUDA_R_32F,
  24. (transa == CUBLAS_OP_N) ? m : k,
  25. (const void*)B,
  26. CUDA_R_32F,
  27. (transb == CUBLAS_OP_N) ? k : n,
  28. (const void*)beta,
  29. C,
  30. CUDA_R_32F,
  31. m,
  32. CUDA_R_32F,
  33. algo);
  34. if (status != CUBLAS_STATUS_SUCCESS) {
  35. fprintf(stderr,
  36. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  37. m,
  38. n,
  39. k,
  40. (int)status);
  41. return EXIT_FAILURE;
  42. }
  43. return 0;
  44. }
  45. int cublas_gemm_ex(cublasHandle_t handle,
  46. cublasOperation_t transa,
  47. cublasOperation_t transb,
  48. int m,
  49. int n,
  50. int k,
  51. const float* alpha,
  52. const float* beta,
  53. const __half* A,
  54. const __half* B,
  55. __half* C,
  56. cublasGemmAlgo_t algo)
  57. {
  58. cublasStatus_t status = cublasGemmEx(handle,
  59. transa,
  60. transb,
  61. m,
  62. n,
  63. k,
  64. (const void*)alpha,
  65. (const void*)A,
  66. CUDA_R_16F,
  67. (transa == CUBLAS_OP_N) ? m : k,
  68. (const void*)B,
  69. CUDA_R_16F,
  70. (transb == CUBLAS_OP_N) ? k : n,
  71. (const void*)beta,
  72. (void*)C,
  73. CUDA_R_16F,
  74. m,
  75. CUDA_R_32F,
  76. algo);
  77. if (status != CUBLAS_STATUS_SUCCESS) {
  78. fprintf(stderr,
  79. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  80. m,
  81. n,
  82. k,
  83. (int)status);
  84. return EXIT_FAILURE;
  85. }
  86. return 0;
  87. }
  88. int cublas_strided_batched_gemm(cublasHandle_t handle,
  89. int m,
  90. int n,
  91. int k,
  92. const float* alpha,
  93. const float* beta,
  94. const float* A,
  95. const float* B,
  96. float* C,
  97. cublasOperation_t op_A,
  98. cublasOperation_t op_B,
  99. int stride_A,
  100. int stride_B,
  101. int stride_C,
  102. int batch,
  103. cublasGemmAlgo_t algo)
  104. {
  105. cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
  106. op_A,
  107. op_B,
  108. m,
  109. n,
  110. k,
  111. alpha,
  112. A,
  113. CUDA_R_32F,
  114. (op_A == CUBLAS_OP_N) ? m : k,
  115. stride_A,
  116. B,
  117. CUDA_R_32F,
  118. (op_B == CUBLAS_OP_N) ? k : n,
  119. stride_B,
  120. beta,
  121. C,
  122. CUDA_R_32F,
  123. m,
  124. stride_C,
  125. batch,
  126. CUDA_R_32F,
  127. algo);
  128. if (status != CUBLAS_STATUS_SUCCESS) {
  129. fprintf(stderr,
  130. "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
  131. batch,
  132. m,
  133. n,
  134. k,
  135. (int)status);
  136. return EXIT_FAILURE;
  137. }
  138. return 0;
  139. }
  140. int cublas_strided_batched_gemm(cublasHandle_t handle,
  141. int m,
  142. int n,
  143. int k,
  144. const float* alpha,
  145. const float* beta,
  146. const __half* A,
  147. const __half* B,
  148. __half* C,
  149. cublasOperation_t op_A,
  150. cublasOperation_t op_B,
  151. int stride_A,
  152. int stride_B,
  153. int stride_C,
  154. int batch,
  155. cublasGemmAlgo_t algo)
  156. {
  157. cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
  158. op_A,
  159. op_B,
  160. m,
  161. n,
  162. k,
  163. alpha,
  164. A,
  165. CUDA_R_16F,
  166. (op_A == CUBLAS_OP_N) ? m : k,
  167. stride_A,
  168. B,
  169. CUDA_R_16F,
  170. (op_B == CUBLAS_OP_N) ? k : n,
  171. stride_B,
  172. beta,
  173. C,
  174. CUDA_R_16F,
  175. m,
  176. stride_C,
  177. batch,
  178. CUDA_R_32F,
  179. algo);
  180. if (status != CUBLAS_STATUS_SUCCESS) {
  181. fprintf(stderr,
  182. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  183. m,
  184. n,
  185. k,
  186. (int)status);
  187. return EXIT_FAILURE;
  188. }
  189. return 0;
  190. }