From 9d9fbd44aff51db81ed1c8cc17dfcc315b9de813 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 21 Feb 2023 12:44:45 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AF=BC=E5=87=BA=20MatMul=20Concat=20?= =?UTF-8?q?=E5=88=B0=20onnx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- pyinfinitensor/src/pyinfinitensor/onnx.py | 43 +++++++++++++++++++---- src/ffi/ffi_infinitensor.cc | 9 ++++- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index ae21ab8b..00838b8d 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -7,9 +7,9 @@ from onnx import ( TensorShapeProto, ValueInfoProto, ) -from onnx.helper import make_node, make_tensor_value_info +from onnx.helper import make_node, make_tensor_value_info, make_tensor from onnx.shape_inference import infer_shapes -from typing import Dict, List, Any, Tuple +from typing import Dict, List, Any, Tuple, Sequence from functools import reduce runtime = backend.cpu_runtime() @@ -326,6 +326,8 @@ def to_onnx(graph: backend.GraphHandler): inputs: List[ValueInfoProto] = [] # saves global output tensors outputs: List[ValueInfoProto] = [] + # saves global input tensors + initializers: List[TensorProto] = [] def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]: ty = op.op_type() @@ -351,6 +353,19 @@ def to_onnx(graph: backend.GraphHandler): return name + def push_data_input( + self, + node_name: str, + attr_name: str, + elem_type: int, + shape: Sequence[int], + vals: Any, + ) -> str: + name = "{}_{}".format(node_name, attr_name) + self.inputs.append(make_tensor_value_info(name, elem_type, shape)) + self.initializers.append(make_tensor(name, elem_type, shape, vals)) + return name + def push_node(self, node: NodeProto) -> None: self.nodes.append(node) @@ -367,7 +382,10 @@ def to_onnx(graph: backend.GraphHandler): inputs = op.inputs() outputs = op.outputs() if ty == backend.OpType.Matmul: - raise Exception("TODO") + context.push_output(name, outputs[0]) + a = context.push_input(inputs[0]) + b = context.push_input(inputs[1]) + context.push_node(make_node("MatMul", [a, b], [name], name)) elif ty == backend.OpType.BatchNorm: raise Exception("TODO") elif ty == backend.OpType.MaxPool: @@ -391,18 +409,31 @@ def to_onnx(graph: backend.GraphHandler): backend.OpType.Tanh, backend.OpType.Softmax, backend.OpType.Abs, + backend.OpType.Identity, ]: context.push_output(name, outputs[0]) x = context.push_input(inputs[0]) context.push_node(make_node(ty.name, [x], [name], name)) - elif ty == backend.OpType.Identity: - raise Exception("TODO") elif ty == backend.OpType.Flatten: raise Exception("TODO") elif ty == backend.OpType.Reshape: + context.push_output(name, outputs[0]) + data = context.push_input(inputs[0]) + # shape = context.push_data_input( + # name, + # "shape", + # TensorProto.INT32, + # shape=[len(vals)], + # vals=1, + # ) + # context.push_node(make_node(ty.name, [data, shape], [name], name)) raise Exception("TODO") elif ty == backend.OpType.Concat: - raise Exception("TODO") + context.push_output(name, outputs[0]) + a = context.push_input(inputs[0]) + b = context.push_input(inputs[1]) + axis = backend.concat_dim_of(op) + context.push_node(make_node("Concat", [a, b], [name], name, axis=axis)) elif ty == backend.OpType.Gather: raise Exception("TODO") elif ty == backend.OpType.ReduceMean: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 5b556e89..795468de 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -1,4 +1,5 @@ #include "core/graph_handler.h" +#include "operators/concat.h" #include #ifdef USE_CUDA @@ -90,12 +91,18 @@ static int tensor_dtype(Tensor t) { IT_ASSERT(false, "Unsupported data type"); } +static int concat_dim_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Concat); + return reinterpret_cast(op.get())->getDim(); +} + void init_graph_builder(py::module &m) { using Handler = GraphHandlerObj; m.def("cpu_runtime", &CpuRuntimeObj::getInstance) - .def("tensor_dtype", &tensor_dtype); + .def("tensor_dtype", &tensor_dtype) + .def("concat_dim_of", &concat_dim_of); py::class_>(m, "Runtime"); py::class_, RuntimeObj>( m, "CpuRuntime");