feat: 导出加减乘除幂到 onnx

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-17 17:15:15 +08:00
parent f2591edbb4
commit 0833a2f779
3 changed files with 169 additions and 89 deletions

View File

@ -1,4 +1,6 @@
import onnx, backend
import backend
from onnx import ModelProto, TensorProto, NodeProto, AttributeProto, TensorShapeProto
from onnx.helper import make_node
from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any
from functools import reduce
@ -6,12 +8,12 @@ from functools import reduce
runtime = backend.cpu_runtime()
def from_onnx(model: onnx.ModelProto) -> backend.GraphHandler:
def from_onnx(model: ModelProto) -> backend.GraphHandler:
model = infer_shapes(model)
handler = backend.GraphHandler(runtime)
tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, onnx.TensorProto] = dict()
data: Dict[str, TensorProto] = dict()
for input in model.graph.input:
dims = _take_shape_dim(input.type.tensor_type.shape)
@ -310,16 +312,123 @@ def to_onnx(graph: backend.GraphHandler):
ops = graph.operators()
names: Dict[Any, str] = dict()
count: Dict[backend.OpType, int] = dict()
nodes: List[NodeProto] = []
count_op: Dict[backend.OpType, int] = dict()
count_in = 0
for op in ops:
ty = op.op_type()
names[op] = "{}{}".format(ty.name, count.setdefault(ty, 0) + 1)
count[ty] += 1
name = "{}{}".format(ty.name, count_op.setdefault(ty, 0) + 1)
inputs = op.inputs()
outputs = op.outputs()
names[op] = name
count_op[ty] += 1
if ty == backend.OpType.Matmul:
raise Exception("TODO")
elif ty == backend.OpType.BatchNorm:
raise Exception("TODO")
elif ty == backend.OpType.MaxPool:
raise Exception("TODO")
elif ty == backend.OpType.AvgPool:
raise Exception("TODO")
elif ty == backend.OpType.Add:
names[outputs[0]] = name
if inputs[0] in names:
a = names[inputs[0]]
else:
count_in += 1
a = "input{}".format(count_in)
if inputs[1] in names:
b = names[inputs[1]]
else:
count_in += 1
b = "input{}".format(count_in)
nodes.append(make_node("Add", [a, b], [name], name))
elif ty == backend.OpType.Sub:
names[outputs[0]] = name
if inputs[0] in names:
a = names[inputs[0]]
else:
count_in += 1
a = "input{}".format(count_in)
if inputs[1] in names:
b = names[inputs[1]]
else:
count_in += 1
b = "input{}".format(count_in)
nodes.append(make_node("Sub", [a, b], [name], name))
elif ty == backend.OpType.Mul:
names[outputs[0]] = name
if inputs[0] in names:
a = names[inputs[0]]
else:
count_in += 1
a = "input{}".format(count_in)
if inputs[1] in names:
b = names[inputs[1]]
else:
count_in += 1
b = "input{}".format(count_in)
nodes.append(make_node("Mul", [a, b], [name], name))
elif ty == backend.OpType.Div:
names[outputs[0]] = name
if inputs[0] in names:
a = names[inputs[0]]
else:
count_in += 1
a = "input{}".format(count_in)
if inputs[1] in names:
b = names[inputs[1]]
else:
count_in += 1
b = "input{}".format(count_in)
nodes.append(make_node("Div", [a, b], [name], name))
elif ty == backend.OpType.Pow:
names[outputs[0]] = name
if inputs[0] in names:
a = names[inputs[0]]
else:
count_in += 1
a = "input{}".format(count_in)
if inputs[1] in names:
b = names[inputs[1]]
else:
count_in += 1
b = "input{}".format(count_in)
nodes.append(make_node("Pow", [a, b], [name], name))
elif ty == backend.OpType.Relu:
raise Exception("TODO")
elif ty == backend.OpType.Sigmoid:
raise Exception("TODO")
elif ty == backend.OpType.Tanh:
raise Exception("TODO")
elif ty == backend.OpType.Softmax:
raise Exception("TODO")
elif ty == backend.OpType.Abs:
raise Exception("TODO")
elif ty == backend.OpType.Identity:
raise Exception("TODO")
elif ty == backend.OpType.Flatten:
raise Exception("TODO")
elif ty == backend.OpType.Reshape:
raise Exception("TODO")
elif ty == backend.OpType.Concat:
raise Exception("TODO")
elif ty == backend.OpType.Gather:
raise Exception("TODO")
elif ty == backend.OpType.ReduceMean:
raise Exception("TODO")
elif ty == backend.OpType.Slice:
raise Exception("TODO")
elif ty == backend.OpType.Pad:
raise Exception("TODO")
else:
raise Exception("Unsupported OpType {}".format(ty.name))
print(names)
def parse_onnx(model: onnx.ModelProto):
def parse_onnx(model: ModelProto):
print()
for field in [
@ -355,34 +464,32 @@ def parse_onnx(model: onnx.ModelProto):
print(" {}".format(node.name))
def _parse_attribute(
node: onnx.NodeProto, attrs: Dict[str, Any] = dict()
) -> Dict[str, Any]:
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
for attr in node.attribute:
if attr.name in attrs:
if attr.type == onnx.AttributeProto.INT:
if attr.type == AttributeProto.INT:
attrs[attr.name] = attr.i
elif attr.type == onnx.AttributeProto.INTS:
elif attr.type == AttributeProto.INTS:
attrs[attr.name] = attr.ints
elif attr.type == onnx.AttributeProto.FLOAT:
elif attr.type == AttributeProto.FLOAT:
attrs[attr.name] = attr.f
elif attr.type == onnx.AttributeProto.STRING:
elif attr.type == AttributeProto.STRING:
attrs[attr.name] = attr.s
elif attr.type == onnx.AttributeProto.TENSOR:
elif attr.type == AttributeProto.TENSOR:
attrs[attr.name] = attr.t
else:
assert False, "Unsupported Attribute Type: {}".format(attr.type)
return attrs
def _parse_data(tensor: onnx.TensorProto) -> List[int]:
if tensor.data_type == onnx.TensorProto.INT32:
def _parse_data(tensor: TensorProto) -> List[int]:
if tensor.data_type == TensorProto.INT32:
return [int(i) for i in tensor.int32_data]
elif tensor.data_type == onnx.TensorProto.INT64:
elif tensor.data_type == TensorProto.INT64:
return [int(i) for i in tensor.int64_data]
else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
def _take_shape_dim(shape: onnx.TensorShapeProto) -> List[int]:
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]

