feat: 导出 MatMul Concat 到 onnx

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-21 12:44:45 +08:00
parent 32f6f02c81
commit 9d9fbd44af
2 changed files with 45 additions and 7 deletions

View File

@ -7,9 +7,9 @@ from onnx import (
TensorShapeProto,
ValueInfoProto,
)
from onnx.helper import make_node, make_tensor_value_info
from onnx.helper import make_node, make_tensor_value_info, make_tensor
from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any, Tuple
from typing import Dict, List, Any, Tuple, Sequence
from functools import reduce
runtime = backend.cpu_runtime()
@ -326,6 +326,8 @@ def to_onnx(graph: backend.GraphHandler):
inputs: List[ValueInfoProto] = []
# saves global output tensors
outputs: List[ValueInfoProto] = []
# saves global input tensors
initializers: List[TensorProto] = []
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
ty = op.op_type()
@ -351,6 +353,19 @@ def to_onnx(graph: backend.GraphHandler):
return name
def push_data_input(
self,
node_name: str,
attr_name: str,
elem_type: int,
shape: Sequence[int],
vals: Any,
) -> str:
name = "{}_{}".format(node_name, attr_name)
self.inputs.append(make_tensor_value_info(name, elem_type, shape))
self.initializers.append(make_tensor(name, elem_type, shape, vals))
return name
def push_node(self, node: NodeProto) -> None:
self.nodes.append(node)
@ -367,7 +382,10 @@ def to_onnx(graph: backend.GraphHandler):
inputs = op.inputs()
outputs = op.outputs()
if ty == backend.OpType.Matmul:
raise Exception("TODO")
context.push_output(name, outputs[0])
a = context.push_input(inputs[0])
b = context.push_input(inputs[1])
context.push_node(make_node("MatMul", [a, b], [name], name))
elif ty == backend.OpType.BatchNorm:
raise Exception("TODO")
elif ty == backend.OpType.MaxPool:
@ -391,18 +409,31 @@ def to_onnx(graph: backend.GraphHandler):
backend.OpType.Tanh,
backend.OpType.Softmax,
backend.OpType.Abs,
backend.OpType.Identity,
]:
context.push_output(name, outputs[0])
x = context.push_input(inputs[0])
context.push_node(make_node(ty.name, [x], [name], name))
elif ty == backend.OpType.Identity:
raise Exception("TODO")
elif ty == backend.OpType.Flatten:
raise Exception("TODO")
elif ty == backend.OpType.Reshape:
context.push_output(name, outputs[0])
data = context.push_input(inputs[0])
# shape = context.push_data_input(
# name,
# "shape",
# TensorProto.INT32,
# shape=[len(vals)],
# vals=1,
# )
# context.push_node(make_node(ty.name, [data, shape], [name], name))
raise Exception("TODO")
elif ty == backend.OpType.Concat:
raise Exception("TODO")
context.push_output(name, outputs[0])
a = context.push_input(inputs[0])
b = context.push_input(inputs[1])
axis = backend.concat_dim_of(op)
context.push_node(make_node("Concat", [a, b], [name], name, axis=axis))
elif ty == backend.OpType.Gather:
raise Exception("TODO")
elif ty == backend.OpType.ReduceMean:

View File

@ -1,4 +1,5 @@
#include "core/graph_handler.h"
#include "operators/concat.h"
#include <pybind11/stl.h>
#ifdef USE_CUDA
@ -90,12 +91,18 @@ static int tensor_dtype(Tensor t) {
IT_ASSERT(false, "Unsupported data type");
}
static int concat_dim_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Concat);
return reinterpret_cast<const ConcatObj *>(op.get())->getDim();
}
void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj;
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
.def("tensor_dtype", &tensor_dtype);
.def("tensor_dtype", &tensor_dtype)
.def("concat_dim_of", &concat_dim_of);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntime");