forked from jiuyuan/InfiniTensor
feat: 导出 MatMul Concat 到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
32f6f02c81
commit
9d9fbd44af
|
@ -7,9 +7,9 @@ from onnx import (
|
|||
TensorShapeProto,
|
||||
ValueInfoProto,
|
||||
)
|
||||
from onnx.helper import make_node, make_tensor_value_info
|
||||
from onnx.helper import make_node, make_tensor_value_info, make_tensor
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from typing import Dict, List, Any, Tuple
|
||||
from typing import Dict, List, Any, Tuple, Sequence
|
||||
from functools import reduce
|
||||
|
||||
runtime = backend.cpu_runtime()
|
||||
|
@ -326,6 +326,8 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
inputs: List[ValueInfoProto] = []
|
||||
# saves global output tensors
|
||||
outputs: List[ValueInfoProto] = []
|
||||
# saves global input tensors
|
||||
initializers: List[TensorProto] = []
|
||||
|
||||
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
||||
ty = op.op_type()
|
||||
|
@ -351,6 +353,19 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
|
||||
return name
|
||||
|
||||
def push_data_input(
|
||||
self,
|
||||
node_name: str,
|
||||
attr_name: str,
|
||||
elem_type: int,
|
||||
shape: Sequence[int],
|
||||
vals: Any,
|
||||
) -> str:
|
||||
name = "{}_{}".format(node_name, attr_name)
|
||||
self.inputs.append(make_tensor_value_info(name, elem_type, shape))
|
||||
self.initializers.append(make_tensor(name, elem_type, shape, vals))
|
||||
return name
|
||||
|
||||
def push_node(self, node: NodeProto) -> None:
|
||||
self.nodes.append(node)
|
||||
|
||||
|
@ -367,7 +382,10 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
inputs = op.inputs()
|
||||
outputs = op.outputs()
|
||||
if ty == backend.OpType.Matmul:
|
||||
raise Exception("TODO")
|
||||
context.push_output(name, outputs[0])
|
||||
a = context.push_input(inputs[0])
|
||||
b = context.push_input(inputs[1])
|
||||
context.push_node(make_node("MatMul", [a, b], [name], name))
|
||||
elif ty == backend.OpType.BatchNorm:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.MaxPool:
|
||||
|
@ -391,18 +409,31 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
backend.OpType.Tanh,
|
||||
backend.OpType.Softmax,
|
||||
backend.OpType.Abs,
|
||||
backend.OpType.Identity,
|
||||
]:
|
||||
context.push_output(name, outputs[0])
|
||||
x = context.push_input(inputs[0])
|
||||
context.push_node(make_node(ty.name, [x], [name], name))
|
||||
elif ty == backend.OpType.Identity:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Flatten:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Reshape:
|
||||
context.push_output(name, outputs[0])
|
||||
data = context.push_input(inputs[0])
|
||||
# shape = context.push_data_input(
|
||||
# name,
|
||||
# "shape",
|
||||
# TensorProto.INT32,
|
||||
# shape=[len(vals)],
|
||||
# vals=1,
|
||||
# )
|
||||
# context.push_node(make_node(ty.name, [data, shape], [name], name))
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Concat:
|
||||
raise Exception("TODO")
|
||||
context.push_output(name, outputs[0])
|
||||
a = context.push_input(inputs[0])
|
||||
b = context.push_input(inputs[1])
|
||||
axis = backend.concat_dim_of(op)
|
||||
context.push_node(make_node("Concat", [a, b], [name], name, axis=axis))
|
||||
elif ty == backend.OpType.Gather:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.ReduceMean:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/concat.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
@ -90,12 +91,18 @@ static int tensor_dtype(Tensor t) {
|
|||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
||||
static int concat_dim_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||
return reinterpret_cast<const ConcatObj *>(op.get())->getDim();
|
||||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
|
||||
using Handler = GraphHandlerObj;
|
||||
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
.def("tensor_dtype", &tensor_dtype);
|
||||
.def("tensor_dtype", &tensor_dtype)
|
||||
.def("concat_dim_of", &concat_dim_of);
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntime");
|
||||
|
|
Loading…
Reference in New Issue