From bd9e1aeb3faa80513413d05a06a548bebe86a019 Mon Sep 17 00:00:00 2001 From: zhangyunze <93699316+bitzyz@users.noreply.github.com> Date: Thu, 10 Aug 2023 15:22:18 +0800 Subject: [PATCH] fix: fix cuda conv_fp16 run fail (#105) --- src/bang/bang_runtime.cc | 4 ++-- src/cuda/cuda_runtime.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/bang/bang_runtime.cc b/src/bang/bang_runtime.cc index d909b57c..9d422a56 100644 --- a/src/bang/bang_runtime.cc +++ b/src/bang/bang_runtime.cc @@ -13,8 +13,8 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, std::map opCnt; for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(), - DataType::Float32}; + auto kernelAttrs = + KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 972bfb4c..23f58ced 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -11,8 +11,8 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { auto &perfEngine = PerfEngine::getInstance(); for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(), - DataType::Float32}; + auto kernelAttrs = + KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey);