diff --git a/include/core/kernel.h b/include/core/kernel.h index 3ef0d1b9..a19f3f1a 100644 --- a/include/core/kernel.h +++ b/include/core/kernel.h @@ -2,6 +2,7 @@ #include "core/common.h" #include "core/operator.h" #include "core/tensor.h" +#include "utils/operator_utils.h" #include #include using json = nlohmann::json; @@ -102,11 +103,9 @@ class KernelRegistry { } Kernel *getKernel(const KernelAttrs &kernelAttrs) const { auto it = kernels.find(kernelAttrs); - IT_ASSERT(it != kernels.end(), - "Kernel not found for key {" + - to_string(enum_to_underlying(std::get<0>(kernelAttrs))) + - ", " + std::to_string(std::get<1>(kernelAttrs)) + ", " + - std::get<2>(kernelAttrs).toString() + "}"); + IT_ASSERT(it != kernels.end(), "Kernel not found for key {" + + get_kernel_attrs_str(kernelAttrs) + + "}"); return std::get<0>(it->second); } const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const { diff --git a/include/utils/operator_utils.h b/include/utils/operator_utils.h index 4f6a6985..1b3a1eb2 100644 --- a/include/utils/operator_utils.h +++ b/include/utils/operator_utils.h @@ -2,6 +2,7 @@ #ifndef OPERATOR_UTIL_H #define OPERATOR_UTIL_H +#include "core/operator.h" #include "core/tensor.h" namespace infini { @@ -10,8 +11,10 @@ namespace infini { Shape infer_broadcast(const Shape &A, const Shape &B); // Launch the real axis based on rank and current axis int get_real_axis(const int &axis, const int &rank); -// check if tensor B is unidirectional broadcastable to tensor A +// Check if tensor B is unidirectional broadcastable to tensor A bool is_unidirectional_broadcasting(const Shape &A, const Shape &B); +// Convert KernelAttrs to a string representation +std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs); } // namespace infini #endif diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index a9b81a5e..76a1d91f 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -1,4 +1,5 @@ #include "utils/operator_utils.h" +#include "core/runtime.h" namespace infini { @@ -64,4 +65,29 @@ bool is_unidirectional_broadcasting(const Shape &A, const Shape &B) { } return true; } + +std::string device_to_str(Device device) { + std::string deviceStr; + switch (device) { + case Device::CPU: + return "CPU"; + case Device::CUDA: + return "CUDA"; + case Device::BANG: + return "BANG"; + case Device::INTELCPU: + return "INTELCPU"; + case Device::KUNLUN: + return "KUNLUN"; + default: + IT_TODO_HALT(); + } +} + +std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs) { + std::string deviceStr = device_to_str(std::get<0>(kernelAttrs)); + std::string opStr = OpType(std::get<1>(kernelAttrs)).toString(); + std::string datatypeStr = std::get<2>(kernelAttrs).toString(); + return deviceStr + ", " + opStr + ", " + datatypeStr; +} } // namespace infini