diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index df778850..41d01c1b 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -58,6 +58,7 @@ class GraphHandlerObj { Tensor identity(Tensor x, Tensor y); Tensor flatten(Tensor s, Tensor y); Tensor reshape(Tensor data, Tensor reshaped, Shape shape); + Tensor concat(TensorVec inputs, Tensor output, int dim); }; } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index bbbaaa17..87c3899c 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -1,4 +1,5 @@ import onnx, backend +from typing import Dict runtime = backend.cpu_runtime() @@ -6,8 +7,8 @@ runtime = backend.cpu_runtime() def from_onnx(model: onnx.ModelProto): handler = backend.GraphHandlerObj(runtime) - tensors = dict() - data = dict() + tensors: Dict[str, backend.TensorObj] = dict() + data: Dict[str, onnx.TensorProto] = dict() for input in model.graph.input: dims = [d.dim_value for d in input.type.tensor_type.shape.dim] @@ -121,6 +122,12 @@ def from_onnx(model: onnx.ModelProto): tensors.get(node.output[0]), [int(i) for i in data[node.input[1]].int64_data], ) + elif node.op_type == "Concat": + tensors[node.output[0]] = handler.concat( + [tensors[name] for name in node.input], + tensors.get(node.output[0]), + next((attr.i for attr in node.attribute if attr.name == "axis")), + ) else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 7e36b125..605fb6bf 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -136,11 +136,11 @@ class TestStringMethods(unittest.TestCase): make_and_import_model(make_graph([flatten], "flatten", [x], [y])) def test_reshape(self): - data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4]) + data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5]) # shape 对于后端来说并不是一个张量,然而转换中可能没有办法分辨 # 不知道怎么把 ValueInfoProto 转换成 TensorProto - shape = make_tensor_value_info("shape", TensorProto.INT64, [4]) - shape_data = make_tensor("shape", TensorProto.INT64, [4], [3, 2, 4, 3]) + shape = make_tensor_value_info("shape", TensorProto.INT64, [3]) + shape_data = make_tensor("shape", TensorProto.INT64, [3], [5, 3, 8]) reshaped = make_tensor_value_info( "reshaped", TensorProto.FLOAT, shape_data.int64_data ) @@ -151,6 +151,17 @@ class TestStringMethods(unittest.TestCase): make_graph([reshape], "reshape", [data, shape], [reshaped], [shape_data]) ) + def test_concat(self): + input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4]) + input2 = make_tensor_value_info("input2", TensorProto.FLOAT, [1, 3, 2, 5]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 9]) + concat = make_node( + "Concat", ["input1", "input2"], ["output"], axis=3, name="concat" + ) + make_and_import_model( + make_graph([concat], "concat", [input1, input2], [output]) + ) + # see def test_linear(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index cd3b355d..f9aa095e 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -1,5 +1,6 @@ #include "core/graph_handler.h" #include "operators/batch_norm.h" +#include "operators/concat.h" #include "operators/element_wise.h" #include "operators/matmul.h" #include "operators/reshape.h" @@ -93,6 +94,15 @@ Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) { } } +Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) { + if (output) { + g->addOpWithOutputs(std::move(inputs), output, dim); + return output; + } else { + return g->addOp(std::move(inputs), output, dim)->getOutput(); + } +} + static DataType dtype_repr_convert(int dtype) { switch ((OnnxDType)dtype) { case OnnxDType::FLOAT: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 00a30866..3ad18699 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -76,6 +76,9 @@ void init_graph_builder(py::module &m) { policy::move) .def("reshape", py::overload_cast(&Handler::reshape), + policy::move) + .def("concat", + py::overload_cast(&Handler::concat), policy::move); }