forked from jiuyuan/InfiniTensor
Op matmul (#20)
ADD:add cuda kernel for matmul. matmul tune Add test_matmul.cc
This commit is contained in:
parent
32a01efbbe
commit
c3bc278c12
|
@ -44,6 +44,7 @@ class MatmulObj : public OperatorObj {
|
||||||
int getM() const { return m; }
|
int getM() const { return m; }
|
||||||
int getN() const { return n; }
|
int getN() const { return n; }
|
||||||
int getK() const { return k; }
|
int getK() const { return k; }
|
||||||
|
auto getBMNK() const { return tuple{b, m, n, k}; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
#include "operators/matmul.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include <chrono>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
struct MatmulCudnnPerfRecord : public PerfRecord {
|
||||||
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||||
|
};
|
||||||
|
constexpr int N_ALGO = 24;
|
||||||
|
constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
||||||
|
CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, CUBLAS_GEMM_ALGO2,
|
||||||
|
CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4, CUBLAS_GEMM_ALGO5,
|
||||||
|
CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7, CUBLAS_GEMM_ALGO8,
|
||||||
|
CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10, CUBLAS_GEMM_ALGO11,
|
||||||
|
CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13, CUBLAS_GEMM_ALGO14,
|
||||||
|
CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16, CUBLAS_GEMM_ALGO17,
|
||||||
|
CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20,
|
||||||
|
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
|
||||||
|
};
|
||||||
|
|
||||||
|
class matmulCublas : public Kernel {
|
||||||
|
bool do_compute(const Operator &_op, const PerfRecord &_record,
|
||||||
|
const RuntimeObj *_context) const {
|
||||||
|
auto op = as<MatmulObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
void *const inAData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const inBData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
|
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
auto record = dynamic_cast<const MatmulCudnnPerfRecord &>(_record);
|
||||||
|
|
||||||
|
const auto [b, m, n, k] = op->getBMNK();
|
||||||
|
auto opA =
|
||||||
|
op->getTransA() ? CUBLAS_OP_T : CUBLAS_OP_N; // BLAS_N = col major
|
||||||
|
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
|
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
|
||||||
|
ldc = n;
|
||||||
|
const float alpha = 1.f, beta = 0.f;
|
||||||
|
// TODO:use compute type
|
||||||
|
cublasStatus_t stat;
|
||||||
|
if (b > 1) {
|
||||||
|
stat = cublasGemmStridedBatchedEx(
|
||||||
|
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||||
|
CUDA_R_32F, ldb, k * n, inAData, CUDA_R_32F, lda, m * k, &beta,
|
||||||
|
outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, record.algo);
|
||||||
|
} else {
|
||||||
|
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||||
|
&alpha, inBData, CUDA_R_32F, ldb, inAData,
|
||||||
|
CUDA_R_32F, lda, &beta, outData, CUDA_R_32F,
|
||||||
|
ldc, CUDA_R_32F, record.algo);
|
||||||
|
}
|
||||||
|
return (stat == CUBLAS_STATUS_SUCCESS);
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(const Operator &_op, const PerfRecord &_record,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
IT_ASSERT(do_compute(_op, _record, _context));
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||||
|
MatmulCudnnPerfRecord record; // use default record;
|
||||||
|
compute(op, record, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
PerfRecord tune(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
auto op = as<MatmulObj>(_op);
|
||||||
|
MatmulCudnnPerfRecord ret;
|
||||||
|
ret.time = std::numeric_limits<double>::max();
|
||||||
|
for (int i = 0; i < N_ALGO; i++) {
|
||||||
|
MatmulCudnnPerfRecord rcd;
|
||||||
|
rcd.algo = ALGOS[i];
|
||||||
|
if (!do_compute(_op, rcd, _context))
|
||||||
|
continue;
|
||||||
|
rcd.time = timeit([&]() { do_compute(_op, rcd, _context); },
|
||||||
|
[&]() { context->sync(); });
|
||||||
|
if (rcd.time < ret.time)
|
||||||
|
ret = rcd;
|
||||||
|
}
|
||||||
|
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
|
||||||
|
"algorithm "
|
||||||
|
"found");
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Matmul, DataType::Float32, matmulCublas,
|
||||||
|
"Matmul_cuBLAS_CUDA_Float32");
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,95 @@
|
||||||
|
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/matmul.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
using ExpectOutput = vector<float>;
|
||||||
|
|
||||||
|
TEST(Matmul, ShapeInference) {
|
||||||
|
auto runtime = CpuRuntimeObj::getInstance();
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
auto A = g->addTensor(Shape{1, 3, 5});
|
||||||
|
auto B = g->addTensor(Shape{1, 5, 2});
|
||||||
|
auto matmul = g->addOp<MatmulObj>(A, B, nullptr);
|
||||||
|
auto C = matmul->getOutputs()[0];
|
||||||
|
EXPECT_EQ(C->getDims(), (Shape{1, 3, 2}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
auto A = g->addTensor(Shape{3, 5, 4});
|
||||||
|
auto B = g->addTensor(Shape{3, 5, 2});
|
||||||
|
auto matmul = g->addOp<MatmulObj>(A, B, nullptr, true, false);
|
||||||
|
auto C = matmul->getOutputs()[0];
|
||||||
|
EXPECT_EQ(C->getDims(), (Shape{3, 4, 2}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void testMatmulCuda(
|
||||||
|
const std::function<void(void *, size_t, DataType)> &generatorA,
|
||||||
|
const std::function<void(void *, size_t, DataType)> &generatorB,
|
||||||
|
bool transA, bool transB, const Shape &shapeA, const Shape &shapeB,
|
||||||
|
const ExpectOutput &ansVec) {
|
||||||
|
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
auto ACpu = gCpu->addTensor(shapeA, DataType::Float32);
|
||||||
|
auto BCpu = gCpu->addTensor(shapeB, DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
ACpu->setData(generatorA);
|
||||||
|
BCpu->setData(generatorB);
|
||||||
|
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
auto gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto ACuda = gCuda->cloneTensor(ACpu);
|
||||||
|
auto BCuda = gCuda->cloneTensor(BCpu);
|
||||||
|
auto matmul =
|
||||||
|
gCuda->addOp<MatmulObj>(ACuda, BCuda, nullptr, transA, transB);
|
||||||
|
|
||||||
|
// allocate CUDA memory
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
auto CCpu = gCpu->cloneTensor(matmul->getOutput());
|
||||||
|
// CCpu->printData();
|
||||||
|
// check results on CPU
|
||||||
|
EXPECT_TRUE(CCpu->equalData(ansVec));
|
||||||
|
// print a tensor/operator/graph by print()
|
||||||
|
// gCuda->print();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Matmul, cuBlas) {
|
||||||
|
testMatmulCuda(IncrementalGenerator(), OneGenerator(), false, false,
|
||||||
|
Shape{1, 3, 5}, Shape{1, 5, 2},
|
||||||
|
ExpectOutput{10, 10, 35, 35, 60, 60});
|
||||||
|
testMatmulCuda(IncrementalGenerator(), IncrementalGenerator(), true, false,
|
||||||
|
Shape{2, 3, 4}, Shape{2, 3, 2},
|
||||||
|
ExpectOutput{40, 52, 46, 61, 52, 70, 58, 79, 400, 448, 424,
|
||||||
|
475, 448, 502, 472, 529});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Matmul, tune) {
|
||||||
|
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32);
|
||||||
|
auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
ACpu->setData(IncrementalGenerator());
|
||||||
|
BCpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
auto gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto ACuda = gCuda->cloneTensor(ACpu);
|
||||||
|
auto BCuda = gCuda->cloneTensor(BCpu);
|
||||||
|
auto matmul = gCuda->addOp<MatmulObj>(ACuda, BCuda, nullptr);
|
||||||
|
|
||||||
|
// allocate CUDA memory
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
cudaRuntime->run(gCuda, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}; // namespace infini
|
Loading…
Reference in New Issue