forked from jiuyuan/InfiniTensor
add constant op; fix concat bug (#151)
This commit is contained in:
parent
f25bcca076
commit
c0ff584e04
|
@ -663,6 +663,15 @@ class OnnxStub:
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "Constant":
|
||||||
|
output_name = node.output[0]
|
||||||
|
attributes = _parse_attribute(node)
|
||||||
|
tensor = attributes['value']
|
||||||
|
dims = [d for d in tensor.dims]
|
||||||
|
tensors[output_name] = self.handler.tensor(
|
||||||
|
dims, tensor.data_type)
|
||||||
|
data[output_name] = tensor
|
||||||
|
tensors[output_name].set_weight()
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
new_node_name.append(node.name)
|
new_node_name.append(node.name)
|
||||||
|
@ -1067,19 +1076,18 @@ def _search_shape(model: ModelProto, name: str) -> List[int]:
|
||||||
|
|
||||||
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||||
for attr in node.attribute:
|
for attr in node.attribute:
|
||||||
if attr.name in attrs:
|
if attr.type == AttributeProto.INT:
|
||||||
if attr.type == AttributeProto.INT:
|
attrs[attr.name] = attr.i
|
||||||
attrs[attr.name] = attr.i
|
elif attr.type == AttributeProto.INTS:
|
||||||
elif attr.type == AttributeProto.INTS:
|
attrs[attr.name] = attr.ints
|
||||||
attrs[attr.name] = attr.ints
|
elif attr.type == AttributeProto.FLOAT:
|
||||||
elif attr.type == AttributeProto.FLOAT:
|
attrs[attr.name] = attr.f
|
||||||
attrs[attr.name] = attr.f
|
elif attr.type == AttributeProto.STRING:
|
||||||
elif attr.type == AttributeProto.STRING:
|
attrs[attr.name] = attr.s
|
||||||
attrs[attr.name] = attr.s
|
elif attr.type == AttributeProto.TENSOR:
|
||||||
elif attr.type == AttributeProto.TENSOR:
|
attrs[attr.name] = attr.t
|
||||||
attrs[attr.name] = attr.t
|
else:
|
||||||
else:
|
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
||||||
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,6 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) const {
|
optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) const {
|
||||||
IT_ASSERT(inputs.size() > 1);
|
|
||||||
Shape dims = inputs[0]->getDims();
|
Shape dims = inputs[0]->getDims();
|
||||||
auto rank = inputs[0]->getRank();
|
auto rank = inputs[0]->getRank();
|
||||||
ShapeElem n = dims.at(dim);
|
ShapeElem n = dims.at(dim);
|
||||||
|
|
Loading…
Reference in New Issue