Add: CUDA Matmul selection

This commit is contained in:
Liyan Zheng 2023-04-28 19:13:19 +08:00
parent c875f3cbb8
commit d0ae48d21d
1 changed files with 21 additions and 12 deletions

View File

@ -6,6 +6,8 @@ namespace infini {
struct MatmulCublasPerfRecordObj : public PerfRecordObj {
int algo = CUBLAS_GEMM_DEFAULT;
/// @brief 0 for cublasGemmStridedBatchedEx, 1 for cublasGemmEx
int apiId = 0;
void to_json(json &j) override {
j["type"] = 2;
j["data"] = std::make_pair(algo, time);
@ -49,7 +51,7 @@ class matmulCublas : public Kernel {
const float alpha = 1.f, beta = 0.f;
// TODO:use compute type
cublasStatus_t stat;
if (b >= 1) {
if (record->apiId == 0) {
// Support batch broadcast with zero stride
int dimA = op->getInputs(0)->getDims().size();
int dimB = op->getInputs(1)->getDims().size();
@ -73,12 +75,13 @@ class matmulCublas : public Kernel {
CUDA_R_32F, ldb, strideB, inAData, CUDA_R_32F, lda, strideA,
&beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F,
(cublasGemmAlgo_t)record->algo);
} else {
} else if (record->apiId == 1) {
stat = cublasGemmEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
}
} else
IT_ASSERT(false);
// if (stat != CUBLAS_STATUS_SUCCESS)
// cout << cublasGetErrorString(stat);
return (stat == CUBLAS_STATUS_SUCCESS);
@ -103,15 +106,21 @@ class matmulCublas : public Kernel {
IT_ASSERT(op);
auto ret = make_ref<MatmulCublasPerfRecordObj>();
ret->time = std::numeric_limits<double>::max();
for (int i = 0; i < N_ALGO; i++) {
auto rcd = make_ref<MatmulCublasPerfRecordObj>();
rcd->algo = ALGOS[i];
if (!do_compute(_op, rcd, _context))
continue;
rcd->time = timeit([&]() { do_compute(_op, rcd, _context); },
[&]() { context->sync(); });
if (rcd->time < ret->time)
ret = rcd;
vector<int> apis{0};
if (op->getB() == 1)
apis.emplace_back(1);
for (int api : apis) {
for (int i = 0; i < N_ALGO; i++) {
auto rcd = make_ref<MatmulCublasPerfRecordObj>();
rcd->apiId = api;
rcd->algo = ALGOS[i];
if (!do_compute(_op, rcd, _context))
continue;
rcd->time = timeit([&]() { do_compute(_op, rcd, _context); },
[&]() { context->sync(); });
if (rcd->time < ret->time)
ret = rcd;
}
}
IT_ASSERT(ret->time < std::numeric_limits<double>::max(),
"No valid algorithm found for " + op->toString());