123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- #pragma once
- #include <assert.h>
- #include <cublas_v2.h>
- #include <cuda.h>
- #include <cuda_fp16.h>
- #include <cuda_runtime.h>
- #include <mma.h>
- #include <stdio.h>
- int cublas_gemm_ex(cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int m,
- int n,
- int k,
- const float* alpha,
- const float* beta,
- const float* A,
- const float* B,
- float* C,
- cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
- int cublas_gemm_ex(cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int m,
- int n,
- int k,
- const float* alpha,
- const float* beta,
- const __half* A,
- const __half* B,
- __half* C,
- cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
- int cublas_strided_batched_gemm(cublasHandle_t handle,
- int m,
- int n,
- int k,
- const float* alpha,
- const float* beta,
- const float* A,
- const float* B,
- float* C,
- cublasOperation_t op_A,
- cublasOperation_t op_B,
- int stride_A,
- int stride_B,
- int stride_C,
- int batch,
- cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
- int cublas_strided_batched_gemm(cublasHandle_t handle,
- int m,
- int n,
- int k,
- const float* alpha,
- const float* beta,
- const __half* A,
- const __half* B,
- __half* C,
- cublasOperation_t op_A,
- cublasOperation_t op_B,
- int stride_A,
- int stride_B,
- int stride_C,
- int batch,
- cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|