cublas_wrappers.h 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <assert.h>
  6. #include <cublas_v2.h>
  7. #include <cuda.h>
  8. #include <cuda_fp16.h>
  9. #include <cuda_runtime.h>
  10. #ifndef __HIP_PLATFORM_AMD__
  11. #include <mma.h>
  12. #endif
  13. #ifdef __HIP_PLATFORM_AMD__
  14. #include <rocblas/rocblas.h>
  15. #endif
  16. #include <stdio.h>
  17. int cublas_gemm_ex(cublasHandle_t handle,
  18. cublasOperation_t transa,
  19. cublasOperation_t transb,
  20. int m,
  21. int n,
  22. int k,
  23. const float* alpha,
  24. const float* beta,
  25. const float* A,
  26. const float* B,
  27. float* C,
  28. #ifdef __HIP_PLATFORM_AMD__
  29. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  30. #else
  31. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
  32. #endif
  33. int cublas_gemm_ex(cublasHandle_t handle,
  34. cublasOperation_t transa,
  35. cublasOperation_t transb,
  36. int m,
  37. int n,
  38. int k,
  39. const float* alpha,
  40. const float* beta,
  41. const __half* A,
  42. const __half* B,
  43. __half* C,
  44. #ifdef __HIP_PLATFORM_AMD__
  45. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  46. #else
  47. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  48. #endif
  49. int cublas_strided_batched_gemm(cublasHandle_t handle,
  50. int m,
  51. int n,
  52. int k,
  53. const float* alpha,
  54. const float* beta,
  55. const float* A,
  56. const float* B,
  57. float* C,
  58. cublasOperation_t op_A,
  59. cublasOperation_t op_B,
  60. int stride_A,
  61. int stride_B,
  62. int stride_C,
  63. int batch,
  64. #ifdef __HIP_PLATFORM_AMD__
  65. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  66. #else
  67. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
  68. #endif
  69. int cublas_strided_batched_gemm(cublasHandle_t handle,
  70. int m,
  71. int n,
  72. int k,
  73. const float* alpha,
  74. const float* beta,
  75. const __half* A,
  76. const __half* B,
  77. __half* C,
  78. cublasOperation_t op_A,
  79. cublasOperation_t op_B,
  80. int stride_A,
  81. int stride_B,
  82. int stride_C,
  83. int batch,
  84. #ifdef __HIP_PLATFORM_AMD__
  85. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  86. #else
  87. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  88. #endif