forked from jiuyuan/InfiniTensor
feat: 支持导出权重
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
40fb8390b1
commit
afed749b74
|
@ -22,7 +22,7 @@ from onnx.checker import (
|
|||
check_tensor,
|
||||
)
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from typing import Dict, List, Any, Tuple, Sequence, Union
|
||||
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
||||
from functools import reduce
|
||||
|
||||
cpu_runtime = backend.cpu_runtime()
|
||||
|
@ -365,7 +365,7 @@ class OnnxStub:
|
|||
def to_onnx(self, name: str) -> ModelProto:
|
||||
class Context:
|
||||
# saves object names, including tensors and operators
|
||||
names: Dict[Any, str] = dict()
|
||||
names: Dict[Union[backend.Tensor, backend.Operator], str] = dict()
|
||||
# counts the occurrence times of each operator for naming
|
||||
count_op: Dict[backend.OpType, int] = dict()
|
||||
# counts input and output tensors for naming
|
||||
|
@ -396,7 +396,9 @@ class OnnxStub:
|
|||
self.outputs.append(value_info)
|
||||
return name
|
||||
|
||||
def push_input(self, tensor: backend.Tensor) -> str:
|
||||
def push_input(
|
||||
self, tensor: backend.Tensor, init: Optional[TensorProto]
|
||||
) -> str:
|
||||
name = self.names.get(tensor)
|
||||
# means that this input is a global input
|
||||
if name is None:
|
||||
|
@ -408,7 +410,9 @@ class OnnxStub:
|
|||
value_info = make_tensor_value_info(name, dtype, shape)
|
||||
check_value_info(value_info)
|
||||
self.inputs.append(value_info)
|
||||
|
||||
if init != None:
|
||||
init.name = name
|
||||
self.initializers.append(init)
|
||||
return name
|
||||
|
||||
def push_data_input(
|
||||
|
@ -462,7 +466,10 @@ class OnnxStub:
|
|||
|
||||
for op in ops:
|
||||
ty, name = ctx.name_op(op)
|
||||
inputs = [ctx.push_input(it) for it in op.inputs()]
|
||||
inputs = [
|
||||
ctx.push_input(it, self.initializer.get(it.fuid()))
|
||||
for it in op.inputs()
|
||||
]
|
||||
outputs = [
|
||||
ctx.push_output("{}_{}".format(name, i), it)
|
||||
for (i, it) in enumerate(op.outputs())
|
||||
|
|
Loading…
Reference in New Issue