diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 0a54ab41..53fd9f32 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -51,6 +51,7 @@ class GraphHandlerObj { Tensor tanh(Tensor x, Tensor y); Tensor softmax(Tensor x, Tensor y); Tensor abs(Tensor x, Tensor y); + Tensor identity(Tensor x, Tensor y); }; } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 38277614..709529de 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -58,6 +58,38 @@ def from_onnx(model: onnx.ModelProto): tensors[node.input[1]], tensors.get(node.output[0], None), ) + elif node.op_type == "Relu": + tensors[node.output[0]] = handler.relu( + tensors[node.input[0]], + tensors.get(node.output[0], None), + ) + elif node.op_type == "Sigmoid": + tensors[node.output[0]] = handler.sigmoid( + tensors[node.input[0]], + tensors.get(node.output[0], None), + ) + elif node.op_type == "Tanh": + tensors[node.output[0]] = handler.tanh( + tensors[node.input[0]], + tensors.get(node.output[0], None), + ) + elif node.op_type == "Softmax": + tensors[node.output[0]] = handler.softmax( + tensors[node.input[0]], + tensors.get(node.output[0], None), + ) + elif node.op_type == "Abs": + tensors[node.output[0]] = handler.abs( + tensors[node.input[0]], + tensors.get(node.output[0], None), + ) + elif node.op_type == "Identity": + tensors[node.output[0]] = handler.identity( + tensors[node.input[0]], + tensors.get(node.output[0], None), + ) + else: + raise Exception('Unsupported operator "{}"'.format(node.op_type)) def parse_onnx(model: onnx.ModelProto): diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 9ba1d802..c3b23c34 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -135,7 +135,7 @@ class TestStringMethods(unittest.TestCase): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4]) b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4]) - y = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 2, 4]) matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul") add = make_node("Add", ["xa", "b"], ["y"], name="add") graph = make_graph([matmul, add], "lr", [x, a, b], [y]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 17fa2d18..a32ecd64 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -1,6 +1,7 @@ #include "core/graph_handler.h" #include "operators/element_wise.h" #include "operators/matmul.h" +#include "operators/reshape.h" #include "operators/unary.h" namespace infini { @@ -58,6 +59,8 @@ DEFINE_UNARY_METHOD(sigmoid, Sigmoid) DEFINE_UNARY_METHOD(tanh, Tanh) DEFINE_UNARY_METHOD(softmax, Softmax) DEFINE_UNARY_METHOD(abs, Abs) +// see operators/reshape.h +DEFINE_UNARY_METHOD(identity, Identity) static DataType dtype_repr_convert(int dtype) { switch ((OnnxDType)dtype) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index e61e422f..2de5bc1b 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -22,6 +22,8 @@ void register_operator_timer(py::module &m) { } void init_graph_builder(py::module &m) { + using Handler = GraphHandlerObj; + m.def("cpu_runtime", &CpuRuntimeObj::getInstance); py::class_>(m, "RuntimeObj"); py::class_, RuntimeObj>( @@ -36,40 +38,35 @@ void init_graph_builder(py::module &m) { .value("Tanh", ActType::Tanh) .export_values(); py::class_(m, "GraphHandler"); - py::class_(m, "GraphHandlerObj") + py::class_(m, "GraphHandlerObj") .def(py::init()) - .def("tensor", py::overload_cast(&GraphHandlerObj::tensor), - policy::reference_internal) + .def("tensor", py::overload_cast(&Handler::tensor), + policy::move) .def("matmul", py::overload_cast(&GraphHandlerObj::matmul), + ActType>(&Handler::matmul), policy::move) - .def("add", - py::overload_cast(&GraphHandlerObj::add), + .def("add", py::overload_cast(&Handler::add), policy::move) - .def("sub", - py::overload_cast(&GraphHandlerObj::sub), + .def("sub", py::overload_cast(&Handler::sub), policy::move) - .def("mul", - py::overload_cast(&GraphHandlerObj::mul), + .def("mul", py::overload_cast(&Handler::mul), policy::move) - .def("div", - py::overload_cast(&GraphHandlerObj::div), + .def("div", py::overload_cast(&Handler::div), policy::move) - .def("pow", - py::overload_cast(&GraphHandlerObj::pow), + .def("pow", py::overload_cast(&Handler::pow), policy::move) - .def("relu", py::overload_cast(&GraphHandlerObj::relu), + .def("relu", py::overload_cast(&Handler::relu), policy::move) - .def("sigmoid", - py::overload_cast(&GraphHandlerObj::sigmoid), + .def("sigmoid", py::overload_cast(&Handler::sigmoid), policy::move) - .def("tanh", py::overload_cast(&GraphHandlerObj::tanh), - policy::reference_internal) - .def("softmax", - py::overload_cast(&GraphHandlerObj::softmax), + .def("tanh", py::overload_cast(&Handler::tanh), policy::move) - .def("abs", py::overload_cast(&GraphHandlerObj::abs), + .def("softmax", py::overload_cast(&Handler::softmax), + policy::move) + .def("abs", py::overload_cast(&Handler::abs), + policy::move) + .def("identity", py::overload_cast(&Handler::identity), policy::move); }