Add: TF32 supports and accurate timing for conv

This commit is contained in:
Liyan Zheng 2023-05-07 13:22:39 +08:00
parent abcfa76fb5
commit d25b606e12
8 changed files with 85 additions and 37 deletions

View File

@ -77,6 +77,6 @@ double timeit(
const std::function<void()> &func,
// HACK: set timeit rounds to 10 for fast debug
const std::function<void(void)> &sync = []() {}, int warmupRounds = 10,
int timingRounds = 10);
int timingRounds = 100);
} // namespace infini

View File

@ -56,6 +56,14 @@ class GraphObj : public Object {
return cloneOperator(op, inputs, outputs);
}
Operator cloneOpAndCreateInputsOutputs(Operator op) {
vector<Tensor> inputs;
for (auto t : op->getInputs()) {
inputs.emplace_back(cloneTensor(t));
}
return cloneOpAndCreateOutputs(op, inputs);
}
const TensorVec &getTensors() const { return tensors; }
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;

View File

@ -21,6 +21,8 @@ class CudaRuntimeObj : public RuntimeObj {
// CUDA device properties
cudaDeviceProp deviceProperties;
bool enableTF32 = false;
public:
CudaRuntimeObj();
virtual ~CudaRuntimeObj();
@ -82,7 +84,11 @@ class CudaRuntimeObj : public RuntimeObj {
bool isInCudaGraph() const { return cudaGraphStatus; }
cudaStream_t getStream() const { return stream; }
double timeWithCudaGraph(Graph graph, int rounds = 1000);
double timeWithCudaGraph(Graph graph, int rounds = 50);
double timeWithCudaGraph(vector<std::function<void(void)>> funcs,
int rounds = 50);
void setEnableTF32(bool state);
bool getEnableTF32() const { return enableTF32; }
private:
void tune(const Graph &graph, bool profiling) const;

View File

@ -4,9 +4,9 @@
#include "core/runtime.h"
#include "cuda_profiler_api.h"
#include "nnet/dbg.h"
#include "operators/any.h"
#include "operators/conv.h"
#include "operators/matmul.h"
#include "operators/any.h"
#ifdef INFINI_USE_TVM
#include "tvm/runtime/device_api.h"
#endif
@ -148,13 +148,25 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph, int rounds) {
if (as<AnyObj>(op))
dbg(op, as<AnyObj>(op)->getKernelName() == string("FakeOp"));
if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Reshape &&
!isFakeOp)
op->getOpType() != OpType::Flatten && !isFakeOp)
kernels.emplace_back(op, kernel, perfData);
}
for (auto &[op, kernel, perfData] : kernels) {
dbg(op);
}
vector<std::function<void(void)>> funcs;
for (auto &[op, kernel, perfData] : kernels) {
if (perfData)
funcs.push_back([&]() { kernel->compute(op, perfData, this); });
else
funcs.push_back([&]() { kernel->compute(op, this); });
}
return timeWithCudaGraph(funcs, rounds);
}
double
CudaRuntimeObj::timeWithCudaGraph(std::vector<std::function<void(void)>> funcs,
int rounds) {
// TODO: move this to kernel source?
// Init tvm stream
#ifdef INFINI_USE_TVM
@ -163,25 +175,22 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph, int rounds) {
tvm_device->SetStream(tvm_device_id, getStream());
#endif
beginCudaGraphStreamCapture();
for (auto &[op, kernel, perfData] : kernels) {
if (perfData)
kernel->compute(op, perfData, this);
else
kernel->compute(op, this);
}
for (auto &f : funcs)
f();
auto [cudaGraphInstance, numCudaGraphNodes] = endCudaGraphStreamCapture();
// Since one TVM packed function may contaion more than one CUDA kernel, the
// number of captured kernels may exceed the number of operators.
IT_ASSERT(numCudaGraphNodes >= kernels.size(),
IT_ASSERT(numCudaGraphNodes >= funcs.size(),
std::to_string(numCudaGraphNodes) +
" != " + std::to_string(kernels.size()));
printf("numCudaGraphNodes = %lu\n", numCudaGraphNodes);
" != " + std::to_string(funcs.size()));
return timeit(
[&, cudaGraphInstance = cudaGraphInstance, stream = getStream()]() {
checkCudaError(cudaGraphLaunch(cudaGraphInstance, stream));
},
[&, stream = getStream()]() { cudaStreamSynchronize(stream); }, rounds,
rounds);
[&, stream = getStream()]() { cudaStreamSynchronize(stream); },
std::min(50, rounds), rounds);
}
void CudaRuntimeObj::setEnableTF32(bool state) { enableTF32 = state; }
} // namespace infini

