forked from jiuyuan/InfiniTensor
Merge branch 'NNET_e2e_fix' into NNET_e2e
Support CUDA Graph for TVM kernels
This commit is contained in:
commit
2cd75bd79b
|
@ -5,6 +5,9 @@
|
|||
#include "cuda_profiler_api.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
#ifdef INFINI_USE_TVM
|
||||
#include "tvm/runtime/device_api.h"
|
||||
#endif
|
||||
namespace infini {
|
||||
|
||||
CudaRuntimeObj::CudaRuntimeObj()
|
||||
|
@ -145,6 +148,13 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph) {
|
|||
dbg(op);
|
||||
}
|
||||
|
||||
// Init tvm stream
|
||||
#ifdef INFINI_USE_TVM
|
||||
DLDevice tvm_device_id = {kDLCUDA, 0};
|
||||
auto tvm_device = tvm::runtime::DeviceAPI::Get(tvm_device_id);
|
||||
tvm_device->SetStream(tvm_device_id, getStream());
|
||||
#endif
|
||||
|
||||
beginCudaGraphStreamCapture();
|
||||
for (auto &[op, kernel, perfData] : kernels) {
|
||||
if (perfData)
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/conv.h"
|
||||
#include "nnet/nmutator.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -27,4 +29,30 @@ TEST(TestCudaRuntime, CudaGraph) {
|
|||
EXPECT_GE(time, 0.01);
|
||||
}
|
||||
|
||||
TEST(TestCudaRuntime, CudaGraphMembound) {
|
||||
auto runtime = make_ref<CudaRuntimeObj>();
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::Float32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::Float32);
|
||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::Float32);
|
||||
g->dataMalloc();
|
||||
i0->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||
NMutator nmutator(NMutator::Mode::ToNaiveMembound);
|
||||
auto mutations = nmutator.run(g);
|
||||
ASSERT_EQ(mutations.size(), 2u);
|
||||
Graph gNew = mutations[1];
|
||||
gNew->print();
|
||||
gNew->dataMalloc();
|
||||
|
||||
runtime->run(gNew, true); // tune kernels
|
||||
runtime->run(gNew, false);
|
||||
runtime->getPerfTime(gNew);
|
||||
|
||||
runtime->timeWithCudaGraph(gNew);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue