forked from jiuyuan/InfiniTensor
Add: CUDA Matmul selection
This commit is contained in:
parent
c875f3cbb8
commit
d0ae48d21d
|
@ -6,6 +6,8 @@ namespace infini {
|
|||
|
||||
struct MatmulCublasPerfRecordObj : public PerfRecordObj {
|
||||
int algo = CUBLAS_GEMM_DEFAULT;
|
||||
/// @brief 0 for cublasGemmStridedBatchedEx, 1 for cublasGemmEx
|
||||
int apiId = 0;
|
||||
void to_json(json &j) override {
|
||||
j["type"] = 2;
|
||||
j["data"] = std::make_pair(algo, time);
|
||||
|
@ -49,7 +51,7 @@ class matmulCublas : public Kernel {
|
|||
const float alpha = 1.f, beta = 0.f;
|
||||
// TODO:use compute type
|
||||
cublasStatus_t stat;
|
||||
if (b >= 1) {
|
||||
if (record->apiId == 0) {
|
||||
// Support batch broadcast with zero stride
|
||||
int dimA = op->getInputs(0)->getDims().size();
|
||||
int dimB = op->getInputs(1)->getDims().size();
|
||||
|
@ -73,12 +75,13 @@ class matmulCublas : public Kernel {
|
|||
CUDA_R_32F, ldb, strideB, inAData, CUDA_R_32F, lda, strideA,
|
||||
&beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
} else {
|
||||
} else if (record->apiId == 1) {
|
||||
stat = cublasGemmEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
|
||||
CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
|
||||
}
|
||||
} else
|
||||
IT_ASSERT(false);
|
||||
// if (stat != CUBLAS_STATUS_SUCCESS)
|
||||
// cout << cublasGetErrorString(stat);
|
||||
return (stat == CUBLAS_STATUS_SUCCESS);
|
||||
|
@ -103,15 +106,21 @@ class matmulCublas : public Kernel {
|
|||
IT_ASSERT(op);
|
||||
auto ret = make_ref<MatmulCublasPerfRecordObj>();
|
||||
ret->time = std::numeric_limits<double>::max();
|
||||
for (int i = 0; i < N_ALGO; i++) {
|
||||
auto rcd = make_ref<MatmulCublasPerfRecordObj>();
|
||||
rcd->algo = ALGOS[i];
|
||||
if (!do_compute(_op, rcd, _context))
|
||||
continue;
|
||||
rcd->time = timeit([&]() { do_compute(_op, rcd, _context); },
|
||||
[&]() { context->sync(); });
|
||||
if (rcd->time < ret->time)
|
||||
ret = rcd;
|
||||
vector<int> apis{0};
|
||||
if (op->getB() == 1)
|
||||
apis.emplace_back(1);
|
||||
for (int api : apis) {
|
||||
for (int i = 0; i < N_ALGO; i++) {
|
||||
auto rcd = make_ref<MatmulCublasPerfRecordObj>();
|
||||
rcd->apiId = api;
|
||||
rcd->algo = ALGOS[i];
|
||||
if (!do_compute(_op, rcd, _context))
|
||||
continue;
|
||||
rcd->time = timeit([&]() { do_compute(_op, rcd, _context); },
|
||||
[&]() { context->sync(); });
|
||||
if (rcd->time < ret->time)
|
||||
ret = rcd;
|
||||
}
|
||||
}
|
||||
IT_ASSERT(ret->time < std::numeric_limits<double>::max(),
|
||||
"No valid algorithm found for " + op->toString());
|
||||
|
|
Loading…
Reference in New Issue