diff --git a/include/core/kernel.h b/include/core/kernel.h index 3ef0d1b9..ca627551 100644 --- a/include/core/kernel.h +++ b/include/core/kernel.h @@ -29,7 +29,6 @@ class Kernel { public: Kernel() {} virtual ~Kernel() {} - /** * @param op The operator to be executed. * @param record The parameters for kernel execution. If extra parameters @@ -105,8 +104,7 @@ class KernelRegistry { IT_ASSERT(it != kernels.end(), "Kernel not found for key {" + to_string(enum_to_underlying(std::get<0>(kernelAttrs))) + - ", " + std::to_string(std::get<1>(kernelAttrs)) + ", " + - std::get<2>(kernelAttrs).toString() + "}"); + ", " + std::to_string(std::get<1>(kernelAttrs)) + "}"); return std::get<0>(it->second); } const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const { @@ -131,15 +129,16 @@ class CpuKernelWithoutConfig : public Kernel { } // namespace infini -#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \ +#define _REGISTER_KERNEL_1(device, opType, kernel, name, cnt) \ namespace infini { \ static const bool _CAT(_register_kernel_, cnt) = \ - KernelRegistry::getInstance().registerKernel( \ - KernelAttrs{device, opType, dataType}, new kernel(), name); \ + KernelRegistry::getInstance().registerKernel(KernelAttrs{device, \ + opType}, \ + new kernel(), name); \ } -#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \ - _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__) +#define REGISTER_KERNEL(device, opType, kernel, name) \ + _REGISTER_KERNEL_1(device, opType, kernel, name, __COUNTER__) #define _REGISTER_CONSTRUCTOR_1(type, constructor, cnt) \ namespace infini { \ diff --git a/include/core/operator.h b/include/core/operator.h index cc8ce174..b2cb4aff 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -4,7 +4,7 @@ #include "core/tensor.h" namespace infini { -using KernelAttrs = std::tuple; +using KernelAttrs = std::tuple; struct OpPerfKey { HashType hash; @@ -90,6 +90,8 @@ class OperatorObj : public Object { OpType getOpType() const { return type; } // HACK: set correct data type DataType getDType() const { return getInputs(0)->getDType(); } + DataType getInDType() const { return getInputs(0)->getDType(); } + DataType getOutDType() const { return getOutput()->getDType(); } virtual int numInputs() const = 0; virtual int numOutputs() const = 0; diff --git a/include/cuda/cuda_utility.h b/include/cuda/cuda_utility.h index 85e3478b..7c3e8044 100644 --- a/include/cuda/cuda_utility.h +++ b/include/cuda/cuda_utility.h @@ -1,11 +1,13 @@ +#pragma once #include "core/tensor.h" +#include "cuda/cuda_common.h" namespace infini { void cudaPrintFloat(float *x, int len); -void cudaPrintTensor(const Tensor &tensor) { - cudaPrintFloat(tensor->getRawDataPtr(), tensor->size()); -} +void cudaPrintTensor(const Tensor &tensor); -} // namespace infini \ No newline at end of file +cudnnDataType_t cudnnDataTypeConvert(DataType dataType); + +} // namespace infini diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 4d64d433..f1ae8849 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -17,8 +17,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { std::map opCnt; for (auto &op : graph->getOperators()) { - auto kernelAttrs = - KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; + auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); @@ -66,8 +65,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const { std::map opCnt; for (auto &op : graph->getOperators()) { - auto kernelAttrs = - KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; + auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 0676646a..b92cb18f 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -25,8 +25,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { auto &perfEngine = PerfEngine::getInstance(); for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = - KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; + auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); @@ -48,8 +47,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const { std::map opCnt; for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(), - DataType::Float32}; + auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); diff --git a/src/cuda/cuda_utility.cu b/src/cuda/cuda_utility.cu index 83490959..1c03d5c9 100644 --- a/src/cuda/cuda_utility.cu +++ b/src/cuda/cuda_utility.cu @@ -1,4 +1,6 @@ +#include "core/data_type.h" #include "cuda/cuda_common.h" +#include "cuda/cuda_utility.h" #include __global__ void cudaPrintFloatImpl(float *x, int len) { @@ -18,4 +20,39 @@ void cudaPrintFloat(float *x, int len) { cudaDeviceSynchronize(); } +void cudaPrintTensor(const Tensor &tensor) { + cudaPrintFloat(tensor->getRawDataPtr(), tensor->size()); +} + +cudnnDataType_t cudnnDataTypeConvert(DataType dataType) { + if (dataType == DataType::Float32) { + return CUDNN_DATA_FLOAT; + } + if (dataType == DataType::Double) { + return CUDNN_DATA_DOUBLE; + } + if (dataType == DataType::Float16) { + return CUDNN_DATA_HALF; + } + if (dataType == DataType::Int8) { + return CUDNN_DATA_INT8; + } + if (dataType == DataType::Int32) { + return CUDNN_DATA_INT32; + } + if (dataType == DataType::UInt8) { + return CUDNN_DATA_UINT8; + } + if (dataType == DataType::BFloat16) { + return CUDNN_DATA_BFLOAT16; + } + if (dataType == DataType::Int64) { + return CUDNN_DATA_INT64; + } + if (dataType == DataType::Bool) { + return CUDNN_DATA_BOOLEAN; + } + IT_ASSERT(false, "Unsupported data type"); +} + } // namespace infini