From 32f6f02c811932a618e034069c76c946bf7f1f5a Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 20 Feb 2023 14:38:13 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AF=BC=E5=87=BA=205=20=E4=B8=AA?= =?UTF-8?q?=E5=8D=95=E7=9B=AE=E7=AE=97=E5=AD=90=E5=88=B0=20onnx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- pyinfinitensor/src/pyinfinitensor/onnx.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 0a5b3c29..ae21ab8b 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -336,6 +336,7 @@ def to_onnx(graph: backend.GraphHandler): def push_output(self, name: str, tensor: backend.Tensor) -> None: self.names[tensor] = name + # TODO 需要判断全图输出并保存到 outputs def push_input(self, tensor: backend.Tensor) -> str: name = self.names.get(tensor) @@ -384,16 +385,16 @@ def to_onnx(graph: backend.GraphHandler): a = context.push_input(inputs[0]) b = context.push_input(inputs[1]) context.push_node(make_node(ty.name, [a, b], [name], name)) - elif ty == backend.OpType.Relu: - raise Exception("TODO") - elif ty == backend.OpType.Sigmoid: - raise Exception("TODO") - elif ty == backend.OpType.Tanh: - raise Exception("TODO") - elif ty == backend.OpType.Softmax: - raise Exception("TODO") - elif ty == backend.OpType.Abs: - raise Exception("TODO") + elif ty in [ + backend.OpType.Relu, + backend.OpType.Sigmoid, + backend.OpType.Tanh, + backend.OpType.Softmax, + backend.OpType.Abs, + ]: + context.push_output(name, outputs[0]) + x = context.push_input(inputs[0]) + context.push_node(make_node(ty.name, [x], [name], name)) elif ty == backend.OpType.Identity: raise Exception("TODO") elif ty == backend.OpType.Flatten: