Fix: avoid reload library

This commit is contained in:
Liyan Zheng 2023-04-14 15:10:47 +08:00
parent b6b37ccf33
commit 582de83629
1 changed files with 5 additions and 3 deletions

View File

@ -28,6 +28,7 @@ class TVMRecordObj : public PerfRecordObj {
std::string dllPath;
std::string funcName;
std::vector<int> inputIdx;
tvm::runtime::PackedFunc packedFunc;
};
using TVMRecord = Ref<TVMRecordObj>;
@ -39,9 +40,8 @@ class MemboundTVMPackedFunction : public Kernel {
auto op = as<MemBoundObj>(_op);
// auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
auto tvmRecord = std::dynamic_pointer_cast<TVMRecordObj>(record);
tvm::runtime::PackedFunc packedFunc =
getPackedFunction(tvmRecord->dllPath, tvmRecord->funcName);
IT_ASSERT(packedFunc != nullptr);
tvm::runtime::PackedFunc packedFunc = tvmRecord->packedFunc;
// IT_ASSERT(packedFunc != nullptr);
// prepare inputs and outputs
vector<DLTensorHolder> inputsHolder;
@ -99,6 +99,7 @@ class MemboundTVMPackedFunction : public Kernel {
if (inputName == op->getNnetInputs()[j]->getName())
break;
}
IT_ASSERT(j < numInputs, "Cannot find input name: " + inputName);
inputIdx.emplace_back(j);
}
@ -126,6 +127,7 @@ class MemboundTVMPackedFunction : public Kernel {
ret->dllPath = dllPath;
ret->funcName = func;
ret->inputIdx = inputIdx;
ret->packedFunc = packedFunc;
return std::dynamic_pointer_cast<PerfRecordObj>(ret);
}