forked from jiuyuan/InfiniTensor
Fix: avoid reload library
This commit is contained in:
parent
b6b37ccf33
commit
582de83629
|
@ -28,6 +28,7 @@ class TVMRecordObj : public PerfRecordObj {
|
|||
std::string dllPath;
|
||||
std::string funcName;
|
||||
std::vector<int> inputIdx;
|
||||
tvm::runtime::PackedFunc packedFunc;
|
||||
};
|
||||
|
||||
using TVMRecord = Ref<TVMRecordObj>;
|
||||
|
@ -39,9 +40,8 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
auto op = as<MemBoundObj>(_op);
|
||||
// auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto tvmRecord = std::dynamic_pointer_cast<TVMRecordObj>(record);
|
||||
tvm::runtime::PackedFunc packedFunc =
|
||||
getPackedFunction(tvmRecord->dllPath, tvmRecord->funcName);
|
||||
IT_ASSERT(packedFunc != nullptr);
|
||||
tvm::runtime::PackedFunc packedFunc = tvmRecord->packedFunc;
|
||||
// IT_ASSERT(packedFunc != nullptr);
|
||||
|
||||
// prepare inputs and outputs
|
||||
vector<DLTensorHolder> inputsHolder;
|
||||
|
@ -99,6 +99,7 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
if (inputName == op->getNnetInputs()[j]->getName())
|
||||
break;
|
||||
}
|
||||
IT_ASSERT(j < numInputs, "Cannot find input name: " + inputName);
|
||||
inputIdx.emplace_back(j);
|
||||
}
|
||||
|
||||
|
@ -126,6 +127,7 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
ret->dllPath = dllPath;
|
||||
ret->funcName = func;
|
||||
ret->inputIdx = inputIdx;
|
||||
ret->packedFunc = packedFunc;
|
||||
|
||||
return std::dynamic_pointer_cast<PerfRecordObj>(ret);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue