forked from jiuyuan/InfiniTensor
- modify error info when kernel not found (#191)
* - modify error info when kernel not found * - modify code as reviewer suggested --------- Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
9a9587556c
commit
3f34372012
|
@ -2,6 +2,7 @@
|
|||
#include "core/common.h"
|
||||
#include "core/operator.h"
|
||||
#include "core/tensor.h"
|
||||
#include "utils/operator_utils.h"
|
||||
#include <functional>
|
||||
#include <nlohmann/json.hpp>
|
||||
using json = nlohmann::json;
|
||||
|
@ -102,11 +103,9 @@ class KernelRegistry {
|
|||
}
|
||||
Kernel *getKernel(const KernelAttrs &kernelAttrs) const {
|
||||
auto it = kernels.find(kernelAttrs);
|
||||
IT_ASSERT(it != kernels.end(),
|
||||
"Kernel not found for key {" +
|
||||
to_string(enum_to_underlying(std::get<0>(kernelAttrs))) +
|
||||
", " + std::to_string(std::get<1>(kernelAttrs)) + ", " +
|
||||
std::get<2>(kernelAttrs).toString() + "}");
|
||||
IT_ASSERT(it != kernels.end(), "Kernel not found for key {" +
|
||||
get_kernel_attrs_str(kernelAttrs) +
|
||||
"}");
|
||||
return std::get<0>(it->second);
|
||||
}
|
||||
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#ifndef OPERATOR_UTIL_H
|
||||
#define OPERATOR_UTIL_H
|
||||
|
||||
#include "core/operator.h"
|
||||
#include "core/tensor.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -10,8 +11,10 @@ namespace infini {
|
|||
Shape infer_broadcast(const Shape &A, const Shape &B);
|
||||
// Launch the real axis based on rank and current axis
|
||||
int get_real_axis(const int &axis, const int &rank);
|
||||
// check if tensor B is unidirectional broadcastable to tensor A
|
||||
// Check if tensor B is unidirectional broadcastable to tensor A
|
||||
bool is_unidirectional_broadcasting(const Shape &A, const Shape &B);
|
||||
// Convert KernelAttrs to a string representation
|
||||
std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs);
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "utils/operator_utils.h"
|
||||
#include "core/runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -64,4 +65,29 @@ bool is_unidirectional_broadcasting(const Shape &A, const Shape &B) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string device_to_str(Device device) {
|
||||
std::string deviceStr;
|
||||
switch (device) {
|
||||
case Device::CPU:
|
||||
return "CPU";
|
||||
case Device::CUDA:
|
||||
return "CUDA";
|
||||
case Device::BANG:
|
||||
return "BANG";
|
||||
case Device::INTELCPU:
|
||||
return "INTELCPU";
|
||||
case Device::KUNLUN:
|
||||
return "KUNLUN";
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs) {
|
||||
std::string deviceStr = device_to_str(std::get<0>(kernelAttrs));
|
||||
std::string opStr = OpType(std::get<1>(kernelAttrs)).toString();
|
||||
std::string datatypeStr = std::get<2>(kernelAttrs).toString();
|
||||
return deviceStr + ", " + opStr + ", " + datatypeStr;
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue