cublas_wrappers.h 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <assert.h>
  6. #include <cublas_v2.h>
  7. #include <cuda.h>
  8. #include <cuda_fp16.h>
  9. #include <cuda_runtime.h>
  10. #ifndef __HIP_PLATFORM_HCC__
  11. #include <mma.h>
  12. #endif
  13. #include <stdio.h>
  14. int cublas_gemm_ex(cublasHandle_t handle,
  15. cublasOperation_t transa,
  16. cublasOperation_t transb,
  17. int m,
  18. int n,
  19. int k,
  20. const float* alpha,
  21. const float* beta,
  22. const float* A,
  23. const float* B,
  24. float* C,
  25. #ifdef __HIP_PLATFORM_HCC__
  26. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  27. #else
  28. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
  29. #endif
  30. int cublas_gemm_ex(cublasHandle_t handle,
  31. cublasOperation_t transa,
  32. cublasOperation_t transb,
  33. int m,
  34. int n,
  35. int k,
  36. const float* alpha,
  37. const float* beta,
  38. const __half* A,
  39. const __half* B,
  40. __half* C,
  41. #ifdef __HIP_PLATFORM_HCC__
  42. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  43. #else
  44. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  45. #endif
  46. int cublas_strided_batched_gemm(cublasHandle_t handle,
  47. int m,
  48. int n,
  49. int k,
  50. const float* alpha,
  51. const float* beta,
  52. const float* A,
  53. const float* B,
  54. float* C,
  55. cublasOperation_t op_A,
  56. cublasOperation_t op_B,
  57. int stride_A,
  58. int stride_B,
  59. int stride_C,
  60. int batch,
  61. #ifdef __HIP_PLATFORM_HCC__
  62. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  63. #else
  64. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
  65. #endif
  66. int cublas_strided_batched_gemm(cublasHandle_t handle,
  67. int m,
  68. int n,
  69. int k,
  70. const float* alpha,
  71. const float* beta,
  72. const __half* A,
  73. const __half* B,
  74. __half* C,
  75. cublasOperation_t op_A,
  76. cublasOperation_t op_B,
  77. int stride_A,
  78. int stride_B,
  79. int stride_C,
  80. int batch,
  81. #ifdef __HIP_PLATFORM_HCC__
  82. rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
  83. #else
  84. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  85. #endif