forked from jiuyuan/InfiniTensor
feat: 导出加减乘除幂到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
f2591edbb4
commit
0833a2f779
|
@ -1,4 +1,6 @@
|
|||
import onnx, backend
|
||||
import backend
|
||||
from onnx import ModelProto, TensorProto, NodeProto, AttributeProto, TensorShapeProto
|
||||
from onnx.helper import make_node
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from typing import Dict, List, Any
|
||||
from functools import reduce
|
||||
|
@ -6,12 +8,12 @@ from functools import reduce
|
|||
runtime = backend.cpu_runtime()
|
||||
|
||||
|
||||
def from_onnx(model: onnx.ModelProto) -> backend.GraphHandler:
|
||||
def from_onnx(model: ModelProto) -> backend.GraphHandler:
|
||||
model = infer_shapes(model)
|
||||
handler = backend.GraphHandler(runtime)
|
||||
|
||||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
data: Dict[str, onnx.TensorProto] = dict()
|
||||
data: Dict[str, TensorProto] = dict()
|
||||
|
||||
for input in model.graph.input:
|
||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||
|
@ -310,16 +312,123 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
ops = graph.operators()
|
||||
|
||||
names: Dict[Any, str] = dict()
|
||||
count: Dict[backend.OpType, int] = dict()
|
||||
nodes: List[NodeProto] = []
|
||||
count_op: Dict[backend.OpType, int] = dict()
|
||||
count_in = 0
|
||||
|
||||
for op in ops:
|
||||
ty = op.op_type()
|
||||
names[op] = "{}{}".format(ty.name, count.setdefault(ty, 0) + 1)
|
||||
count[ty] += 1
|
||||
name = "{}{}".format(ty.name, count_op.setdefault(ty, 0) + 1)
|
||||
inputs = op.inputs()
|
||||
outputs = op.outputs()
|
||||
names[op] = name
|
||||
count_op[ty] += 1
|
||||
if ty == backend.OpType.Matmul:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.BatchNorm:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.MaxPool:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.AvgPool:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Add:
|
||||
names[outputs[0]] = name
|
||||
if inputs[0] in names:
|
||||
a = names[inputs[0]]
|
||||
else:
|
||||
count_in += 1
|
||||
a = "input{}".format(count_in)
|
||||
if inputs[1] in names:
|
||||
b = names[inputs[1]]
|
||||
else:
|
||||
count_in += 1
|
||||
b = "input{}".format(count_in)
|
||||
nodes.append(make_node("Add", [a, b], [name], name))
|
||||
elif ty == backend.OpType.Sub:
|
||||
names[outputs[0]] = name
|
||||
if inputs[0] in names:
|
||||
a = names[inputs[0]]
|
||||
else:
|
||||
count_in += 1
|
||||
a = "input{}".format(count_in)
|
||||
if inputs[1] in names:
|
||||
b = names[inputs[1]]
|
||||
else:
|
||||
count_in += 1
|
||||
b = "input{}".format(count_in)
|
||||
nodes.append(make_node("Sub", [a, b], [name], name))
|
||||
elif ty == backend.OpType.Mul:
|
||||
names[outputs[0]] = name
|
||||
if inputs[0] in names:
|
||||
a = names[inputs[0]]
|
||||
else:
|
||||
count_in += 1
|
||||
a = "input{}".format(count_in)
|
||||
if inputs[1] in names:
|
||||
b = names[inputs[1]]
|
||||
else:
|
||||
count_in += 1
|
||||
b = "input{}".format(count_in)
|
||||
nodes.append(make_node("Mul", [a, b], [name], name))
|
||||
elif ty == backend.OpType.Div:
|
||||
names[outputs[0]] = name
|
||||
if inputs[0] in names:
|
||||
a = names[inputs[0]]
|
||||
else:
|
||||
count_in += 1
|
||||
a = "input{}".format(count_in)
|
||||
if inputs[1] in names:
|
||||
b = names[inputs[1]]
|
||||
else:
|
||||
count_in += 1
|
||||
b = "input{}".format(count_in)
|
||||
nodes.append(make_node("Div", [a, b], [name], name))
|
||||
elif ty == backend.OpType.Pow:
|
||||
names[outputs[0]] = name
|
||||
if inputs[0] in names:
|
||||
a = names[inputs[0]]
|
||||
else:
|
||||
count_in += 1
|
||||
a = "input{}".format(count_in)
|
||||
if inputs[1] in names:
|
||||
b = names[inputs[1]]
|
||||
else:
|
||||
count_in += 1
|
||||
b = "input{}".format(count_in)
|
||||
nodes.append(make_node("Pow", [a, b], [name], name))
|
||||
elif ty == backend.OpType.Relu:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Sigmoid:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Tanh:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Softmax:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Abs:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Identity:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Flatten:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Reshape:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Concat:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Gather:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.ReduceMean:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Slice:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Pad:
|
||||
raise Exception("TODO")
|
||||
else:
|
||||
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||
|
||||
print(names)
|
||||
|
||||
|
||||
def parse_onnx(model: onnx.ModelProto):
|
||||
def parse_onnx(model: ModelProto):
|
||||
print()
|
||||
|
||||
for field in [
|
||||
|
@ -355,34 +464,32 @@ def parse_onnx(model: onnx.ModelProto):
|
|||
print(" {}".format(node.name))
|
||||
|
||||
|
||||
def _parse_attribute(
|
||||
node: onnx.NodeProto, attrs: Dict[str, Any] = dict()
|
||||
) -> Dict[str, Any]:
|
||||
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||
for attr in node.attribute:
|
||||
if attr.name in attrs:
|
||||
if attr.type == onnx.AttributeProto.INT:
|
||||
if attr.type == AttributeProto.INT:
|
||||
attrs[attr.name] = attr.i
|
||||
elif attr.type == onnx.AttributeProto.INTS:
|
||||
elif attr.type == AttributeProto.INTS:
|
||||
attrs[attr.name] = attr.ints
|
||||
elif attr.type == onnx.AttributeProto.FLOAT:
|
||||
elif attr.type == AttributeProto.FLOAT:
|
||||
attrs[attr.name] = attr.f
|
||||
elif attr.type == onnx.AttributeProto.STRING:
|
||||
elif attr.type == AttributeProto.STRING:
|
||||
attrs[attr.name] = attr.s
|
||||
elif attr.type == onnx.AttributeProto.TENSOR:
|
||||
elif attr.type == AttributeProto.TENSOR:
|
||||
attrs[attr.name] = attr.t
|
||||
else:
|
||||
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
||||
return attrs
|
||||
|
||||
|
||||
def _parse_data(tensor: onnx.TensorProto) -> List[int]:
|
||||
if tensor.data_type == onnx.TensorProto.INT32:
|
||||
def _parse_data(tensor: TensorProto) -> List[int]:
|
||||
if tensor.data_type == TensorProto.INT32:
|
||||
return [int(i) for i in tensor.int32_data]
|
||||
elif tensor.data_type == onnx.TensorProto.INT64:
|
||||
elif tensor.data_type == TensorProto.INT64:
|
||||
return [int(i) for i in tensor.int64_data]
|
||||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
|
||||
def _take_shape_dim(shape: onnx.TensorShapeProto) -> List[int]:
|
||||
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
||||
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
||||
|
|
|
@ -294,10 +294,20 @@ class TestStringMethods(unittest.TestCase):
|
|||
|
||||
def test_frontend(self):
|
||||
handler = backend.GraphHandler(runtime)
|
||||
i = handler.tensor([1, 2, 3], 12)
|
||||
w = handler.tensor([1, 3, 4], 12)
|
||||
o = handler.tensor([1, 2, 4], 12)
|
||||
handler.matmul(i, w, o, False, False, None, backend.ActType.Relu)
|
||||
a = handler.tensor([1, 2, 3], 12)
|
||||
b = handler.tensor([1, 2, 3], 12)
|
||||
ab = handler.tensor([1, 2, 3], 12)
|
||||
c = handler.tensor([1, 2, 3], 12)
|
||||
abc = handler.tensor([1, 2, 3], 12)
|
||||
d = handler.tensor([1, 2, 3], 12)
|
||||
abcd = handler.tensor([1, 2, 3], 12)
|
||||
e = handler.tensor([1, 2, 3], 12)
|
||||
abcde = handler.tensor([1, 2, 3], 12)
|
||||
|
||||
handler.add(a, b, ab)
|
||||
handler.add(ab, c, abc)
|
||||
handler.add(abc, d, abcd)
|
||||
handler.add(abcd, e, abcde)
|
||||
|
||||
to_onnx(handler)
|
||||
|
||||
|
|
|
@ -80,75 +80,38 @@ void init_graph_builder(py::module &m) {
|
|||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj")
|
||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::move);
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
|
||||
policy::reference)
|
||||
.def("outputs",
|
||||
py::overload_cast<>(&OperatorObj::getOutputs, py::const_),
|
||||
policy::reference);
|
||||
py::class_<Handler>(m, "GraphHandler")
|
||||
.def(py::init<Runtime>())
|
||||
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
||||
policy::move)
|
||||
.def("tensor", &Handler::tensor, policy::move)
|
||||
.def("conv", &Handler::conv, policy::move)
|
||||
.def("matmul",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||
ActType>(&Handler::matmul),
|
||||
policy::move)
|
||||
.def("batchNorm",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
|
||||
float, float, bool>(&Handler::batchNorm),
|
||||
policy::move)
|
||||
.def("maxPool",
|
||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||
int, int>(&Handler::maxPool),
|
||||
policy::move)
|
||||
.def("avgPool",
|
||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||
int, int>(&Handler::avgPool),
|
||||
policy::move)
|
||||
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
|
||||
policy::move)
|
||||
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
|
||||
policy::move)
|
||||
.def("mul", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::mul),
|
||||
policy::move)
|
||||
.def("div", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::div),
|
||||
policy::move)
|
||||
.def("pow", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::pow),
|
||||
policy::move)
|
||||
.def("relu", py::overload_cast<Tensor, Tensor>(&Handler::relu),
|
||||
policy::move)
|
||||
.def("sigmoid", py::overload_cast<Tensor, Tensor>(&Handler::sigmoid),
|
||||
policy::move)
|
||||
.def("tanh", py::overload_cast<Tensor, Tensor>(&Handler::tanh),
|
||||
policy::move)
|
||||
.def("softmax", py::overload_cast<Tensor, Tensor>(&Handler::softmax),
|
||||
policy::move)
|
||||
.def("abs", py::overload_cast<Tensor, Tensor>(&Handler::abs),
|
||||
policy::move)
|
||||
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
|
||||
policy::move)
|
||||
.def("flatten", py::overload_cast<Tensor, Tensor>(&Handler::flatten),
|
||||
policy::move)
|
||||
.def("reshape",
|
||||
py::overload_cast<Tensor, Tensor, Shape>(&Handler::reshape),
|
||||
policy::move)
|
||||
.def("concat",
|
||||
py::overload_cast<TensorVec, Tensor, int>(&Handler::concat),
|
||||
policy::move)
|
||||
.def("gather",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, int>(&Handler::gather),
|
||||
policy::move)
|
||||
.def("reduceMean",
|
||||
py::overload_cast<Tensor, Tensor, const optional<vector<int>> &,
|
||||
bool>(&Handler::reduceMean),
|
||||
policy::move)
|
||||
.def("slice",
|
||||
py::overload_cast<
|
||||
Tensor, Tensor, const vector<int> &, const vector<int> &,
|
||||
const optional<vector<int>> &, const optional<vector<int>> &>(
|
||||
&Handler::slice),
|
||||
policy::move)
|
||||
.def("pad",
|
||||
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
||||
const optional<vector<int>> &>(&Handler::pad),
|
||||
policy::move)
|
||||
.def("matmul", &Handler::matmul, policy::move)
|
||||
.def("batchNorm", &Handler::batchNorm, policy::move)
|
||||
.def("maxPool", &Handler::maxPool, policy::move)
|
||||
.def("avgPool", &Handler::avgPool, policy::move)
|
||||
.def("add", &Handler::add, policy::move)
|
||||
.def("sub", &Handler::sub, policy::move)
|
||||
.def("mul", &Handler::mul, policy::move)
|
||||
.def("div", &Handler::div, policy::move)
|
||||
.def("pow", &Handler::pow, policy::move)
|
||||
.def("relu", &Handler::relu, policy::move)
|
||||
.def("sigmoid", &Handler::sigmoid, policy::move)
|
||||
.def("tanh", &Handler::tanh, policy::move)
|
||||
.def("softmax", &Handler::softmax, policy::move)
|
||||
.def("abs", &Handler::abs, policy::move)
|
||||
.def("identity", &Handler::identity, policy::move)
|
||||
.def("flatten", &Handler::flatten, policy::move)
|
||||
.def("reshape", &Handler::reshape, policy::move)
|
||||
.def("concat", &Handler::concat, policy::move)
|
||||
.def("gather", &Handler::gather, policy::move)
|
||||
.def("reduceMean", &Handler::reduceMean, policy::move)
|
||||
.def("slice", &Handler::slice, policy::move)
|
||||
.def("pad", &Handler::pad, policy::move)
|
||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
.def("operators", &Handler::operators, policy::move)
|
||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||
|
|
Loading…
Reference in New Issue