diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 53f556f4..0a54ab41 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -45,6 +45,12 @@ class GraphHandlerObj { Tensor mul(Tensor a, Tensor b, Tensor c); Tensor div(Tensor a, Tensor b, Tensor c); Tensor pow(Tensor a, Tensor b, Tensor c); + + Tensor relu(Tensor x, Tensor y); + Tensor sigmoid(Tensor x, Tensor y); + Tensor tanh(Tensor x, Tensor y); + Tensor softmax(Tensor x, Tensor y); + Tensor abs(Tensor x, Tensor y); }; } // namespace infini diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 276a78a8..9ba1d802 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -85,6 +85,51 @@ class TestStringMethods(unittest.TestCase): check_model(model) from_onnx(model) + def test_relu(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + relu = make_node("Relu", ["x"], ["y"], name="relu") + graph = make_graph([relu], "relu", [x], [y]) + model = make_model(graph) + check_model(model) + from_onnx(model) + + def test_sigmoid(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + sigmoid = make_node("Sigmoid", ["x"], ["y"], name="sigmoid") + graph = make_graph([sigmoid], "sigmoid", [x], [y]) + model = make_model(graph) + check_model(model) + from_onnx(model) + + def test_tanh(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + tanh = make_node("Tanh", ["x"], ["y"], name="tanh") + graph = make_graph([tanh], "tanh", [x], [y]) + model = make_model(graph) + check_model(model) + from_onnx(model) + + def test_softmax(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + softmax = make_node("Softmax", ["x"], ["y"], name="softmax") + graph = make_graph([softmax], "softmax", [x], [y]) + model = make_model(graph) + check_model(model) + from_onnx(model) + + def test_abs(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + abs = make_node("Abs", ["x"], ["y"], name="abs") + graph = make_graph([abs], "abs", [x], [y]) + model = make_model(graph) + check_model(model) + from_onnx(model) + # see def test_linear(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 4fe72dd6..17fa2d18 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -1,6 +1,7 @@ #include "core/graph_handler.h" #include "operators/element_wise.h" #include "operators/matmul.h" +#include "operators/unary.h" namespace infini { @@ -24,6 +25,7 @@ Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA, } } +// see operators/element_wise.h #define DEFINE_ELEMENT_WISE_METHOD(name, obj) \ Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \ if (c) { \ @@ -40,6 +42,23 @@ DEFINE_ELEMENT_WISE_METHOD(mul, Mul) DEFINE_ELEMENT_WISE_METHOD(div, Div) DEFINE_ELEMENT_WISE_METHOD(pow, Pow) +// see operators/unary.h +#define DEFINE_UNARY_METHOD(name, obj) \ + Tensor GraphHandlerObj::name(Tensor x, Tensor y) { \ + if (y) { \ + g->addOpWithOutputs(x, y); \ + return y; \ + } else { \ + return g->addOp(x, y)->getOutput(); \ + } \ + } + +DEFINE_UNARY_METHOD(relu, Relu) +DEFINE_UNARY_METHOD(sigmoid, Sigmoid) +DEFINE_UNARY_METHOD(tanh, Tanh) +DEFINE_UNARY_METHOD(softmax, Softmax) +DEFINE_UNARY_METHOD(abs, Abs) + static DataType dtype_repr_convert(int dtype) { switch ((OnnxDType)dtype) { case OnnxDType::FLOAT: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 07e5a6bb..e61e422f 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -43,22 +43,34 @@ void init_graph_builder(py::module &m) { .def("matmul", py::overload_cast(&GraphHandlerObj::matmul), - policy::reference_internal) + policy::move) .def("add", py::overload_cast(&GraphHandlerObj::add), - policy::reference_internal) + policy::move) .def("sub", py::overload_cast(&GraphHandlerObj::sub), - policy::reference_internal) + policy::move) .def("mul", py::overload_cast(&GraphHandlerObj::mul), - policy::reference_internal) + policy::move) .def("div", py::overload_cast(&GraphHandlerObj::div), - policy::reference_internal) + policy::move) .def("pow", py::overload_cast(&GraphHandlerObj::pow), - policy::reference_internal); + policy::move) + .def("relu", py::overload_cast(&GraphHandlerObj::relu), + policy::move) + .def("sigmoid", + py::overload_cast(&GraphHandlerObj::sigmoid), + policy::move) + .def("tanh", py::overload_cast(&GraphHandlerObj::tanh), + policy::reference_internal) + .def("softmax", + py::overload_cast(&GraphHandlerObj::softmax), + policy::move) + .def("abs", py::overload_cast(&GraphHandlerObj::abs), + policy::move); } } // namespace infini