cublas_wrappers.h 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #pragma once
  2. #include <assert.h>
  3. #include <cublas_v2.h>
  4. #include <cuda.h>
  5. #include <cuda_fp16.h>
  6. #include <cuda_runtime.h>
  7. #ifndef __HIP_PLATFORM_HCC__
  8. #include <mma.h>
  9. #endif
  10. #include <stdio.h>
  11. int cublas_gemm_ex(cublasHandle_t handle,
  12. cublasOperation_t transa,
  13. cublasOperation_t transb,
  14. int m,
  15. int n,
  16. int k,
  17. const float* alpha,
  18. const float* beta,
  19. const float* A,
  20. const float* B,
  21. float* C,
  22. #ifdef __HIP_PLATFORM_HCC__
  23. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  24. #else
  25. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
  26. #endif
  27. int cublas_gemm_ex(cublasHandle_t handle,
  28. cublasOperation_t transa,
  29. cublasOperation_t transb,
  30. int m,
  31. int n,
  32. int k,
  33. const float* alpha,
  34. const float* beta,
  35. const __half* A,
  36. const __half* B,
  37. __half* C,
  38. #ifdef __HIP_PLATFORM_HCC__
  39. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  40. #else
  41. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  42. #endif
  43. int cublas_strided_batched_gemm(cublasHandle_t handle,
  44. int m,
  45. int n,
  46. int k,
  47. const float* alpha,
  48. const float* beta,
  49. const float* A,
  50. const float* B,
  51. float* C,
  52. cublasOperation_t op_A,
  53. cublasOperation_t op_B,
  54. int stride_A,
  55. int stride_B,
  56. int stride_C,
  57. int batch,
  58. #ifdef __HIP_PLATFORM_HCC__
  59. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  60. #else
  61. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
  62. #endif
  63. int cublas_strided_batched_gemm(cublasHandle_t handle,
  64. int m,
  65. int n,
  66. int k,
  67. const float* alpha,
  68. const float* beta,
  69. const __half* A,
  70. const __half* B,
  71. __half* C,
  72. cublasOperation_t op_A,
  73. cublasOperation_t op_B,
  74. int stride_A,
  75. int stride_B,
  76. int stride_C,
  77. int batch,
  78. #ifdef __HIP_PLATFORM_HCC__
  79. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  80. #else
  81. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  82. #endif