forked from jiuyuan/InfiniTensor
feat: 支持导出权重
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
40fb8390b1
commit
afed749b74
|
@ -22,7 +22,7 @@ from onnx.checker import (
|
||||||
check_tensor,
|
check_tensor,
|
||||||
)
|
)
|
||||||
from onnx.shape_inference import infer_shapes
|
from onnx.shape_inference import infer_shapes
|
||||||
from typing import Dict, List, Any, Tuple, Sequence, Union
|
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
cpu_runtime = backend.cpu_runtime()
|
cpu_runtime = backend.cpu_runtime()
|
||||||
|
@ -365,7 +365,7 @@ class OnnxStub:
|
||||||
def to_onnx(self, name: str) -> ModelProto:
|
def to_onnx(self, name: str) -> ModelProto:
|
||||||
class Context:
|
class Context:
|
||||||
# saves object names, including tensors and operators
|
# saves object names, including tensors and operators
|
||||||
names: Dict[Any, str] = dict()
|
names: Dict[Union[backend.Tensor, backend.Operator], str] = dict()
|
||||||
# counts the occurrence times of each operator for naming
|
# counts the occurrence times of each operator for naming
|
||||||
count_op: Dict[backend.OpType, int] = dict()
|
count_op: Dict[backend.OpType, int] = dict()
|
||||||
# counts input and output tensors for naming
|
# counts input and output tensors for naming
|
||||||
|
@ -396,7 +396,9 @@ class OnnxStub:
|
||||||
self.outputs.append(value_info)
|
self.outputs.append(value_info)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def push_input(self, tensor: backend.Tensor) -> str:
|
def push_input(
|
||||||
|
self, tensor: backend.Tensor, init: Optional[TensorProto]
|
||||||
|
) -> str:
|
||||||
name = self.names.get(tensor)
|
name = self.names.get(tensor)
|
||||||
# means that this input is a global input
|
# means that this input is a global input
|
||||||
if name is None:
|
if name is None:
|
||||||
|
@ -408,7 +410,9 @@ class OnnxStub:
|
||||||
value_info = make_tensor_value_info(name, dtype, shape)
|
value_info = make_tensor_value_info(name, dtype, shape)
|
||||||
check_value_info(value_info)
|
check_value_info(value_info)
|
||||||
self.inputs.append(value_info)
|
self.inputs.append(value_info)
|
||||||
|
if init != None:
|
||||||
|
init.name = name
|
||||||
|
self.initializers.append(init)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def push_data_input(
|
def push_data_input(
|
||||||
|
@ -462,7 +466,10 @@ class OnnxStub:
|
||||||
|
|
||||||
for op in ops:
|
for op in ops:
|
||||||
ty, name = ctx.name_op(op)
|
ty, name = ctx.name_op(op)
|
||||||
inputs = [ctx.push_input(it) for it in op.inputs()]
|
inputs = [
|
||||||
|
ctx.push_input(it, self.initializer.get(it.fuid()))
|
||||||
|
for it in op.inputs()
|
||||||
|
]
|
||||||
outputs = [
|
outputs = [
|
||||||
ctx.push_output("{}_{}".format(name, i), it)
|
ctx.push_output("{}_{}".format(name, i), it)
|
||||||
for (i, it) in enumerate(op.outputs())
|
for (i, it) in enumerate(op.outputs())
|
||||||
|
|
Loading…
Reference in New Issue