- modify error info when kernel not found (#191)

* - modify error info when kernel not found

* - modify code as reviewer suggested

---------

Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
Chenjie Duan 2023-12-27 09:43:57 +08:00 committed by GitHub
parent 9a9587556c
commit 3f34372012
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 6 deletions

View File

@ -2,6 +2,7 @@
#include "core/common.h" #include "core/common.h"
#include "core/operator.h" #include "core/operator.h"
#include "core/tensor.h" #include "core/tensor.h"
#include "utils/operator_utils.h"
#include <functional> #include <functional>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
using json = nlohmann::json; using json = nlohmann::json;
@ -102,11 +103,9 @@ class KernelRegistry {
} }
Kernel *getKernel(const KernelAttrs &kernelAttrs) const { Kernel *getKernel(const KernelAttrs &kernelAttrs) const {
auto it = kernels.find(kernelAttrs); auto it = kernels.find(kernelAttrs);
IT_ASSERT(it != kernels.end(), IT_ASSERT(it != kernels.end(), "Kernel not found for key {" +
"Kernel not found for key {" + get_kernel_attrs_str(kernelAttrs) +
to_string(enum_to_underlying(std::get<0>(kernelAttrs))) + "}");
", " + std::to_string(std::get<1>(kernelAttrs)) + ", " +
std::get<2>(kernelAttrs).toString() + "}");
return std::get<0>(it->second); return std::get<0>(it->second);
} }
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const { const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {

View File

@ -2,6 +2,7 @@
#ifndef OPERATOR_UTIL_H #ifndef OPERATOR_UTIL_H
#define OPERATOR_UTIL_H #define OPERATOR_UTIL_H
#include "core/operator.h"
#include "core/tensor.h" #include "core/tensor.h"
namespace infini { namespace infini {
@ -10,8 +11,10 @@ namespace infini {
Shape infer_broadcast(const Shape &A, const Shape &B); Shape infer_broadcast(const Shape &A, const Shape &B);
// Launch the real axis based on rank and current axis // Launch the real axis based on rank and current axis
int get_real_axis(const int &axis, const int &rank); int get_real_axis(const int &axis, const int &rank);
// check if tensor B is unidirectional broadcastable to tensor A // Check if tensor B is unidirectional broadcastable to tensor A
bool is_unidirectional_broadcasting(const Shape &A, const Shape &B); bool is_unidirectional_broadcasting(const Shape &A, const Shape &B);
// Convert KernelAttrs to a string representation
std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs);
} // namespace infini } // namespace infini
#endif #endif

View File

@ -1,4 +1,5 @@
#include "utils/operator_utils.h" #include "utils/operator_utils.h"
#include "core/runtime.h"
namespace infini { namespace infini {
@ -64,4 +65,29 @@ bool is_unidirectional_broadcasting(const Shape &A, const Shape &B) {
} }
return true; return true;
} }
std::string device_to_str(Device device) {
std::string deviceStr;
switch (device) {
case Device::CPU:
return "CPU";
case Device::CUDA:
return "CUDA";
case Device::BANG:
return "BANG";
case Device::INTELCPU:
return "INTELCPU";
case Device::KUNLUN:
return "KUNLUN";
default:
IT_TODO_HALT();
}
}
std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs) {
std::string deviceStr = device_to_str(std::get<0>(kernelAttrs));
std::string opStr = OpType(std::get<1>(kernelAttrs)).toString();
std::string datatypeStr = std::get<2>(kernelAttrs).toString();
return deviceStr + ", " + opStr + ", " + datatypeStr;
}
} // namespace infini } // namespace infini