diff --git a/include/operators/reshape.h b/include/operators/reshape.h index 66fb1bda..31cc3576 100644 --- a/include/operators/reshape.h +++ b/include/operators/reshape.h @@ -28,6 +28,8 @@ class ReshapeObj : public OperatorObj { int numInputs() const override { return 1; } int numOutputs() const override { return 1; } + inline Shape getShape() const { return dims; } + private: vector getWorkloadVector() const override; vector getOpAttrVector() const override; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 1008b391..bb155a6c 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -450,16 +450,17 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: elif ty == backend.OpType.Flatten: raise Exception("TODO") elif ty == backend.OpType.Reshape: - data = ctx.push_input(inputs[0]) - # shape = context.push_data_input( - # name, - # "shape", - # TensorProto.INT32, - # shape=[len(vals)], - # vals=1, - # ) - # context.push_node(make_node(ty.name, [data, shape], [name], name)) - raise Exception("TODO") + shape = backend.reshape_shape_of(op) + inputs.append( + ctx.push_data_input( + name, + "shape", + TensorProto.INT32, + shape=[len(shape)], + vals=shape, + ) + ) + ctx.push_node(make_node(ty.name, inputs, outputs, name)) elif ty == backend.OpType.Concat: axis = backend.concat_axis_of(op) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index d512b504..1f839256 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -296,20 +296,17 @@ class TestStringMethods(unittest.TestCase): handler = backend.GraphHandler(runtime) a = handler.tensor([1, 2, 3], 12) b = handler.tensor([1, 2, 3], 12) - ab = handler.tensor([1, 2, 3], 12) c = handler.tensor([1, 2, 3], 12) - abc = handler.tensor([1, 2, 3], 12) d = handler.tensor([1, 2, 3], 12) - abcd = handler.tensor([1, 2, 3], 12) e = handler.tensor([1, 2, 3], 12) - abcde = handler.tensor([1, 2, 3], 12) - handler.add(a, b, ab) - handler.add(ab, c, abc) - handler.add(abc, d, abcd) - handler.add(abcd, e, abcde) + x = handler.add( + handler.add(handler.add(handler.add(a, b, None), c, None), d, None), e, None + ) + y = handler.tensor([3, 2, 1], 12) + handler.reshape(x, y, [3, 2, 1]) - to_onnx(handler, "add") + to_onnx(handler, "test_frontend") if __name__ == "__main__": diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 0ffb1b31..f979b03c 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -1,6 +1,7 @@ #include "core/graph_handler.h" #include "operators/concat.h" #include "operators/gather.h" +#include "operators/reshape.h" #include #ifdef USE_CUDA @@ -102,12 +103,18 @@ static int gather_axis_of(Operator op) { return reinterpret_cast(op.get())->getAxis(); } +static Shape reshape_shape_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Reshape); + return reinterpret_cast(op.get())->getShape(); +} + void init_graph_builder(py::module &m) { using Handler = GraphHandlerObj; m.def("cpu_runtime", &CpuRuntimeObj::getInstance) .def("tensor_dtype", &tensor_dtype) + .def("reshape_shape_of", &reshape_shape_of) .def("concat_axis_of", &concat_axis_of) .def("gather_axis_of", &gather_axis_of); py::class_>(m, "Runtime");