diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 0a5b3c29..ae21ab8b 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -336,6 +336,7 @@ def to_onnx(graph: backend.GraphHandler): def push_output(self, name: str, tensor: backend.Tensor) -> None: self.names[tensor] = name + # TODO 需要判断全图输出并保存到 outputs def push_input(self, tensor: backend.Tensor) -> str: name = self.names.get(tensor) @@ -384,16 +385,16 @@ def to_onnx(graph: backend.GraphHandler): a = context.push_input(inputs[0]) b = context.push_input(inputs[1]) context.push_node(make_node(ty.name, [a, b], [name], name)) - elif ty == backend.OpType.Relu: - raise Exception("TODO") - elif ty == backend.OpType.Sigmoid: - raise Exception("TODO") - elif ty == backend.OpType.Tanh: - raise Exception("TODO") - elif ty == backend.OpType.Softmax: - raise Exception("TODO") - elif ty == backend.OpType.Abs: - raise Exception("TODO") + elif ty in [ + backend.OpType.Relu, + backend.OpType.Sigmoid, + backend.OpType.Tanh, + backend.OpType.Softmax, + backend.OpType.Abs, + ]: + context.push_output(name, outputs[0]) + x = context.push_input(inputs[0]) + context.push_node(make_node(ty.name, [x], [name], name)) elif ty == backend.OpType.Identity: raise Exception("TODO") elif ty == backend.OpType.Flatten: