forked from jiuyuan/InfiniTensor
add constant op; fix concat bug (#151)
This commit is contained in:
parent
f25bcca076
commit
c0ff584e04
|
@ -663,6 +663,15 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Constant":
|
||||
output_name = node.output[0]
|
||||
attributes = _parse_attribute(node)
|
||||
tensor = attributes['value']
|
||||
dims = [d for d in tensor.dims]
|
||||
tensors[output_name] = self.handler.tensor(
|
||||
dims, tensor.data_type)
|
||||
data[output_name] = tensor
|
||||
tensors[output_name].set_weight()
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
new_node_name.append(node.name)
|
||||
|
@ -1067,7 +1076,6 @@ def _search_shape(model: ModelProto, name: str) -> List[int]:
|
|||
|
||||
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||
for attr in node.attribute:
|
||||
if attr.name in attrs:
|
||||
if attr.type == AttributeProto.INT:
|
||||
attrs[attr.name] = attr.i
|
||||
elif attr.type == AttributeProto.INTS:
|
||||
|
|
|
@ -10,7 +10,6 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim)
|
|||
}
|
||||
|
||||
optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) const {
|
||||
IT_ASSERT(inputs.size() > 1);
|
||||
Shape dims = inputs[0]->getDims();
|
||||
auto rank = inputs[0]->getRank();
|
||||
ShapeElem n = dims.at(dim);
|
||||
|
|
Loading…
Reference in New Issue