cublas_wrappers.h 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #pragma once
  2. #include <assert.h>
  3. #include <cublas_v2.h>
  4. #include <cuda.h>
  5. #include <cuda_fp16.h>
  6. #include <cuda_runtime.h>
  7. #include <mma.h>
  8. #include <stdio.h>
  9. int cublas_gemm_ex(cublasHandle_t handle,
  10. cublasOperation_t transa,
  11. cublasOperation_t transb,
  12. int m,
  13. int n,
  14. int k,
  15. const float* alpha,
  16. const float* beta,
  17. const float* A,
  18. const float* B,
  19. float* C,
  20. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
  21. int cublas_gemm_ex(cublasHandle_t handle,
  22. cublasOperation_t transa,
  23. cublasOperation_t transb,
  24. int m,
  25. int n,
  26. int k,
  27. const float* alpha,
  28. const float* beta,
  29. const __half* A,
  30. const __half* B,
  31. __half* C,
  32. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  33. int cublas_strided_batched_gemm(cublasHandle_t handle,
  34. int m,
  35. int n,
  36. int k,
  37. const float* alpha,
  38. const float* beta,
  39. const float* A,
  40. const float* B,
  41. float* C,
  42. cublasOperation_t op_A,
  43. cublasOperation_t op_B,
  44. int stride_A,
  45. int stride_B,
  46. int stride_C,
  47. int batch,
  48. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
  49. int cublas_strided_batched_gemm(cublasHandle_t handle,
  50. int m,
  51. int n,
  52. int k,
  53. const float* alpha,
  54. const float* beta,
  55. const __half* A,
  56. const __half* B,
  57. __half* C,
  58. cublasOperation_t op_A,
  59. cublasOperation_t op_B,
  60. int stride_A,
  61. int stride_B,
  62. int stride_C,
  63. int batch,
  64. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);