diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 00838b8d..03339455 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -336,9 +336,10 @@ def to_onnx(graph: backend.GraphHandler): self.count_op[ty] += 1 return ty, name - def push_output(self, name: str, tensor: backend.Tensor) -> None: + def push_output(self, name: str, tensor: backend.Tensor) -> str: self.names[tensor] = name # TODO 需要判断全图输出并保存到 outputs + return name def push_input(self, tensor: backend.Tensor) -> str: name = self.names.get(tensor) @@ -375,17 +376,17 @@ def to_onnx(graph: backend.GraphHandler): ops = graph.operators() # 图中所有算子(节点) - context = Context() + ctx = Context() for op in ops: - ty, name = context.name_op(op) - inputs = op.inputs() - outputs = op.outputs() + ty, name = ctx.name_op(op) + inputs = [ctx.push_input(it) for it in op.inputs()] + outputs = [ + ctx.push_output("{}_{}".format(name, i), it) + for (i, it) in enumerate(op.outputs()) + ] if ty == backend.OpType.Matmul: - context.push_output(name, outputs[0]) - a = context.push_input(inputs[0]) - b = context.push_input(inputs[1]) - context.push_node(make_node("MatMul", [a, b], [name], name)) + ctx.push_node(make_node("MatMul", inputs, outputs, name)) elif ty == backend.OpType.BatchNorm: raise Exception("TODO") elif ty == backend.OpType.MaxPool: @@ -398,12 +399,6 @@ def to_onnx(graph: backend.GraphHandler): backend.OpType.Mul, backend.OpType.Div, backend.OpType.Pow, - ]: - context.push_output(name, outputs[0]) - a = context.push_input(inputs[0]) - b = context.push_input(inputs[1]) - context.push_node(make_node(ty.name, [a, b], [name], name)) - elif ty in [ backend.OpType.Relu, backend.OpType.Sigmoid, backend.OpType.Tanh, @@ -411,14 +406,11 @@ def to_onnx(graph: backend.GraphHandler): backend.OpType.Abs, backend.OpType.Identity, ]: - context.push_output(name, outputs[0]) - x = context.push_input(inputs[0]) - context.push_node(make_node(ty.name, [x], [name], name)) + ctx.push_node(make_node(ty.name, inputs, outputs, name)) elif ty == backend.OpType.Flatten: raise Exception("TODO") elif ty == backend.OpType.Reshape: - context.push_output(name, outputs[0]) - data = context.push_input(inputs[0]) + data = ctx.push_input(inputs[0]) # shape = context.push_data_input( # name, # "shape", @@ -429,13 +421,11 @@ def to_onnx(graph: backend.GraphHandler): # context.push_node(make_node(ty.name, [data, shape], [name], name)) raise Exception("TODO") elif ty == backend.OpType.Concat: - context.push_output(name, outputs[0]) - a = context.push_input(inputs[0]) - b = context.push_input(inputs[1]) - axis = backend.concat_dim_of(op) - context.push_node(make_node("Concat", [a, b], [name], name, axis=axis)) + axis = backend.concat_axis_of(op) + ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) elif ty == backend.OpType.Gather: - raise Exception("TODO") + axis = backend.gather_axis_of(op) + ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) elif ty == backend.OpType.ReduceMean: raise Exception("TODO") elif ty == backend.OpType.Slice: @@ -446,13 +436,13 @@ def to_onnx(graph: backend.GraphHandler): raise Exception("Unsupported OpType {}".format(ty.name)) print() - print(context.names) + print(ctx.names) print() - print(context.inputs) + print(ctx.inputs) print() - print(context.outputs) + print(ctx.outputs) print() - print(context.nodes) + print(ctx.nodes) def parse_onnx(model: ModelProto): diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 795468de..0ffb1b31 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -1,5 +1,6 @@ #include "core/graph_handler.h" #include "operators/concat.h" +#include "operators/gather.h" #include #ifdef USE_CUDA @@ -91,18 +92,24 @@ static int tensor_dtype(Tensor t) { IT_ASSERT(false, "Unsupported data type"); } -static int concat_dim_of(Operator op) { +static int concat_axis_of(Operator op) { IT_ASSERT(op->getOpType() == OpType::Concat); return reinterpret_cast(op.get())->getDim(); } +static int gather_axis_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Gather); + return reinterpret_cast(op.get())->getAxis(); +} + void init_graph_builder(py::module &m) { using Handler = GraphHandlerObj; m.def("cpu_runtime", &CpuRuntimeObj::getInstance) .def("tensor_dtype", &tensor_dtype) - .def("concat_dim_of", &concat_dim_of); + .def("concat_axis_of", &concat_axis_of) + .def("gather_axis_of", &gather_axis_of); py::class_>(m, "Runtime"); py::class_, RuntimeObj>( m, "CpuRuntime");