add constant op; fix concat bug (#151)

This commit is contained in:
Haojie Wang 2023-10-08 21:42:41 +08:00 committed by GitHub
parent f25bcca076
commit c0ff584e04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 14 deletions

View File

@ -663,6 +663,15 @@ class OnnxStub:
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Constant":
output_name = node.output[0]
attributes = _parse_attribute(node)
tensor = attributes['value']
dims = [d for d in tensor.dims]
tensors[output_name] = self.handler.tensor(
dims, tensor.data_type)
data[output_name] = tensor
tensors[output_name].set_weight()
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
new_node_name.append(node.name)
@ -1067,19 +1076,18 @@ def _search_shape(model: ModelProto, name: str) -> List[int]:
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
for attr in node.attribute:
if attr.name in attrs:
if attr.type == AttributeProto.INT:
attrs[attr.name] = attr.i
elif attr.type == AttributeProto.INTS:
attrs[attr.name] = attr.ints
elif attr.type == AttributeProto.FLOAT:
attrs[attr.name] = attr.f
elif attr.type == AttributeProto.STRING:
attrs[attr.name] = attr.s
elif attr.type == AttributeProto.TENSOR:
attrs[attr.name] = attr.t
else:
assert False, "Unsupported Attribute Type: {}".format(attr.type)
if attr.type == AttributeProto.INT:
attrs[attr.name] = attr.i
elif attr.type == AttributeProto.INTS:
attrs[attr.name] = attr.ints
elif attr.type == AttributeProto.FLOAT:
attrs[attr.name] = attr.f
elif attr.type == AttributeProto.STRING:
attrs[attr.name] = attr.s
elif attr.type == AttributeProto.TENSOR:
attrs[attr.name] = attr.t
else:
assert False, "Unsupported Attribute Type: {}".format(attr.type)
return attrs

View File

@ -10,7 +10,6 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim)
}
optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() > 1);
Shape dims = inputs[0]->getDims();
auto rank = inputs[0]->getRank();
ShapeElem n = dims.at(dim);