feat: 导出 5 个单目算子到 onnx

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-20 14:38:13 +08:00
parent 0517089dca
commit 32f6f02c81
1 changed files with 11 additions and 10 deletions

View File

@ -336,6 +336,7 @@ def to_onnx(graph: backend.GraphHandler):
def push_output(self, name: str, tensor: backend.Tensor) -> None:
self.names[tensor] = name
# TODO 需要判断全图输出并保存到 outputs
def push_input(self, tensor: backend.Tensor) -> str:
name = self.names.get(tensor)
@ -384,16 +385,16 @@ def to_onnx(graph: backend.GraphHandler):
a = context.push_input(inputs[0])
b = context.push_input(inputs[1])
context.push_node(make_node(ty.name, [a, b], [name], name))
elif ty == backend.OpType.Relu:
raise Exception("TODO")
elif ty == backend.OpType.Sigmoid:
raise Exception("TODO")
elif ty == backend.OpType.Tanh:
raise Exception("TODO")
elif ty == backend.OpType.Softmax:
raise Exception("TODO")
elif ty == backend.OpType.Abs:
raise Exception("TODO")
elif ty in [
backend.OpType.Relu,
backend.OpType.Sigmoid,
backend.OpType.Tanh,
backend.OpType.Softmax,
backend.OpType.Abs,
]:
context.push_output(name, outputs[0])
x = context.push_input(inputs[0])
context.push_node(make_node(ty.name, [x], [name], name))
elif ty == backend.OpType.Identity:
raise Exception("TODO")
elif ty == backend.OpType.Flatten: