Fix: ignore transpose in CudaGraph since no kernel

This commit is contained in:
Liyan Zheng 2023-04-22 16:08:40 +08:00
parent 0865f8d823
commit a732b6f176
2 changed files with 4 additions and 2 deletions

View File

@ -49,7 +49,7 @@ template <typename R, typename... Args> class Functor<R(Args...)> {
virtual R visit_(const Tensor &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const Func &c, Args... args) FUNCTOR_DEFAULT;
virtual R visitDefault(const Expr &c, [[maybe_unused]] Args... args) {
dbg(*c);
dbg(*c, c->getType());
nnet_assert(0, "Reach unimplemented visit function.");
return R();
};

View File

@ -141,7 +141,9 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph) {
kernel->compute(op, perfData, this);
else
kernel->compute(op, this);
if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Reshape)
// FIXME: transpose
if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Transpose &&
op->getOpType() != OpType::Reshape)
kernels.emplace_back(op, kernel, perfData);
}
for (auto &[op, kernel, perfData] : kernels) {