forked from jiuyuan/InfiniTensor
Add: TF32 supports and accurate timing for conv
This commit is contained in:
parent
abcfa76fb5
commit
d25b606e12
|
@ -77,6 +77,6 @@ double timeit(
|
|||
const std::function<void()> &func,
|
||||
// HACK: set timeit rounds to 10 for fast debug
|
||||
const std::function<void(void)> &sync = []() {}, int warmupRounds = 10,
|
||||
int timingRounds = 10);
|
||||
int timingRounds = 100);
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -56,6 +56,14 @@ class GraphObj : public Object {
|
|||
return cloneOperator(op, inputs, outputs);
|
||||
}
|
||||
|
||||
Operator cloneOpAndCreateInputsOutputs(Operator op) {
|
||||
vector<Tensor> inputs;
|
||||
for (auto t : op->getInputs()) {
|
||||
inputs.emplace_back(cloneTensor(t));
|
||||
}
|
||||
return cloneOpAndCreateOutputs(op, inputs);
|
||||
}
|
||||
|
||||
const TensorVec &getTensors() const { return tensors; }
|
||||
const OpVec &getOperators() const { return ops; }
|
||||
OpVec getComputeOps() const;
|
||||
|
|
|
@ -21,6 +21,8 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
// CUDA device properties
|
||||
cudaDeviceProp deviceProperties;
|
||||
|
||||
bool enableTF32 = false;
|
||||
|
||||
public:
|
||||
CudaRuntimeObj();
|
||||
virtual ~CudaRuntimeObj();
|
||||
|
@ -82,7 +84,11 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
bool isInCudaGraph() const { return cudaGraphStatus; }
|
||||
cudaStream_t getStream() const { return stream; }
|
||||
|
||||
double timeWithCudaGraph(Graph graph, int rounds = 1000);
|
||||
double timeWithCudaGraph(Graph graph, int rounds = 50);
|
||||
double timeWithCudaGraph(vector<std::function<void(void)>> funcs,
|
||||
int rounds = 50);
|
||||
void setEnableTF32(bool state);
|
||||
bool getEnableTF32() const { return enableTF32; }
|
||||
|
||||
private:
|
||||
void tune(const Graph &graph, bool profiling) const;
|
||||
|
|
|
@ -4,9 +4,9 @@
|
|||
#include "core/runtime.h"
|
||||
#include "cuda_profiler_api.h"
|
||||
#include "nnet/dbg.h"
|
||||
#include "operators/any.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/any.h"
|
||||
#ifdef INFINI_USE_TVM
|
||||
#include "tvm/runtime/device_api.h"
|
||||
#endif
|
||||
|
@ -148,13 +148,25 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph, int rounds) {
|
|||
if (as<AnyObj>(op))
|
||||
dbg(op, as<AnyObj>(op)->getKernelName() == string("FakeOp"));
|
||||
if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Reshape &&
|
||||
!isFakeOp)
|
||||
op->getOpType() != OpType::Flatten && !isFakeOp)
|
||||
kernels.emplace_back(op, kernel, perfData);
|
||||
}
|
||||
for (auto &[op, kernel, perfData] : kernels) {
|
||||
dbg(op);
|
||||
}
|
||||
vector<std::function<void(void)>> funcs;
|
||||
for (auto &[op, kernel, perfData] : kernels) {
|
||||
if (perfData)
|
||||
funcs.push_back([&]() { kernel->compute(op, perfData, this); });
|
||||
else
|
||||
funcs.push_back([&]() { kernel->compute(op, this); });
|
||||
}
|
||||
return timeWithCudaGraph(funcs, rounds);
|
||||
}
|
||||
|
||||
double
|
||||
CudaRuntimeObj::timeWithCudaGraph(std::vector<std::function<void(void)>> funcs,
|
||||
int rounds) {
|
||||
// TODO: move this to kernel source?
|
||||
// Init tvm stream
|
||||
#ifdef INFINI_USE_TVM
|
||||
|
@ -163,25 +175,22 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph, int rounds) {
|
|||
tvm_device->SetStream(tvm_device_id, getStream());
|
||||
#endif
|
||||
beginCudaGraphStreamCapture();
|
||||
for (auto &[op, kernel, perfData] : kernels) {
|
||||
if (perfData)
|
||||
kernel->compute(op, perfData, this);
|
||||
else
|
||||
kernel->compute(op, this);
|
||||
}
|
||||
for (auto &f : funcs)
|
||||
f();
|
||||
auto [cudaGraphInstance, numCudaGraphNodes] = endCudaGraphStreamCapture();
|
||||
// Since one TVM packed function may contaion more than one CUDA kernel, the
|
||||
// number of captured kernels may exceed the number of operators.
|
||||
IT_ASSERT(numCudaGraphNodes >= kernels.size(),
|
||||
IT_ASSERT(numCudaGraphNodes >= funcs.size(),
|
||||
std::to_string(numCudaGraphNodes) +
|
||||
" != " + std::to_string(kernels.size()));
|
||||
printf("numCudaGraphNodes = %lu\n", numCudaGraphNodes);
|
||||
" != " + std::to_string(funcs.size()));
|
||||
return timeit(
|
||||
[&, cudaGraphInstance = cudaGraphInstance, stream = getStream()]() {
|
||||
checkCudaError(cudaGraphLaunch(cudaGraphInstance, stream));
|
||||
},
|
||||
[&, stream = getStream()]() { cudaStreamSynchronize(stream); }, rounds,
|
||||
rounds);
|
||||
[&, stream = getStream()]() { cudaStreamSynchronize(stream); },
|
||||
std::min(50, rounds), rounds);
|
||||
}
|
||||
|
||||
void CudaRuntimeObj::setEnableTF32(bool state) { enableTF32 = state; }
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -323,7 +323,9 @@ void init_graph_builder(py::module &m) {
|
|||
#ifdef USE_CUDA
|
||||
py::class_<CudaRuntimeObj, Ref<CudaRuntimeObj>, RuntimeObj>(m,
|
||||
"CudaRuntime")
|
||||
.def("timeWithCudaGraph", &CudaRuntimeObj::timeWithCudaGraph);
|
||||
.def("timeWithCudaGraph",
|
||||
py::overload_cast<Graph, int>(&CudaRuntimeObj::timeWithCudaGraph))
|
||||
.def("setEnableTF32", &CudaRuntimeObj::setEnableTF32);
|
||||
#endif
|
||||
#ifdef USE_BANG
|
||||
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "operators/conv.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include <chrono>
|
||||
|
@ -234,7 +235,8 @@ class convCudnn : public Kernel {
|
|||
const RuntimeObj *_context) const override {
|
||||
ConvCuDnnPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto context = const_cast<CudaRuntimeObj *>(
|
||||
dynamic_cast<const CudaRuntimeObj *>(_context));
|
||||
auto op = as<ConvBaseObj>(_op);
|
||||
int try_algo = op->getOpType() == OpType::ConvNHWC ? 2 : N_ALGO;
|
||||
// Both modes have the same performance. Only run cross-correlation.
|
||||
|
@ -267,16 +269,15 @@ class convCudnn : public Kernel {
|
|||
record.workspaceSize, &beta, outDesc, outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
continue;
|
||||
record.time = timeit(
|
||||
[&]() {
|
||||
cudnnConvolutionForward(context->cudnnHandle(), &alpha,
|
||||
inDesc, inData, knDesc, knData,
|
||||
convDesc, ALGOS[record.algo],
|
||||
wsData, record.workspaceSize,
|
||||
&beta, outDesc, outData);
|
||||
},
|
||||
[&]() { context->sync(); });
|
||||
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
|
||||
// Time the kernel with CUDA Graph to get a precise time
|
||||
std::function<void(void)> func = [&]() {
|
||||
cudnnConvolutionForward(
|
||||
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
|
||||
knData, convDesc, ALGOS[record.algo], wsData,
|
||||
record.workspaceSize, &beta, outDesc, outData);
|
||||
};
|
||||
record.time = context->timeWithCudaGraph({func}, 100);
|
||||
// printf("mode:%d algo:%d :%.4lf\n", mode, algo, record.time);
|
||||
|
||||
// Update the tune result
|
||||
if (ret.time > record.time)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "operators/matmul.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "nnet/dbg.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -21,8 +22,7 @@ struct MatmulCublasPerfRecordObj : public PerfRecordObj {
|
|||
}
|
||||
};
|
||||
|
||||
constexpr int N_ALGO = 24;
|
||||
constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
||||
const vector<cublasGemmAlgo_t> Algos = {
|
||||
CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, CUBLAS_GEMM_ALGO2,
|
||||
CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4, CUBLAS_GEMM_ALGO5,
|
||||
CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7, CUBLAS_GEMM_ALGO8,
|
||||
|
@ -32,6 +32,17 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
|||
CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20,
|
||||
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
|
||||
};
|
||||
const vector<cublasGemmAlgo_t> AlgosTensorOp = {
|
||||
CUBLAS_GEMM_DFALT_TENSOR_OP, CUBLAS_GEMM_ALGO0_TENSOR_OP,
|
||||
CUBLAS_GEMM_ALGO1_TENSOR_OP, CUBLAS_GEMM_ALGO2_TENSOR_OP,
|
||||
CUBLAS_GEMM_ALGO3_TENSOR_OP, CUBLAS_GEMM_ALGO4_TENSOR_OP,
|
||||
CUBLAS_GEMM_ALGO5_TENSOR_OP, CUBLAS_GEMM_ALGO6_TENSOR_OP,
|
||||
CUBLAS_GEMM_ALGO7_TENSOR_OP, CUBLAS_GEMM_ALGO8_TENSOR_OP,
|
||||
CUBLAS_GEMM_ALGO9_TENSOR_OP, CUBLAS_GEMM_ALGO10_TENSOR_OP,
|
||||
CUBLAS_GEMM_ALGO11_TENSOR_OP, CUBLAS_GEMM_ALGO12_TENSOR_OP,
|
||||
CUBLAS_GEMM_ALGO13_TENSOR_OP, CUBLAS_GEMM_ALGO14_TENSOR_OP,
|
||||
CUBLAS_GEMM_ALGO15_TENSOR_OP};
|
||||
|
||||
class matmulCublas : public Kernel {
|
||||
bool do_compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const {
|
||||
|
@ -49,8 +60,11 @@ class matmulCublas : public Kernel {
|
|||
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
|
||||
ldc = n;
|
||||
const float alpha = 1.f, beta = 0.f;
|
||||
// TODO:use compute type
|
||||
cublasStatus_t stat;
|
||||
// Set the compute type to TF32 if enabled
|
||||
cublasComputeType_t computeType = context->getEnableTF32()
|
||||
? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
: CUBLAS_COMPUTE_32F;
|
||||
if (record->apiId == 0) {
|
||||
// Support batch broadcast with zero stride
|
||||
int dimA = op->getInputs(0)->getDims().size();
|
||||
|
@ -73,13 +87,13 @@ class matmulCublas : public Kernel {
|
|||
stat = cublasGemmStridedBatchedEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
CUDA_R_32F, ldb, strideB, inAData, CUDA_R_32F, lda, strideA,
|
||||
&beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F,
|
||||
&beta, outData, CUDA_R_32F, ldc, m * n, b, computeType,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
} else if (record->apiId == 1) {
|
||||
stat = cublasGemmEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
|
||||
CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
|
||||
CUDA_R_32F, ldc, computeType, (cublasGemmAlgo_t)record->algo);
|
||||
} else
|
||||
IT_ASSERT(false);
|
||||
// if (stat != CUBLAS_STATUS_SUCCESS)
|
||||
|
@ -109,11 +123,19 @@ class matmulCublas : public Kernel {
|
|||
vector<int> apis{0};
|
||||
if (op->getB() == 1)
|
||||
apis.emplace_back(1);
|
||||
|
||||
// Set the possible algorithm range
|
||||
auto algos = Algos;
|
||||
if (context->getEnableTF32()) {
|
||||
algos.insert(algos.end(), AlgosTensorOp.begin(),
|
||||
AlgosTensorOp.end());
|
||||
}
|
||||
|
||||
for (int api : apis) {
|
||||
for (int i = 0; i < N_ALGO; i++) {
|
||||
for (size_t i = 0; i < algos.size(); i++) {
|
||||
auto rcd = make_ref<MatmulCublasPerfRecordObj>();
|
||||
rcd->apiId = api;
|
||||
rcd->algo = ALGOS[i];
|
||||
rcd->algo = algos[i];
|
||||
if (!do_compute(_op, rcd, _context))
|
||||
continue;
|
||||
rcd->time = timeit([&]() { do_compute(_op, rcd, _context); },
|
||||
|
|
|
@ -68,16 +68,16 @@ TEST(cuBLAS_Matmul, tune) {
|
|||
const int B = 1, M = 4, N = 4096, K = 448;
|
||||
const bool transA = true, transB = false;
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
cudaRuntime->setEnableTF32(true);
|
||||
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||
auto a = g->addTensor(transA ? Shape{B, K, M} : Shape{B, M, K});
|
||||
auto b = g->addTensor(transB ? Shape{B, N, K} : Shape{B, K, N});
|
||||
// allocate CUDA memory
|
||||
|
||||
auto matmul = g->addOp<MatmulObj>(a, b, nullptr, transA, transB);
|
||||
g->dataMalloc();
|
||||
a->setData(IncrementalGenerator());
|
||||
b->setData(IncrementalGenerator());
|
||||
|
||||
auto matmul = g->addOp<MatmulObj>(a, b, nullptr, transA, transB);
|
||||
matmul->print();
|
||||
double time = cudaRuntime->getPerfTime(g);
|
||||
EXPECT_GT(time, 1e-3);
|
||||
EXPECT_LT(time, 1);
|
||||
|
|
Loading…
Reference in New Issue