cublas_wrappers.h 8.8 KB


  1. #pragma once
  2. #include <assert.h>
  3. #include <cublas_v2.h>
  4. #include <cuda.h>
  5. #include <cuda_fp16.h>
  6. #include <cuda_runtime.h>
  7. #include <mma.h>
  8. #include <stdio.h>
  9. #include "cublas_wrappers.h"
  10. int cublas_gemm_ex(cublasHandle_t handle,
  11. cublasOperation_t transa,
  12. cublasOperation_t transb,
  13. int m,
  14. int n,
  15. int k,
  16. const float* alpha,
  17. const float* beta,
  18. const float* A,
  19. const float* B,
  20. float* C,
  21. cublasGemmAlgo_t algo)
  22. {
  23. cublasStatus_t status = cublasGemmEx(handle,
  24. transa,
  25. transb,
  26. m,
  27. n,
  28. k,
  29. (const void*)alpha,
  30. (const void*)A,
  31. CUDA_R_32F,
  32. (transa == CUBLAS_OP_N) ? m : k,
  33. (const void*)B,
  34. CUDA_R_32F,
  35. (transb == CUBLAS_OP_N) ? k : n,
  36. (const void*)beta,
  37. C,
  38. CUDA_R_32F,
  39. m,
  40. CUDA_R_32F,
  41. algo);
  42. if (status != CUBLAS_STATUS_SUCCESS) {
  43. fprintf(stderr,
  44. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  45. m,
  46. n,
  47. k,
  48. (int)status);
  49. return EXIT_FAILURE;
  50. }
  51. return 0;
  52. }
  53. int cublas_gemm_ex(cublasHandle_t handle,
  54. cublasOperation_t transa,
  55. cublasOperation_t transb,
  56. int m,
  57. int n,
  58. int k,
  59. const float* alpha,
  60. const float* beta,
  61. const __half* A,
  62. const __half* B,
  63. __half* C,
  64. cublasGemmAlgo_t algo)
  65. {
  66. cublasStatus_t status = cublasGemmEx(handle,
  67. transa,
  68. transb,
  69. m,
  70. n,
  71. k,
  72. (const void*)alpha,
  73. (const void*)A,
  74. CUDA_R_16F,
  75. (transa == CUBLAS_OP_N) ? m : k,
  76. (const void*)B,
  77. CUDA_R_16F,
  78. (transb == CUBLAS_OP_N) ? k : n,
  79. (const void*)beta,
  80. (void*)C,
  81. CUDA_R_16F,
  82. m,
  83. CUDA_R_32F,
  84. algo);
  85. if (status != CUBLAS_STATUS_SUCCESS) {
  86. fprintf(stderr,
  87. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  88. m,
  89. n,
  90. k,
  91. (int)status);
  92. return EXIT_FAILURE;
  93. }
  94. return 0;
  95. }
  96. int cublas_strided_batched_gemm(cublasHandle_t handle,
  97. int m,
  98. int n,
  99. int k,
  100. const float* alpha,
  101. const float* beta,
  102. const float* A,
  103. const float* B,
  104. float* C,
  105. cublasOperation_t op_A,
  106. cublasOperation_t op_B,
  107. int stride_A,
  108. int stride_B,
  109. int stride_C,
  110. int batch,
  111. cublasGemmAlgo_t algo)
  112. {
  113. cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
  114. op_A,
  115. op_B,
  116. m,
  117. n,
  118. k,
  119. alpha,
  120. A,
  121. CUDA_R_32F,
  122. (op_A == CUBLAS_OP_N) ? m : k,
  123. stride_A,
  124. B,
  125. CUDA_R_32F,
  126. (op_B == CUBLAS_OP_N) ? k : n,
  127. stride_B,
  128. beta,
  129. C,
  130. CUDA_R_32F,
  131. m,
  132. stride_C,
  133. batch,
  134. CUDA_R_32F,
  135. algo);
  136. if (status != CUBLAS_STATUS_SUCCESS) {
  137. fprintf(stderr,
  138. "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
  139. batch,
  140. m,
  141. n,
  142. k,
  143. (int)status);
  144. return EXIT_FAILURE;
  145. }
  146. return 0;
  147. }
  148. int cublas_strided_batched_gemm(cublasHandle_t handle,
  149. int m,
  150. int n,
  151. int k,
  152. const float* alpha,
  153. const float* beta,
  154. const __half* A,
  155. const __half* B,
  156. __half* C,
  157. cublasOperation_t op_A,
  158. cublasOperation_t op_B,
  159. int stride_A,
  160. int stride_B,
  161. int stride_C,
  162. int batch,
  163. cublasGemmAlgo_t algo)
  164. {
  165. cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
  166. op_A,
  167. op_B,
  168. m,
  169. n,
  170. k,
  171. alpha,
  172. A,
  173. CUDA_R_16F,
  174. (op_A == CUBLAS_OP_N) ? m : k,
  175. stride_A,
  176. B,
  177. CUDA_R_16F,
  178. (op_B == CUBLAS_OP_N) ? k : n,
  179. stride_B,
  180. beta,
  181. C,
  182. CUDA_R_16F,
  183. m,
  184. stride_C,
  185. batch,
  186. CUDA_R_32F,
  187. algo);
  188. if (status != CUBLAS_STATUS_SUCCESS) {
  189. fprintf(stderr,
  190. "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
  191. m,
  192. n,
  193. k,
  194. (int)status);
  195. return EXIT_FAILURE;
  196. }
  197. return 0;
  198. }