diff --git a/src/bang/bang_runtime.cc b/src/bang/bang_runtime.cc index d909b57c..9d422a56 100644 --- a/src/bang/bang_runtime.cc +++ b/src/bang/bang_runtime.cc @@ -13,8 +13,8 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, std::map opCnt; for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(), - DataType::Float32}; + auto kernelAttrs = + KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 972bfb4c..23f58ced 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -11,8 +11,8 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { auto &perfEngine = PerfEngine::getInstance(); for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(), - DataType::Float32}; + auto kernelAttrs = + KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey);