feat: 支持导出权重

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-24 16:57:37 +08:00
parent 40fb8390b1
commit afed749b74
1 changed files with 12 additions and 5 deletions

View File

@ -22,7 +22,7 @@ from onnx.checker import (
check_tensor,
)
from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any, Tuple, Sequence, Union
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
from functools import reduce
cpu_runtime = backend.cpu_runtime()
@ -365,7 +365,7 @@ class OnnxStub:
def to_onnx(self, name: str) -> ModelProto:
class Context:
# saves object names, including tensors and operators
names: Dict[Any, str] = dict()
names: Dict[Union[backend.Tensor, backend.Operator], str] = dict()
# counts the occurrence times of each operator for naming
count_op: Dict[backend.OpType, int] = dict()
# counts input and output tensors for naming
@ -396,7 +396,9 @@ class OnnxStub:
self.outputs.append(value_info)
return name
def push_input(self, tensor: backend.Tensor) -> str:
def push_input(
self, tensor: backend.Tensor, init: Optional[TensorProto]
) -> str:
name = self.names.get(tensor)
# means that this input is a global input
if name is None:
@ -408,7 +410,9 @@ class OnnxStub:
value_info = make_tensor_value_info(name, dtype, shape)
check_value_info(value_info)
self.inputs.append(value_info)
if init != None:
init.name = name
self.initializers.append(init)
return name
def push_data_input(
@ -462,7 +466,10 @@ class OnnxStub:
for op in ops:
ty, name = ctx.name_op(op)
inputs = [ctx.push_input(it) for it in op.inputs()]
inputs = [
ctx.push_input(it, self.initializer.get(it.fuid()))
for it in op.inputs()
]
outputs = [
ctx.push_output("{}_{}".format(name, i), it)
for (i, it) in enumerate(op.outputs())