From c0ff584e0407c71aa372db21c4d29ee3ea1b01d8 Mon Sep 17 00:00:00 2001 From: Haojie Wang Date: Sun, 8 Oct 2023 21:42:41 +0800 Subject: [PATCH] add constant op; fix concat bug (#151) --- pyinfinitensor/src/pyinfinitensor/onnx.py | 34 ++++++++++++++--------- src/operators/concat.cc | 1 - 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 121ad22e..f8e53b1c 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -663,6 +663,15 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), ) + elif node.op_type == "Constant": + output_name = node.output[0] + attributes = _parse_attribute(node) + tensor = attributes['value'] + dims = [d for d in tensor.dims] + tensors[output_name] = self.handler.tensor( + dims, tensor.data_type) + data[output_name] = tensor + tensors[output_name].set_weight() else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) new_node_name.append(node.name) @@ -1067,19 +1076,18 @@ def _search_shape(model: ModelProto, name: str) -> List[int]: def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]: for attr in node.attribute: - if attr.name in attrs: - if attr.type == AttributeProto.INT: - attrs[attr.name] = attr.i - elif attr.type == AttributeProto.INTS: - attrs[attr.name] = attr.ints - elif attr.type == AttributeProto.FLOAT: - attrs[attr.name] = attr.f - elif attr.type == AttributeProto.STRING: - attrs[attr.name] = attr.s - elif attr.type == AttributeProto.TENSOR: - attrs[attr.name] = attr.t - else: - assert False, "Unsupported Attribute Type: {}".format(attr.type) + if attr.type == AttributeProto.INT: + attrs[attr.name] = attr.i + elif attr.type == AttributeProto.INTS: + attrs[attr.name] = attr.ints + elif attr.type == AttributeProto.FLOAT: + attrs[attr.name] = attr.f + elif attr.type == AttributeProto.STRING: + attrs[attr.name] = attr.s + elif attr.type == AttributeProto.TENSOR: + attrs[attr.name] = attr.t + else: + assert False, "Unsupported Attribute Type: {}".format(attr.type) return attrs diff --git a/src/operators/concat.cc b/src/operators/concat.cc index 78e30dad..de836d58 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -10,7 +10,6 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim) } optional> ConcatObj::inferShape(const TensorVec &inputs) const { - IT_ASSERT(inputs.size() > 1); Shape dims = inputs[0]->getDims(); auto rank = inputs[0]->getRank(); ShapeElem n = dims.at(dim);