diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 53fd9f32..31e10c31 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -52,6 +52,7 @@ class GraphHandlerObj { Tensor softmax(Tensor x, Tensor y); Tensor abs(Tensor x, Tensor y); Tensor identity(Tensor x, Tensor y); + Tensor flatten(Tensor s, Tensor y); }; } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 709529de..f3fd2898 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -88,6 +88,16 @@ def from_onnx(model: onnx.ModelProto): tensors[node.input[0]], tensors.get(node.output[0], None), ) + elif node.op_type == "Flatten": + # TODO 后端算子不支持沿任意轴展开 + axis = next( + (attr.i for attr in node.attribute if attr.name == "axis"), None + ) + assert axis == None or axis == 1 + tensors[node.output[0]] = handler.flatten( + tensors[node.input[0]], + tensors.get(node.output[0], None), + ) else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index c3b23c34..1f27c6b1 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -5,6 +5,12 @@ from onnx.checker import check_model from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime +def make_and_import_model(graph: onnx.GraphProto): + model = make_model(graph) + check_model(model) + from_onnx(model) + + class TestStringMethods(unittest.TestCase): def test_load(self): model_file = next( @@ -20,115 +26,91 @@ class TestStringMethods(unittest.TestCase): def test_tensor(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) - graph = make_graph([], "tensor", [x], [x]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([], "tensor", [x], [x])) def test_matmul(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4]) xa = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4]) matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul") - graph = make_graph([matmul], "matmul", [x, a], [xa]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa])) def test_add(self): a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7]) add = make_node("Add", ["a", "b"], ["c"], name="add") - graph = make_graph([add], "add", [a, b], [c]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([add], "add", [a, b], [c])) def test_sub(self): a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7]) sub = make_node("Sub", ["a", "b"], ["c"], name="sub") - graph = make_graph([sub], "sub", [a, b], [c]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([sub], "sub", [a, b], [c])) def test_mul(self): a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7]) mul = make_node("Mul", ["a", "b"], ["c"], name="mul") - graph = make_graph([mul], "mul", [a, b], [c]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([mul], "mul", [a, b], [c])) def test_div(self): a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7]) div = make_node("Div", ["a", "b"], ["c"], name="div") - graph = make_graph([div], "div", [a, b], [c]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([div], "div", [a, b], [c])) def test_pow(self): a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7]) pow = make_node("Pow", ["a", "b"], ["c"], name="pow") - graph = make_graph([pow], "pow", [a, b], [c]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([pow], "pow", [a, b], [c])) def test_relu(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) relu = make_node("Relu", ["x"], ["y"], name="relu") - graph = make_graph([relu], "relu", [x], [y]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([relu], "relu", [x], [y])) def test_sigmoid(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) sigmoid = make_node("Sigmoid", ["x"], ["y"], name="sigmoid") - graph = make_graph([sigmoid], "sigmoid", [x], [y]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([sigmoid], "sigmoid", [x], [y])) def test_tanh(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) tanh = make_node("Tanh", ["x"], ["y"], name="tanh") - graph = make_graph([tanh], "tanh", [x], [y]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([tanh], "tanh", [x], [y])) def test_softmax(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) softmax = make_node("Softmax", ["x"], ["y"], name="softmax") - graph = make_graph([softmax], "softmax", [x], [y]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([softmax], "softmax", [x], [y])) def test_abs(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) abs = make_node("Abs", ["x"], ["y"], name="abs") - graph = make_graph([abs], "abs", [x], [y]) - model = make_model(graph) - check_model(model) - from_onnx(model) + make_and_import_model(make_graph([abs], "abs", [x], [y])) + + def test_identity(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + identity = make_node("Identity", ["x"], ["y"], name="identity") + make_and_import_model(make_graph([identity], "identity", [x], [y])) + + def test_flatten(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1 * 3 * 5 * 7]) + flatten = make_node("Flatten", ["x"], ["y"], name="flatten") + make_and_import_model(make_graph([flatten], "flatten", [x], [y])) # see def test_linear(self): @@ -141,8 +123,6 @@ class TestStringMethods(unittest.TestCase): graph = make_graph([matmul, add], "lr", [x, a, b], [y]) model = make_model(graph) check_model(model) - print(model) - from_onnx(model) parse_onnx(model) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index a32ecd64..662e9e44 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -61,6 +61,7 @@ DEFINE_UNARY_METHOD(softmax, Softmax) DEFINE_UNARY_METHOD(abs, Abs) // see operators/reshape.h DEFINE_UNARY_METHOD(identity, Identity) +DEFINE_UNARY_METHOD(flatten, Flatten) static DataType dtype_repr_convert(int dtype) { switch ((OnnxDType)dtype) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 2de5bc1b..be78a4d5 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -67,6 +67,8 @@ void init_graph_builder(py::module &m) { .def("abs", py::overload_cast(&Handler::abs), policy::move) .def("identity", py::overload_cast(&Handler::identity), + policy::move) + .def("flatten", py::overload_cast(&Handler::flatten), policy::move); }