123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- #pragma once
- #include <cuda_fp16.h>
- #include <cuda_profiler_api.h>
- #include <array>
- #include <cstdio>
- #include <cstdlib>
- #include <ctime>
- #include <limits>
- #include <memory>
- #include "StopWatch.h"
- #include "cublas_wrappers.h"
- template <typename T>
- void check(T result, char const* const func, const char* const file, int const line)
- {
- if (result) {
- std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) +
- " \n");
- }
- }
- #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
- template <typename T>
- class GemmTest {
- public:
- GemmTest(int m, int n, int k, cublasOperation_t ta, cublasOperation_t tb, cublasHandle_t h)
- : M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
- {
- check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K));
- check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N));
- check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N));
- }
- ~GemmTest()
- {
- check_cuda_error(cudaFree(A));
- check_cuda_error(cudaFree(B));
- check_cuda_error(cudaFree(C));
- }
- std::array<int, 3> TestAlgo(int loops)
- {
- float alpha = (T)1.0f;
- float beta = (T)0.0f;
- int algo_fw = Run(loops, [=](int algo) {
- cublas_gemm_ex(handle,
- CUBLAS_OP_T,
- CUBLAS_OP_N,
- N,
- M,
- K,
- &alpha,
- &beta,
- B,
- A,
- C,
- static_cast<cublasGemmAlgo_t>(algo));
- });
- int algo_bw1 = Run(loops, [=](int algo) {
- cublas_gemm_ex(handle,
- CUBLAS_OP_N,
- CUBLAS_OP_T,
- K,
- N,
- M,
- &alpha,
- &beta,
- A,
- C,
- B,
- static_cast<cublasGemmAlgo_t>(algo));
- });
- int algo_bw2 = Run(loops, [=](int algo) {
- cublas_gemm_ex(handle,
- CUBLAS_OP_N,
- CUBLAS_OP_N,
- K,
- M,
- N,
- &alpha,
- &beta,
- B,
- C,
- A,
- static_cast<cublasGemmAlgo_t>(algo));
- });
- return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
- }
- template <typename Func>
- int Run(int loops, Func f)
- {
- float fast_latency = (std::numeric_limits<float>::max)();
- int fast_algo = 0;
- for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
- algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
- algo++) {
- int warm_up = 5;
- for (int i = 0; i < warm_up; ++i) f(algo);
- cudaDeviceSynchronize();
- Stopwatch timer;
- timer.Restart();
- for (int i = 0; i < loops; ++i) f(algo);
- cudaDeviceSynchronize();
- timer.Stop();
- float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
- printf("algo-%d: %.3fms\n", algo, avg_latency);
- if (avg_latency < fast_latency) {
- fast_latency = avg_latency;
- fast_algo = algo;
- }
- }
- printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
- return fast_algo;
- }
- private:
- int M, N, K;
- cublasHandle_t handle;
- cublasOperation_t transa, transb;
- T *A, *B, *C;
- };
- template <typename T>
- class StridedGemmTest {
- public:
- StridedGemmTest(int b,
- int m,
- int n,
- int k,
- cublasOperation_t ta,
- cublasOperation_t tb,
- cublasHandle_t h)
- : bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
- {
- check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K * bsz));
- check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N * bsz));
- check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N * bsz));
- }
- ~StridedGemmTest()
- {
- check_cuda_error(cudaFree(A));
- check_cuda_error(cudaFree(B));
- check_cuda_error(cudaFree(C));
- }
- std::array<int, 3> TestAlgo(int loops)
- {
- float alpha = (T)1.0f;
- float beta = (T)0.0f;
- int algo_fw = Run(loops, [=](int algo) {
- int stride_a = M * K;
- int stride_b = N * K;
- int stride_c = M * N;
- cublas_strided_batched_gemm(handle,
- M,
- N,
- K,
- &alpha,
- &beta,
- A,
- B,
- C,
- transa,
- transb,
- stride_a,
- stride_b,
- stride_c,
- bsz,
- static_cast<cublasGemmAlgo_t>(algo));
- });
- int algo_bw1 = Run(loops, [=](int algo) {
- int mb = (transa == CUBLAS_OP_T ? K : M);
- int kb = (transa == CUBLAS_OP_T ? M : K);
- int stride_a = mb * N;
- int stride_b = N * kb;
- int stride_c = M * K;
- // B need to transpose.
- cublasOperation_t op_b = (transb == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
- // Calculate d_A.
- cublas_strided_batched_gemm(handle,
- mb,
- kb,
- N,
- &alpha,
- &beta,
- (transa == CUBLAS_OP_T ? B : C),
- (transa == CUBLAS_OP_T ? C : B),
- A,
- CUBLAS_OP_N,
- op_b,
- stride_a,
- stride_b,
- stride_c,
- bsz,
- static_cast<cublasGemmAlgo_t>(algo));
- });
- int algo_bw2 = Run(loops, [=](int algo) {
- // A need to transpose.
- cublasOperation_t op_a = (transa == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
- int stride_a = M * K;
- int stride_b = M * N;
- int stride_c = N * K;
- // Calculate d_B.
- cublas_strided_batched_gemm(handle,
- K,
- N,
- M,
- &alpha,
- &beta,
- A,
- C,
- B,
- op_a,
- CUBLAS_OP_N,
- stride_a,
- stride_b,
- stride_c,
- bsz,
- static_cast<cublasGemmAlgo_t>(algo));
- });
- return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
- }
- template <typename Func>
- int Run(int loops, Func f)
- {
- float fast_latency = (std::numeric_limits<float>::max)();
- int fast_algo = 0;
- for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
- algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
- algo++) {
- int warm_up = 5;
- for (int i = 0; i < warm_up; ++i) f(algo);
- cudaDeviceSynchronize();
- Stopwatch timer;
- timer.Restart();
- for (int i = 0; i < loops; ++i) f(algo);
- cudaDeviceSynchronize();
- timer.Stop();
- float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
- printf("algo-%d: %.3fms\n", algo, avg_latency);
- if (avg_latency < fast_latency) {
- fast_latency = avg_latency;
- fast_algo = algo;
- }
- }
- printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
- return fast_algo;
- }
- private:
- int bsz, M, N, K;
- cublasHandle_t handle;
- cublasOperation_t transa, transb;
- T *A, *B, *C;
- };
|