forked from jiuyuan/InfiniTensor
feat: 前端支持 identity 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
7f0c8ebae3
commit
e4ec9c4230
|
@ -51,6 +51,7 @@ class GraphHandlerObj {
|
||||||
Tensor tanh(Tensor x, Tensor y);
|
Tensor tanh(Tensor x, Tensor y);
|
||||||
Tensor softmax(Tensor x, Tensor y);
|
Tensor softmax(Tensor x, Tensor y);
|
||||||
Tensor abs(Tensor x, Tensor y);
|
Tensor abs(Tensor x, Tensor y);
|
||||||
|
Tensor identity(Tensor x, Tensor y);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -58,6 +58,38 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0], None),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "Relu":
|
||||||
|
tensors[node.output[0]] = handler.relu(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
|
elif node.op_type == "Sigmoid":
|
||||||
|
tensors[node.output[0]] = handler.sigmoid(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
|
elif node.op_type == "Tanh":
|
||||||
|
tensors[node.output[0]] = handler.tanh(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
|
elif node.op_type == "Softmax":
|
||||||
|
tensors[node.output[0]] = handler.softmax(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
|
elif node.op_type == "Abs":
|
||||||
|
tensors[node.output[0]] = handler.abs(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
|
elif node.op_type == "Identity":
|
||||||
|
tensors[node.output[0]] = handler.identity(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
||||||
|
|
||||||
def parse_onnx(model: onnx.ModelProto):
|
def parse_onnx(model: onnx.ModelProto):
|
||||||
|
|
|
@ -135,7 +135,7 @@ class TestStringMethods(unittest.TestCase):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
||||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
||||||
y = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 2, 4])
|
||||||
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
||||||
add = make_node("Add", ["xa", "b"], ["y"], name="add")
|
add = make_node("Add", ["xa", "b"], ["y"], name="add")
|
||||||
graph = make_graph([matmul, add], "lr", [x, a, b], [y])
|
graph = make_graph([matmul, add], "lr", [x, a, b], [y])
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
#include "operators/element_wise.h"
|
#include "operators/element_wise.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
#include "operators/reshape.h"
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
@ -58,6 +59,8 @@ DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
|
||||||
DEFINE_UNARY_METHOD(tanh, Tanh)
|
DEFINE_UNARY_METHOD(tanh, Tanh)
|
||||||
DEFINE_UNARY_METHOD(softmax, Softmax)
|
DEFINE_UNARY_METHOD(softmax, Softmax)
|
||||||
DEFINE_UNARY_METHOD(abs, Abs)
|
DEFINE_UNARY_METHOD(abs, Abs)
|
||||||
|
// see operators/reshape.h
|
||||||
|
DEFINE_UNARY_METHOD(identity, Identity)
|
||||||
|
|
||||||
static DataType dtype_repr_convert(int dtype) {
|
static DataType dtype_repr_convert(int dtype) {
|
||||||
switch ((OnnxDType)dtype) {
|
switch ((OnnxDType)dtype) {
|
||||||
|
|
|
@ -22,6 +22,8 @@ void register_operator_timer(py::module &m) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void init_graph_builder(py::module &m) {
|
void init_graph_builder(py::module &m) {
|
||||||
|
using Handler = GraphHandlerObj;
|
||||||
|
|
||||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
|
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
|
||||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
|
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
|
||||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||||
|
@ -36,40 +38,35 @@ void init_graph_builder(py::module &m) {
|
||||||
.value("Tanh", ActType::Tanh)
|
.value("Tanh", ActType::Tanh)
|
||||||
.export_values();
|
.export_values();
|
||||||
py::class_<GraphHandler>(m, "GraphHandler");
|
py::class_<GraphHandler>(m, "GraphHandler");
|
||||||
py::class_<GraphHandlerObj>(m, "GraphHandlerObj")
|
py::class_<Handler>(m, "GraphHandlerObj")
|
||||||
.def(py::init<Runtime>())
|
.def(py::init<Runtime>())
|
||||||
.def("tensor", py::overload_cast<Shape, int>(&GraphHandlerObj::tensor),
|
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
||||||
policy::reference_internal)
|
policy::move)
|
||||||
.def("matmul",
|
.def("matmul",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||||
ActType>(&GraphHandlerObj::matmul),
|
ActType>(&Handler::matmul),
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("add",
|
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::add),
|
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("sub",
|
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::sub),
|
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("mul",
|
.def("mul", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::mul),
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::mul),
|
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("div",
|
.def("div", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::div),
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::div),
|
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("pow",
|
.def("pow", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::pow),
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::pow),
|
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::relu),
|
.def("relu", py::overload_cast<Tensor, Tensor>(&Handler::relu),
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("sigmoid",
|
.def("sigmoid", py::overload_cast<Tensor, Tensor>(&Handler::sigmoid),
|
||||||
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::sigmoid),
|
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::tanh),
|
.def("tanh", py::overload_cast<Tensor, Tensor>(&Handler::tanh),
|
||||||
policy::reference_internal)
|
|
||||||
.def("softmax",
|
|
||||||
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::softmax),
|
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::abs),
|
.def("softmax", py::overload_cast<Tensor, Tensor>(&Handler::softmax),
|
||||||
|
policy::move)
|
||||||
|
.def("abs", py::overload_cast<Tensor, Tensor>(&Handler::abs),
|
||||||
|
policy::move)
|
||||||
|
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
|
||||||
policy::move);
|
policy::move);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue