From afed749b7423d5b9de3a7d65dd54a3b6923f94c5 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 24 Feb 2023 16:57:37 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=AF=BC=E5=87=BA?= =?UTF-8?q?=E6=9D=83=E9=87=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- pyinfinitensor/src/pyinfinitensor/onnx.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index d554883c..d20b4ca7 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -22,7 +22,7 @@ from onnx.checker import ( check_tensor, ) from onnx.shape_inference import infer_shapes -from typing import Dict, List, Any, Tuple, Sequence, Union +from typing import Dict, List, Any, Tuple, Sequence, Union, Optional from functools import reduce cpu_runtime = backend.cpu_runtime() @@ -365,7 +365,7 @@ class OnnxStub: def to_onnx(self, name: str) -> ModelProto: class Context: # saves object names, including tensors and operators - names: Dict[Any, str] = dict() + names: Dict[Union[backend.Tensor, backend.Operator], str] = dict() # counts the occurrence times of each operator for naming count_op: Dict[backend.OpType, int] = dict() # counts input and output tensors for naming @@ -396,7 +396,9 @@ class OnnxStub: self.outputs.append(value_info) return name - def push_input(self, tensor: backend.Tensor) -> str: + def push_input( + self, tensor: backend.Tensor, init: Optional[TensorProto] + ) -> str: name = self.names.get(tensor) # means that this input is a global input if name is None: @@ -408,7 +410,9 @@ class OnnxStub: value_info = make_tensor_value_info(name, dtype, shape) check_value_info(value_info) self.inputs.append(value_info) - + if init != None: + init.name = name + self.initializers.append(init) return name def push_data_input( @@ -462,7 +466,10 @@ class OnnxStub: for op in ops: ty, name = ctx.name_op(op) - inputs = [ctx.push_input(it) for it in op.inputs()] + inputs = [ + ctx.push_input(it, self.initializer.get(it.fuid())) + for it in op.inputs() + ] outputs = [ ctx.push_output("{}_{}".format(name, i), it) for (i, it) in enumerate(op.outputs())