forked from jiuyuan/InfiniTensor
- modify error info when kernel not found (#191)
* - modify error info when kernel not found * - modify code as reviewer suggested --------- Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
9a9587556c
commit
3f34372012
|
@ -2,6 +2,7 @@
|
||||||
#include "core/common.h"
|
#include "core/common.h"
|
||||||
#include "core/operator.h"
|
#include "core/operator.h"
|
||||||
#include "core/tensor.h"
|
#include "core/tensor.h"
|
||||||
|
#include "utils/operator_utils.h"
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
@ -102,11 +103,9 @@ class KernelRegistry {
|
||||||
}
|
}
|
||||||
Kernel *getKernel(const KernelAttrs &kernelAttrs) const {
|
Kernel *getKernel(const KernelAttrs &kernelAttrs) const {
|
||||||
auto it = kernels.find(kernelAttrs);
|
auto it = kernels.find(kernelAttrs);
|
||||||
IT_ASSERT(it != kernels.end(),
|
IT_ASSERT(it != kernels.end(), "Kernel not found for key {" +
|
||||||
"Kernel not found for key {" +
|
get_kernel_attrs_str(kernelAttrs) +
|
||||||
to_string(enum_to_underlying(std::get<0>(kernelAttrs))) +
|
"}");
|
||||||
", " + std::to_string(std::get<1>(kernelAttrs)) + ", " +
|
|
||||||
std::get<2>(kernelAttrs).toString() + "}");
|
|
||||||
return std::get<0>(it->second);
|
return std::get<0>(it->second);
|
||||||
}
|
}
|
||||||
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {
|
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
#ifndef OPERATOR_UTIL_H
|
#ifndef OPERATOR_UTIL_H
|
||||||
#define OPERATOR_UTIL_H
|
#define OPERATOR_UTIL_H
|
||||||
|
|
||||||
|
#include "core/operator.h"
|
||||||
#include "core/tensor.h"
|
#include "core/tensor.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
@ -10,8 +11,10 @@ namespace infini {
|
||||||
Shape infer_broadcast(const Shape &A, const Shape &B);
|
Shape infer_broadcast(const Shape &A, const Shape &B);
|
||||||
// Launch the real axis based on rank and current axis
|
// Launch the real axis based on rank and current axis
|
||||||
int get_real_axis(const int &axis, const int &rank);
|
int get_real_axis(const int &axis, const int &rank);
|
||||||
// check if tensor B is unidirectional broadcastable to tensor A
|
// Check if tensor B is unidirectional broadcastable to tensor A
|
||||||
bool is_unidirectional_broadcasting(const Shape &A, const Shape &B);
|
bool is_unidirectional_broadcasting(const Shape &A, const Shape &B);
|
||||||
|
// Convert KernelAttrs to a string representation
|
||||||
|
std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs);
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
#include "utils/operator_utils.h"
|
#include "utils/operator_utils.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
@ -64,4 +65,29 @@ bool is_unidirectional_broadcasting(const Shape &A, const Shape &B) {
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string device_to_str(Device device) {
|
||||||
|
std::string deviceStr;
|
||||||
|
switch (device) {
|
||||||
|
case Device::CPU:
|
||||||
|
return "CPU";
|
||||||
|
case Device::CUDA:
|
||||||
|
return "CUDA";
|
||||||
|
case Device::BANG:
|
||||||
|
return "BANG";
|
||||||
|
case Device::INTELCPU:
|
||||||
|
return "INTELCPU";
|
||||||
|
case Device::KUNLUN:
|
||||||
|
return "KUNLUN";
|
||||||
|
default:
|
||||||
|
IT_TODO_HALT();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs) {
|
||||||
|
std::string deviceStr = device_to_str(std::get<0>(kernelAttrs));
|
||||||
|
std::string opStr = OpType(std::get<1>(kernelAttrs)).toString();
|
||||||
|
std::string datatypeStr = std::get<2>(kernelAttrs).toString();
|
||||||
|
return deviceStr + ", " + opStr + ", " + datatypeStr;
|
||||||
|
}
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue