From a732b6f176ed5e2df0c48003c5fd89e2f37b8746 Mon Sep 17 00:00:00 2001 From: Liyan Zheng Date: Sat, 22 Apr 2023 16:08:40 +0800 Subject: [PATCH] Fix: ignore transpose in CudaGraph since no kernel --- include/nnet/visitor.h | 2 +- src/cuda/cuda_runtime.cc | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/include/nnet/visitor.h b/include/nnet/visitor.h index c415a097..f3be4f8a 100644 --- a/include/nnet/visitor.h +++ b/include/nnet/visitor.h @@ -49,7 +49,7 @@ template class Functor { virtual R visit_(const Tensor &c, Args... args) FUNCTOR_DEFAULT; virtual R visit_(const Func &c, Args... args) FUNCTOR_DEFAULT; virtual R visitDefault(const Expr &c, [[maybe_unused]] Args... args) { - dbg(*c); + dbg(*c, c->getType()); nnet_assert(0, "Reach unimplemented visit function."); return R(); }; diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 87bb8c4a..d717b092 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -141,7 +141,9 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph) { kernel->compute(op, perfData, this); else kernel->compute(op, this); - if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Reshape) + // FIXME: transpose + if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Transpose && + op->getOpType() != OpType::Reshape) kernels.emplace_back(op, kernel, perfData); } for (auto &[op, kernel, perfData] : kernels) {