diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 9fe92940..53f556f4 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -39,6 +39,12 @@ class GraphHandlerObj { Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor bias, ActType act); + + Tensor add(Tensor a, Tensor b, Tensor c); + Tensor sub(Tensor a, Tensor b, Tensor c); + Tensor mul(Tensor a, Tensor b, Tensor c); + Tensor div(Tensor a, Tensor b, Tensor c); + Tensor pow(Tensor a, Tensor b, Tensor c); }; } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index ff023914..38277614 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -19,15 +19,45 @@ def from_onnx(model: onnx.ModelProto): for node in model.graph.node: if node.op_type == "MatMul": - handler.matmul( + tensors[node.output[0]] = handler.matmul( tensors[node.input[0]], tensors[node.input[1]], - tensors[node.output[0]], + tensors.get(node.output[0], None), False, False, None, backend.ActType.Linear, ) + elif node.op_type == "Add": + tensors[node.output[0]] = handler.add( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0], None), + ) + elif node.op_type == "Sub": + tensors[node.output[0]] = handler.sub( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0], None), + ) + elif node.op_type == "Mul": + tensors[node.output[0]] = handler.mul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0], None), + ) + elif node.op_type == "Div": + tensors[node.output[0]] = handler.div( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0], None), + ) + elif node.op_type == "Pow": + tensors[node.output[0]] = handler.pow( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0], None), + ) def parse_onnx(model: onnx.ModelProto): diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 3178ebd2..276a78a8 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -18,12 +18,82 @@ class TestStringMethods(unittest.TestCase): ) parse_onnx(onnx.load(model_file)) - def test_import(self): - i = make_tensor_value_info("i", TensorProto.FLOAT, [1, 2, 3]) - w = make_tensor_value_info("w", TensorProto.FLOAT, [1, 3, 4]) - o = make_tensor_value_info("o", TensorProto.FLOAT, [1, 2, 4]) - matmul = make_node("MatMul", ["i", "w"], ["o"], name="matmul") - graph = make_graph([matmul], "mm", [i, w], [o]) + def test_tensor(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) + graph = make_graph([], "tensor", [x], [x]) + model = make_model(graph) + check_model(model) + from_onnx(model) + + def test_matmul(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) + a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4]) + xa = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4]) + matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul") + graph = make_graph([matmul], "matmul", [x, a], [xa]) + model = make_model(graph) + check_model(model) + from_onnx(model) + + def test_add(self): + a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) + b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) + c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7]) + add = make_node("Add", ["a", "b"], ["c"], name="add") + graph = make_graph([add], "add", [a, b], [c]) + model = make_model(graph) + check_model(model) + from_onnx(model) + + def test_sub(self): + a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) + b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) + c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7]) + sub = make_node("Sub", ["a", "b"], ["c"], name="sub") + graph = make_graph([sub], "sub", [a, b], [c]) + model = make_model(graph) + check_model(model) + from_onnx(model) + + def test_mul(self): + a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) + b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) + c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7]) + mul = make_node("Mul", ["a", "b"], ["c"], name="mul") + graph = make_graph([mul], "mul", [a, b], [c]) + model = make_model(graph) + check_model(model) + from_onnx(model) + + def test_div(self): + a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) + b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) + c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7]) + div = make_node("Div", ["a", "b"], ["c"], name="div") + graph = make_graph([div], "div", [a, b], [c]) + model = make_model(graph) + check_model(model) + from_onnx(model) + + def test_pow(self): + a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) + b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) + c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7]) + pow = make_node("Pow", ["a", "b"], ["c"], name="pow") + graph = make_graph([pow], "pow", [a, b], [c]) + model = make_model(graph) + check_model(model) + from_onnx(model) + + # see + def test_linear(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) + a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4]) + b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4]) + y = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4]) + matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul") + add = make_node("Add", ["xa", "b"], ["y"], name="add") + graph = make_graph([matmul, add], "lr", [x, a, b], [y]) model = make_model(graph) check_model(model) print(model) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index ffc68473..4fe72dd6 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -1,4 +1,5 @@ #include "core/graph_handler.h" +#include "operators/element_wise.h" #include "operators/matmul.h" namespace infini { @@ -23,6 +24,22 @@ Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA, } } +#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \ + Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \ + if (c) { \ + g->addOpWithOutputs(a, b, c); \ + return c; \ + } else { \ + return g->addOp(a, b, c)->getOutput(); \ + } \ + } + +DEFINE_ELEMENT_WISE_METHOD(add, Add) +DEFINE_ELEMENT_WISE_METHOD(sub, Sub) +DEFINE_ELEMENT_WISE_METHOD(mul, Mul) +DEFINE_ELEMENT_WISE_METHOD(div, Div) +DEFINE_ELEMENT_WISE_METHOD(pow, Pow) + static DataType dtype_repr_convert(int dtype) { switch ((OnnxDType)dtype) { case OnnxDType::FLOAT: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index b8d4354e..07e5a6bb 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -43,6 +43,21 @@ void init_graph_builder(py::module &m) { .def("matmul", py::overload_cast(&GraphHandlerObj::matmul), + policy::reference_internal) + .def("add", + py::overload_cast(&GraphHandlerObj::add), + policy::reference_internal) + .def("sub", + py::overload_cast(&GraphHandlerObj::sub), + policy::reference_internal) + .def("mul", + py::overload_cast(&GraphHandlerObj::mul), + policy::reference_internal) + .def("div", + py::overload_cast(&GraphHandlerObj::div), + policy::reference_internal) + .def("pow", + py::overload_cast(&GraphHandlerObj::pow), policy::reference_internal); }