From 58993d433980f76528b9ad0004f1a7919c1c63cb Mon Sep 17 00:00:00 2001 From: zhangyunze <93699316+bitzyz@users.noreply.github.com> Date: Fri, 12 Jan 2024 14:54:27 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E9=99=A4=E5=89=8D=E7=AB=AF=E5=AF=B9on?= =?UTF-8?q?nx=20infershape=E5=8A=9F=E8=83=BD=E7=9A=84=E4=BE=9D=E8=B5=96=20?= =?UTF-8?q?(#206)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: SqueezeOp lift the dependency of onnx infershape. * feat: UnsqueezeOp lift the dependency of onnx infershape. * feat: lift the dependency of onnx infershape * fix: fix Makefile off nccl --- Makefile | 1 - include/core/graph_handler.h | 2 + include/operators/squeeze.h | 39 +++++++ include/operators/unsqueeze.h | 38 +++++++ pyinfinitensor/src/pyinfinitensor/onnx.py | 127 ++++++++++------------ pyinfinitensor/tests/test_onnx.py | 22 ++++ src/core/graph_handler.cc | 24 ++++ src/ffi/ffi_infinitensor.cc | 26 +++++ src/kernels/cuda/reshape.cc | 4 + src/operators/squeeze.cc | 60 ++++++++++ src/operators/unsqueeze.cc | 52 +++++++++ test/operators/test_reshape.cc | 34 ++++++ 12 files changed, 358 insertions(+), 71 deletions(-) create mode 100644 include/operators/squeeze.h create mode 100644 include/operators/unsqueeze.h create mode 100644 src/operators/squeeze.cc create mode 100644 src/operators/unsqueeze.cc diff --git a/Makefile b/Makefile index d21a406b..302f47b8 100644 --- a/Makefile +++ b/Makefile @@ -29,7 +29,6 @@ CMAKE_OPT += -DUSE_BANG=$(BANG) CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN) CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE) CMAKE_OPT += -DBUILD_TEST=$(TEST) -CMAKE_OPT += -DBUILD_DIST=ON CMAKE_OPT += -DBUILD_NNET=$(NNET) ifeq ($(INTELCPU), ON) diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 313a1f79..0e1472bb 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -71,6 +71,8 @@ class GraphHandlerObj { vector scales_, vector roi_, string mode, string ratioPolicy, string nearestMode, string coordTransMode); + Tensor squeeze(Tensor input, Tensor output, Shape axes); + Tensor unsqueeze(Tensor input, Tensor output, Shape axes); Tensor concat(TensorVec inputs, Tensor output, int dim); Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, diff --git a/include/operators/squeeze.h b/include/operators/squeeze.h new file mode 100644 index 00000000..d1d99683 --- /dev/null +++ b/include/operators/squeeze.h @@ -0,0 +1,39 @@ +#pragma once + +#include "core/operator.h" + +namespace infini { + +/** + * @brief Remove single-dimensional entries from the shape of a tensor. + * + */ +class SqueezeObj : public OperatorObj { + Shape axes; + + public: + /** + * @brief Construct a new Squeeze object. + * + * @param graph The computation graph that this operator belongs to. + * @param input The input tensor. + * @param output The output tensor. + * @param axes List of integers indicating the dimensions to squeeze. + */ + SqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes); + OP_CLONE(SqueezeObj); + + optional> inferShape(const TensorVec &inputs) override; + + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + + inline Shape getAxes() const { return axes; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +} // namespace infini diff --git a/include/operators/unsqueeze.h b/include/operators/unsqueeze.h new file mode 100644 index 00000000..f496d32a --- /dev/null +++ b/include/operators/unsqueeze.h @@ -0,0 +1,38 @@ +#pragma once + +#include "core/operator.h" + +namespace infini { +/** + * @brief nsert single-dimensional entries to the shape of an input tensor. + * + */ +class UnsqueezeObj : public OperatorObj { + Shape axes; + + public: + /** + * @brief Construct a new Unsqueeze object. + * + * @param graph The computation graph that this operator belongs to. + * @param input The input tensor. + * @param output The output tensor. + * @param axes List of integers indicating the dimensions to be inserted. + */ + UnsqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes); + OP_CLONE(UnsqueezeObj); + + optional> inferShape(const TensorVec &inputs) override; + + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + + inline Shape getAxes() const { return axes; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +} // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index c63746af..192e5273 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -52,10 +52,10 @@ class OnnxStub: self.inputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {} self.initializer: Dict[int, TensorProto] = {} - try: - model = infer_shapes(model) - except: - warnings.warn("infer_shapes failed.") + # try: + # model = infer_shapes(model) + # except: + # warnings.warn("infer_shapes failed.") self.handler = backend.GraphHandler(runtime) tensors: Dict[str, backend.Tensor] = dict() @@ -135,7 +135,7 @@ class OnnxStub: 1, reduce( lambda acc, x: acc * x, - _search_shape(model, node.input[2]), + tensors[node.input[2]].shape(), ), 1, 1, @@ -357,7 +357,7 @@ class OnnxStub: ceil_mode, ) elif node.op_type == "GlobalAveragePool": - [_, _, h, w] = _search_shape(model, node.input[0]) + [_, _, h, w] = tensors[node.input[0]].shape() tensors[node.output[0]] = self.handler.avgPool( tensors[node.input[0]], tensors.get(node.output[0]), @@ -595,35 +595,43 @@ class OnnxStub: coordinate_transformation_mode, ) elif node.op_type == "Squeeze": - input_shape = _search_shape(model, node.input[0]) - axes = set( - [int(i) for i in data[node.input[1]].int64_data] + axes = ( + _parse_data(data[node.input[1]]) if len(node.input) > 1 - else _parse_attribute(node, {"axes": None})["axes"] + else None ) - assert all(input_shape[d] == 1 for d in axes) - output_shape = [] - for i, x in enumerate(input_shape): - if i not in axes: - output_shape.append(x) - tensors[node.output[0]] = self.handler.reshape( + if axes is None: + axes = next( + ( + attr.ints + for attr in node.attribute + if attr.name == "axes" + ), + [], + ) + tensors[node.output[0]] = self.handler.squeeze( tensors[node.input[0]], tensors.get(node.output[0]), - output_shape, + axes, ) elif node.op_type == "Unsqueeze": - input_shape = _search_shape(model, node.input[0]) axes = ( - [int(i) for i in data[node.input[1]].int64_data] + _parse_data(data[node.input[1]]) if len(node.input) > 1 - else _parse_attribute(node, {"axes": None})["axes"] + else None ) - for i in axes: - input_shape.insert(i, 1) - tensors[node.output[0]] = self.handler.reshape( + if axes is None: + axes = next( + ( + attr.ints + for attr in node.attribute + if attr.name == "axes" + ) + ) + tensors[node.output[0]] = self.handler.unsqueeze( tensors[node.input[0]], tensors.get(node.output[0]), - input_shape, + axes, ) elif node.op_type == "Concat": tensors[node.output[0]] = self.handler.concat( @@ -935,8 +943,7 @@ class OnnxStub: node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1} ) (alpha, beta, bias, size) = ( - attributes[name] - for name in ["alpha", "beta", "bias", "size"] + attributes[name] for name in ["alpha", "beta", "bias", "size"] ) tensors[node.output[0]] = self.handler.lrn( tensors[node.input[0]], @@ -1207,6 +1214,30 @@ class OnnxStub: ) ) ctx.push_node(make_node(ty.name, inputs, outputs, name)) + elif ty == backend.OpTypeId.Squeeze: + axes = backend.squeeze_axes_of(op) + inputs.append( + ctx.push_data_input( + name, + "axes", + TensorProto.INT64, + [len(axes)], + axes, + ) + ) + ctx.push_node(make_node(ty.name, inputs, outputs, name)) + elif ty == backend.OpTypeId.Unsqueeze: + axes = backend.unsqueeze_axes_of(op) + inputs.append( + ctx.push_data_input( + name, + "axes", + TensorProto.INT64, + [len(axes)], + axes, + ) + ) + ctx.push_node(make_node(ty.name, inputs, outputs, name)) elif ty == backend.OpTypeId.Concat: axis = backend.concat_axis_of(op) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) @@ -1344,50 +1375,6 @@ def from_onnx(model: ModelProto, runtime): return stub.inputs, stub.outputs, stub.handler -def _search_shape(model: ModelProto, name: str) -> List[int]: - ans = ( - next( - ( - [ - (d.dim_value if d.dim_value > 0 else 1) - for d in tensor.type.tensor_type.shape.dim - ] - for tensor in model.graph.value_info - if tensor.name == name - ), - None, - ) - or next( - ( - [ - (d.dim_value if d.dim_value > 0 else 1) - for d in tensor.type.tensor_type.shape.dim - ] - for tensor in model.graph.input - if tensor.name == name - ), - None, - ) - or next( - ( - [ - (d.dim_value if d.dim_value > 0 else 1) - for d in tensor.type.tensor_type.shape.dim - ] - for tensor in model.graph.output - if tensor.name == name - ), - None, - ) - or next( - [int(d) for d in tensor.dims] - for tensor in model.graph.initializer - if tensor.name == name - ) - ) - return ans - - def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]: for attr in node.attribute: if attr.type == AttributeProto.INT: diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index f5d5a426..4d9c7574 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -303,6 +303,28 @@ class TestStringMethods(unittest.TestCase): reshape = make_node("Resize", ["x", "roi", "scales"], ["y"], name="resize") make_and_import_model(make_graph([reshape], "resize", [x], [y], [roi, scales])) + def test_squeeze(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 1, 5]) + axes = make_tensor_value_info("axes", TensorProto.INT64, [2]) + axes_data = make_tensor("axes", TensorProto.INT64, [2], [0, 2]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [3, 5]) + squeeze = make_node("Squeeze", ["input", "axes"], ["output"], name="squeeze") + make_and_import_model( + make_graph([squeeze], "squeeze", [input, axes], [output], [axes_data]) + ) + + def test_unsqueeze(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [2, 3, 4, 5]) + axes = make_tensor_value_info("axes", TensorProto.INT64, [2]) + axes_data = make_tensor("axes", TensorProto.INT64, [2], [0, 2]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 2, 1, 3, 4, 5]) + unsqueeze = make_node( + "Unsqueeze", ["input", "axes"], ["output"], name="unsqueeze" + ) + make_and_import_model( + make_graph([unsqueeze], "unsqueeze", [input, axes], [output], [axes_data]) + ) + def test_concat(self): input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4]) input2 = make_tensor_value_info("input2", TensorProto.FLOAT, [1, 3, 2, 5]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 7fc6f977..415ea947 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -22,8 +22,10 @@ #include "operators/slice.h" #include "operators/softmax.h" #include "operators/split.h" +#include "operators/squeeze.h" #include "operators/transpose.h" #include "operators/unary.h" +#include "operators/unsqueeze.h" #include "operators/where.h" #include #include @@ -608,6 +610,28 @@ Tensor GraphHandlerObj::lrn(Tensor input, Tensor output, float alpha, } } +Tensor GraphHandlerObj::squeeze(Tensor input, Tensor output, Shape axes) { + if (output) { + g->addOpWithOutputs(std::move(input), output, + std::move(axes)); + return output; + } else { + return g->addOp(std::move(input), output, std::move(axes)) + ->getOutput(); + } +} + +Tensor GraphHandlerObj::unsqueeze(Tensor input, Tensor output, Shape axes) { + if (output) { + g->addOpWithOutputs(std::move(input), output, + std::move(axes)); + return output; + } else { + return g->addOp(std::move(input), output, std::move(axes)) + ->getOutput(); + } +} + static CastType inferCastType(Tensor input, int to) { auto iType = input->getDType(); auto oType = DataType(to); diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index eadd4a4e..b565ad4d 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -12,8 +12,10 @@ #include "operators/reduce.h" #include "operators/reshape.h" #include "operators/split.h" +#include "operators/squeeze.h" #include "operators/transpose.h" #include "operators/unary.h" +#include "operators/unsqueeze.h" #include #include #include @@ -93,6 +95,8 @@ void export_values(py::module &m) { .VALUE(OpType, ReduceMean) .VALUE(OpType, ReduceSum) .VALUE(OpType, Reshape) + .VALUE(OpType, Squeeze) + .VALUE(OpType, Unsqueeze) .VALUE(OpType, Flatten) .VALUE(OpType, Identity) .VALUE(OpType, BatchNormalization) @@ -256,6 +260,24 @@ static vector reshape_shape_of(Operator op) { return ans; } +static vector squeeze_axes_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Squeeze); + auto axes = dynamic_cast(op.get())->getAxes(); + vector ans(axes.size()); + std::transform(axes.begin(), axes.end(), ans.begin(), + [](auto x) { return static_cast(x); }); + return ans; +} + +static vector unsqueeze_axes_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Unsqueeze); + auto axes = dynamic_cast(op.get())->getAxes(); + vector ans(axes.size()); + std::transform(axes.begin(), axes.end(), ans.begin(), + [](auto x) { return static_cast(x); }); + return ans; +} + static vector expand_shape_of(Operator op) { IT_ASSERT(op->getOpType() == OpType::Expand); auto shape = dynamic_cast(op.get())->getShape(); @@ -343,6 +365,8 @@ void export_functions(py::module &m) { .FUNCTION(flatten_axis_of) .FUNCTION(cast_to_of) .FUNCTION(depth_to_space_attrs_of) + .FUNCTION(squeeze_axes_of) + .FUNCTION(unsqueeze_axes_of) .FUNCTION(lrn_attrs_of); #undef FUNCTION } @@ -509,6 +533,8 @@ void init_graph_builder(py::module &m) { .def("depthToSpace", &Handler::depthToSpace, policy::move) .def("reshape", &Handler::reshape, policy::move) .def("resize", &Handler::resize, policy::move) + .def("squeeze", &Handler::squeeze, policy::move) + .def("unsqueeze", &Handler::unsqueeze, policy::move) .def("concat", &Handler::concat, policy::move) .def("attentionKVCache", &Handler::attentionKVCache, policy::move) .def("split", &Handler::split, policy::move) diff --git a/src/kernels/cuda/reshape.cc b/src/kernels/cuda/reshape.cc index 7be6aca8..232bcdf6 100644 --- a/src/kernels/cuda/reshape.cc +++ b/src/kernels/cuda/reshape.cc @@ -19,6 +19,10 @@ REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda, "Reshape_CUDA_Int32"); REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda, "Flatten_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, DataType::Float32, CopyCuda, + "Squeeze_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, DataType::Float32, CopyCuda, + "Unsqueeze_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda, "Identity_CUDA_Float32"); diff --git a/src/operators/squeeze.cc b/src/operators/squeeze.cc new file mode 100644 index 00000000..1609ecb9 --- /dev/null +++ b/src/operators/squeeze.cc @@ -0,0 +1,60 @@ +#include "operators/squeeze.h" +#include "utils/operator_utils.h" + +namespace infini { +SqueezeObj::SqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes) + : OperatorObj(OpType::Squeeze, {input}, {output}), axes(std::move(axes)) { + IT_ASSERT(checkValid(graph)); +} + +optional> SqueezeObj::inferShape(const TensorVec &inputs) { + Shape inputDim = inputs[0]->getDims(); + Shape outputShape; + auto rank = inputs[0]->getRank(); + if (axes.size() == 0) { + for (int i = 0; i < (int)rank; ++i) { + if (inputDim[i] == 1) { + axes.emplace_back(i); + } + } + } + auto new_axes = axes; + std::transform(axes.begin(), axes.end(), new_axes.begin(), + [inputDim, rank](auto x) { + x = get_real_axis(x, rank); + IT_ASSERT(inputDim[x] == 1); + return x; + }); + for (int i = 0; i < (int)rank; ++i) { + auto it = std::find(new_axes.begin(), new_axes.end(), i); + if (it == new_axes.end()) { + outputShape.emplace_back(inputDim[i]); + } + } + return {{outputShape}}; +} + +std::string SqueezeObj::toString() const { + std::ostringstream os; + os << "Squeeze[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "axes=" << vecToString(axes) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector SqueezeObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.insert(ret.end(), axes.begin(), axes.end()); + ret.emplace(ret.begin(), type.underlying()); + return ret; +} +vector SqueezeObj::getOpAttrVector() const { + vector ret = axes; + ret.emplace(ret.begin(), type.underlying()); + return ret; +} + +} // namespace infini diff --git a/src/operators/unsqueeze.cc b/src/operators/unsqueeze.cc new file mode 100644 index 00000000..090450bf --- /dev/null +++ b/src/operators/unsqueeze.cc @@ -0,0 +1,52 @@ +#include "operators/unsqueeze.h" +#include "utils/operator_utils.h" + +namespace infini { +UnsqueezeObj::UnsqueezeObj(GraphObj *graph, Tensor input, Tensor output, + Shape axes) + : OperatorObj(OpType::Unsqueeze, {input}, {output}), axes(std::move(axes)) { + IT_ASSERT(checkValid(graph)); +} + +optional> UnsqueezeObj::inferShape(const TensorVec &inputs) { + Shape inputDim = inputs[0]->getDims(); + auto rank = inputs[0]->getRank() + axes.size(); + Shape outputShape(rank, -1); + for (size_t i = 0; i < axes.size(); ++i) { + axes[i] = get_real_axis(axes[i], rank); + IT_ASSERT(outputShape[axes[i]] == -1, "Axes have duplicate"); + outputShape[axes[i]] = 1; + } + auto it = inputDim.begin(); + for (size_t i = 0; i < outputShape.size(); ++i) { + if (outputShape[i] == -1) { + outputShape[i] = *it++; + } + } + return {{outputShape}}; +} + +std::string UnsqueezeObj::toString() const { + std::ostringstream os; + os << "Unsqueeze[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "axes=" << vecToString(axes) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector UnsqueezeObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.insert(ret.end(), axes.begin(), axes.end()); + ret.emplace(ret.begin(), type.underlying()); + return ret; +} +vector UnsqueezeObj::getOpAttrVector() const { + vector ret = axes; + ret.emplace(ret.begin(), type.underlying()); + return ret; +} + +} // namespace infini diff --git a/test/operators/test_reshape.cc b/test/operators/test_reshape.cc index 00ab514f..39fa823d 100644 --- a/test/operators/test_reshape.cc +++ b/test/operators/test_reshape.cc @@ -2,6 +2,8 @@ #include "core/kernel.h" #include "core/runtime.h" #include "operators/reshape.h" +#include "operators/squeeze.h" +#include "operators/unsqueeze.h" #include "test.h" @@ -54,4 +56,36 @@ TEST(Identity, ShapeInference) { } } +TEST(Squeeze, ShapeInference) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 1, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, Shape{-2}); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 4})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({1, 1, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, Shape{}); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{3, 4})); + } +} + +TEST(Unsqueeze, ShapeInference) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, Shape{0, 1}); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 2, 3, 4})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, Shape{-1, -2}); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 4, 1, 1})); + } +} + } // namespace infini