diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 50eb6481..df778850 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -57,6 +57,7 @@ class GraphHandlerObj { Tensor abs(Tensor x, Tensor y); Tensor identity(Tensor x, Tensor y); Tensor flatten(Tensor s, Tensor y); + Tensor reshape(Tensor data, Tensor reshaped, Shape shape); }; } // namespace infini diff --git a/include/operators/reshape.h b/include/operators/reshape.h index 324c5374..66fb1bda 100644 --- a/include/operators/reshape.h +++ b/include/operators/reshape.h @@ -19,7 +19,7 @@ class ReshapeObj : public OperatorObj { * @param output The output tensor. * @param dims The shape of the output tensor. */ - ReshapeObj(GraphObj *graph, Tensor input, Tensor output, const Shape &dims); + ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims); OP_CLONE(ReshapeObj); optional> inferShape(const TensorVec &inputs) const override; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index ac46f0d8..b4a8f6a1 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -1,4 +1,4 @@ -import typing, onnx, backend +import onnx, backend runtime = backend.cpu_runtime() @@ -21,7 +21,7 @@ def from_onnx(model: onnx.ModelProto): tensors[node.output[0]] = handler.matmul( tensors[node.input[0]], tensors[node.input[1]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), False, False, None, @@ -31,7 +31,7 @@ def from_onnx(model: onnx.ModelProto): (input, mean, var, scale, bias) = ( tensors[node.input[i]] for i in [0, 3, 4, 1, 2] ) - output = tensors.get(node.output[0], None) + output = tensors.get(node.output[0]) attributes = _parse_attribute( node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0} ) @@ -45,61 +45,61 @@ def from_onnx(model: onnx.ModelProto): tensors[node.output[0]] = handler.add( tensors[node.input[0]], tensors[node.input[1]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), ) elif node.op_type == "Sub": tensors[node.output[0]] = handler.sub( tensors[node.input[0]], tensors[node.input[1]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), ) elif node.op_type == "Mul": tensors[node.output[0]] = handler.mul( tensors[node.input[0]], tensors[node.input[1]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), ) elif node.op_type == "Div": tensors[node.output[0]] = handler.div( tensors[node.input[0]], tensors[node.input[1]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), ) elif node.op_type == "Pow": tensors[node.output[0]] = handler.pow( tensors[node.input[0]], tensors[node.input[1]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), ) elif node.op_type == "Relu": tensors[node.output[0]] = handler.relu( tensors[node.input[0]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), ) elif node.op_type == "Sigmoid": tensors[node.output[0]] = handler.sigmoid( tensors[node.input[0]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), ) elif node.op_type == "Tanh": tensors[node.output[0]] = handler.tanh( tensors[node.input[0]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), ) elif node.op_type == "Softmax": tensors[node.output[0]] = handler.softmax( tensors[node.input[0]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), ) elif node.op_type == "Abs": tensors[node.output[0]] = handler.abs( tensors[node.input[0]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), ) elif node.op_type == "Identity": tensors[node.output[0]] = handler.identity( tensors[node.input[0]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), ) elif node.op_type == "Flatten": # TODO 后端算子不支持沿任意轴展开 @@ -109,7 +109,13 @@ def from_onnx(model: onnx.ModelProto): assert axis == None or axis == 1 tensors[node.output[0]] = handler.flatten( tensors[node.input[0]], - tensors.get(node.output[0], None), + tensors.get(node.output[0]), + ) + elif node.op_type == "Reshape": + tensors[node.output[0]] = handler.reshape( + tensors[node.input[0]], + tensors.get(node.output[0]), + [int(i) for i in tensors[node.input[1]]], ) else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 18b75773..e617d519 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -129,6 +129,16 @@ class TestStringMethods(unittest.TestCase): flatten = make_node("Flatten", ["x"], ["y"], name="flatten") make_and_import_model(make_graph([flatten], "flatten", [x], [y])) + # FIXME INT64 类型不支持 + # def test_reshape(self): + # data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5]) + # shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8]) + # reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8]) + # reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape") + # make_and_import_model( + # make_graph([reshape], "reshape", [data, shape], [reshaped]) + # ) + # see def test_linear(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 49b8ec57..80e88a09 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -82,6 +82,19 @@ DEFINE_UNARY_METHOD(abs, Abs) DEFINE_UNARY_METHOD(identity, Identity) DEFINE_UNARY_METHOD(flatten, Flatten) +Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) { + if (reshaped) { + g->addOpWithOutputs(std::move(data), reshaped, + std::move(shape)); + return reshaped; + } else { + return g + ->addOpWithOutputs(std::move(data), reshaped, + std::move(shape)) + ->getOutput(); + } +} + static DataType dtype_repr_convert(int dtype) { switch ((OnnxDType)dtype) { case OnnxDType::FLOAT: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 9d604d7f..00a30866 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -73,6 +73,9 @@ void init_graph_builder(py::module &m) { .def("identity", py::overload_cast(&Handler::identity), policy::move) .def("flatten", py::overload_cast(&Handler::flatten), + policy::move) + .def("reshape", + py::overload_cast(&Handler::reshape), policy::move); } diff --git a/src/operators/reshape.cc b/src/operators/reshape.cc index fb15681e..6ae7673b 100644 --- a/src/operators/reshape.cc +++ b/src/operators/reshape.cc @@ -1,9 +1,8 @@ #include "operators/reshape.h" namespace infini { -ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, - const Shape &dims) - : OperatorObj(OpType::Reshape, {input}, {output}), dims(dims) { +ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims) + : OperatorObj(OpType::Reshape, {input}, {output}), dims(std::move(dims)) { IT_ASSERT(checkValid(graph)); }