View File

@ -294,10 +294,20 @@ class TestStringMethods(unittest.TestCase):
def test_frontend(self):
handler = backend.GraphHandler(runtime)
i = handler.tensor([1, 2, 3], 12)
w = handler.tensor([1, 3, 4], 12)
o = handler.tensor([1, 2, 4], 12)
handler.matmul(i, w, o, False, False, None, backend.ActType.Relu)
a = handler.tensor([1, 2, 3], 12)
b = handler.tensor([1, 2, 3], 12)
ab = handler.tensor([1, 2, 3], 12)
c = handler.tensor([1, 2, 3], 12)
abc = handler.tensor([1, 2, 3], 12)
d = handler.tensor([1, 2, 3], 12)
abcd = handler.tensor([1, 2, 3], 12)
e = handler.tensor([1, 2, 3], 12)
abcde = handler.tensor([1, 2, 3], 12)
handler.add(a, b, ab)
handler.add(ab, c, abc)
handler.add(abc, d, abcd)
handler.add(abcd, e, abcde)
to_onnx(handler)

View File

@ -80,75 +80,38 @@ void init_graph_builder(py::module &m) {
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj")
.def("src", &TensorObj::getOutputOf, policy::move);
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
.def("op_type", &OperatorObj::getOpType, policy::move);
.def("op_type", &OperatorObj::getOpType, policy::automatic)
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
policy::reference)
.def("outputs",
py::overload_cast<>(&OperatorObj::getOutputs, py::const_),
policy::reference);
py::class_<Handler>(m, "GraphHandler")
.def(py::init<Runtime>())
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
policy::move)
.def("tensor", &Handler::tensor, policy::move)
.def("conv", &Handler::conv, policy::move)
.def("matmul",
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
ActType>(&Handler::matmul),
policy::move)
.def("batchNorm",
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
float, float, bool>(&Handler::batchNorm),
policy::move)
.def("maxPool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&Handler::maxPool),
policy::move)
.def("avgPool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&Handler::avgPool),
policy::move)
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
policy::move)
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
policy::move)
.def("mul", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::mul),
policy::move)
.def("div", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::div),
policy::move)
.def("pow", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::pow),
policy::move)
.def("relu", py::overload_cast<Tensor, Tensor>(&Handler::relu),
policy::move)
.def("sigmoid", py::overload_cast<Tensor, Tensor>(&Handler::sigmoid),
policy::move)
.def("tanh", py::overload_cast<Tensor, Tensor>(&Handler::tanh),
policy::move)
.def("softmax", py::overload_cast<Tensor, Tensor>(&Handler::softmax),
policy::move)
.def("abs", py::overload_cast<Tensor, Tensor>(&Handler::abs),
policy::move)
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
policy::move)
.def("flatten", py::overload_cast<Tensor, Tensor>(&Handler::flatten),
policy::move)
.def("reshape",
py::overload_cast<Tensor, Tensor, Shape>(&Handler::reshape),
policy::move)
.def("concat",
py::overload_cast<TensorVec, Tensor, int>(&Handler::concat),
policy::move)
.def("gather",
py::overload_cast<Tensor, Tensor, Tensor, int>(&Handler::gather),
policy::move)
.def("reduceMean",
py::overload_cast<Tensor, Tensor, const optional<vector<int>> &,
bool>(&Handler::reduceMean),
policy::move)
.def("slice",
py::overload_cast<
Tensor, Tensor, const vector<int> &, const vector<int> &,
const optional<vector<int>> &, const optional<vector<int>> &>(
&Handler::slice),
policy::move)
.def("pad",
py::overload_cast<Tensor, Tensor, const vector<int> &,
const optional<vector<int>> &>(&Handler::pad),
policy::move)
.def("matmul", &Handler::matmul, policy::move)
.def("batchNorm", &Handler::batchNorm, policy::move)
.def("maxPool", &Handler::maxPool, policy::move)
.def("avgPool", &Handler::avgPool, policy::move)
.def("add", &Handler::add, policy::move)
.def("sub", &Handler::sub, policy::move)
.def("mul", &Handler::mul, policy::move)
.def("div", &Handler::div, policy::move)
.def("pow", &Handler::pow, policy::move)
.def("relu", &Handler::relu, policy::move)
.def("sigmoid", &Handler::sigmoid, policy::move)
.def("tanh", &Handler::tanh, policy::move)
.def("softmax", &Handler::softmax, policy::move)
.def("abs", &Handler::abs, policy::move)
.def("identity", &Handler::identity, policy::move)
.def("flatten", &Handler::flatten, policy::move)
.def("reshape", &Handler::reshape, policy::move)
.def("concat", &Handler::concat, policy::move)
.def("gather", &Handler::gather, policy::move)
.def("reduceMean", &Handler::reduceMean, policy::move)
.def("slice", &Handler::slice, policy::move)
.def("pad", &Handler::pad, policy::move)
.def("topo_sort", &Handler::topo_sort, policy::automatic)
.def("operators", &Handler::operators, policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic)