forked from jiuyuan/InfiniTensor
fix: fix cuda conv_fp16 run fail (#105)
This commit is contained in:
parent
57ac94d893
commit
bd9e1aeb3f
|
@ -13,8 +13,8 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
|||
std::map<OpType, int> opCnt;
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
|
||||
DataType::Float32};
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
|
|
@ -11,8 +11,8 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
|||
auto &perfEngine = PerfEngine::getInstance();
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
|
||||
DataType::Float32};
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
|
Loading…
Reference in New Issue