View File

@ -323,7 +323,9 @@ void init_graph_builder(py::module &m) {
#ifdef USE_CUDA
py::class_<CudaRuntimeObj, Ref<CudaRuntimeObj>, RuntimeObj>(m,
"CudaRuntime")
.def("timeWithCudaGraph", &CudaRuntimeObj::timeWithCudaGraph);
.def("timeWithCudaGraph",
py::overload_cast<Graph, int>(&CudaRuntimeObj::timeWithCudaGraph))
.def("setEnableTF32", &CudaRuntimeObj::setEnableTF32);
#endif
#ifdef USE_BANG
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(

View File

@ -1,4 +1,5 @@
#include "operators/conv.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include <chrono>
@ -234,7 +235,8 @@ class convCudnn : public Kernel {
const RuntimeObj *_context) const override {
ConvCuDnnPerfRecordObj ret;
ret.time = std::numeric_limits<double>::max();
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
auto context = const_cast<CudaRuntimeObj *>(
dynamic_cast<const CudaRuntimeObj *>(_context));
auto op = as<ConvBaseObj>(_op);
int try_algo = op->getOpType() == OpType::ConvNHWC ? 2 : N_ALGO;
// Both modes have the same performance. Only run cross-correlation.
@ -267,16 +269,15 @@ class convCudnn : public Kernel {
record.workspaceSize, &beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS)
continue;
record.time = timeit(
[&]() {
cudnnConvolutionForward(context->cudnnHandle(), &alpha,
inDesc, inData, knDesc, knData,
convDesc, ALGOS[record.algo],
wsData, record.workspaceSize,
&beta, outDesc, outData);
},
[&]() { context->sync(); });
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
// Time the kernel with CUDA Graph to get a precise time
std::function<void(void)> func = [&]() {
cudnnConvolutionForward(
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
knData, convDesc, ALGOS[record.algo], wsData,
record.workspaceSize, &beta, outDesc, outData);
};
record.time = context->timeWithCudaGraph({func}, 100);
// printf("mode:%d algo:%d :%.4lf\n", mode, algo, record.time);
// Update the tune result
if (ret.time > record.time)

View File

@ -1,6 +1,7 @@
#include "operators/matmul.h"
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "nnet/dbg.h"
namespace infini {
@ -21,8 +22,7 @@ struct MatmulCublasPerfRecordObj : public PerfRecordObj {
}
};
constexpr int N_ALGO = 24;
constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
const vector<cublasGemmAlgo_t> Algos = {
CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, CUBLAS_GEMM_ALGO2,
CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4, CUBLAS_GEMM_ALGO5,
CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7, CUBLAS_GEMM_ALGO8,
@ -32,6 +32,17 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20,
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
};
const vector<cublasGemmAlgo_t> AlgosTensorOp = {
CUBLAS_GEMM_DFALT_TENSOR_OP, CUBLAS_GEMM_ALGO0_TENSOR_OP,
CUBLAS_GEMM_ALGO1_TENSOR_OP, CUBLAS_GEMM_ALGO2_TENSOR_OP,
CUBLAS_GEMM_ALGO3_TENSOR_OP, CUBLAS_GEMM_ALGO4_TENSOR_OP,
CUBLAS_GEMM_ALGO5_TENSOR_OP, CUBLAS_GEMM_ALGO6_TENSOR_OP,
CUBLAS_GEMM_ALGO7_TENSOR_OP, CUBLAS_GEMM_ALGO8_TENSOR_OP,
CUBLAS_GEMM_ALGO9_TENSOR_OP, CUBLAS_GEMM_ALGO10_TENSOR_OP,
CUBLAS_GEMM_ALGO11_TENSOR_OP, CUBLAS_GEMM_ALGO12_TENSOR_OP,
CUBLAS_GEMM_ALGO13_TENSOR_OP, CUBLAS_GEMM_ALGO14_TENSOR_OP,
CUBLAS_GEMM_ALGO15_TENSOR_OP};
class matmulCublas : public Kernel {
bool do_compute(const Operator &_op, const PerfRecord &_record,
const RuntimeObj *_context) const {
@ -49,8 +60,11 @@ class matmulCublas : public Kernel {
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
ldc = n;
const float alpha = 1.f, beta = 0.f;
// TODO:use compute type
cublasStatus_t stat;
// Set the compute type to TF32 if enabled
cublasComputeType_t computeType = context->getEnableTF32()
? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
if (record->apiId == 0) {
// Support batch broadcast with zero stride
int dimA = op->getInputs(0)->getDims().size();
@ -73,13 +87,13 @@ class matmulCublas : public Kernel {
stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
CUDA_R_32F, ldb, strideB, inAData, CUDA_R_32F, lda, strideA,
&beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F,
&beta, outData, CUDA_R_32F, ldc, m * n, b, computeType,
(cublasGemmAlgo_t)record->algo);
} else if (record->apiId == 1) {
stat = cublasGemmEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
CUDA_R_32F, ldc, computeType, (cublasGemmAlgo_t)record->algo);
} else
IT_ASSERT(false);
// if (stat != CUBLAS_STATUS_SUCCESS)
@ -109,11 +123,19 @@ class matmulCublas : public Kernel {
vector<int> apis{0};
if (op->getB() == 1)
apis.emplace_back(1);
// Set the possible algorithm range
auto algos = Algos;
if (context->getEnableTF32()) {
algos.insert(algos.end(), AlgosTensorOp.begin(),
AlgosTensorOp.end());
}
for (int api : apis) {
for (int i = 0; i < N_ALGO; i++) {
for (size_t i = 0; i < algos.size(); i++) {
auto rcd = make_ref<MatmulCublasPerfRecordObj>();
rcd->apiId = api;
rcd->algo = ALGOS[i];
rcd->algo = algos[i];
if (!do_compute(_op, rcd, _context))
continue;
rcd->time = timeit([&]() { do_compute(_op, rcd, _context); },

View File

@ -68,16 +68,16 @@ TEST(cuBLAS_Matmul, tune) {
const int B = 1, M = 4, N = 4096, K = 448;
const bool transA = true, transB = false;
auto cudaRuntime = make_ref<CudaRuntimeObj>();
cudaRuntime->setEnableTF32(true);
Graph g = make_ref<GraphObj>(cudaRuntime);
auto a = g->addTensor(transA ? Shape{B, K, M} : Shape{B, M, K});
auto b = g->addTensor(transB ? Shape{B, N, K} : Shape{B, K, N});
// allocate CUDA memory
auto matmul = g->addOp<MatmulObj>(a, b, nullptr, transA, transB);
g->dataMalloc();
a->setData(IncrementalGenerator());
b->setData(IncrementalGenerator());
auto matmul = g->addOp<MatmulObj>(a, b, nullptr, transA, transB);
matmul->print();
double time = cudaRuntime->getPerfTime(g);
EXPECT_GT(time, 1e-3);
EXPECT_LT(time, 1);