feat: 导出输入张量到 onnx

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-20 14:15:54 +08:00
parent eff4c14a85
commit 0517089dca
2 changed files with 57 additions and 8 deletions

View File

@ -1,6 +1,13 @@
import backend
from onnx import ModelProto, TensorProto, NodeProto, AttributeProto, TensorShapeProto
from onnx.helper import make_node
from onnx import (
ModelProto,
TensorProto,
NodeProto,
AttributeProto,
TensorShapeProto,
ValueInfoProto,
)
from onnx.helper import make_node, make_tensor_value_info
from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any, Tuple
from functools import reduce
@ -307,11 +314,18 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler:
def to_onnx(graph: backend.GraphHandler):
class Context:
names: Dict[Any, str] = dict() # 记录所有对象的名字,包括张量和算子
nodes: List[NodeProto] = [] # 保存所有算子
count_op: Dict[backend.OpType, int] = dict() # 统计每个算子出现次数,用于命名
count_in = 0 # 统计输入张量数量,用于命名
count_out = 0 # 统计输出张量数量,用于命名
# saves object names, including tensors and operators
names: Dict[Any, str] = dict()
# counts the occurrence times of each operator for naming
count_op: Dict[backend.OpType, int] = dict()
# counts input and output tensors for naming
count_in, count_out = 0, 0
# saves nodes (operators)
nodes: List[NodeProto] = []
# saves global input tensors
inputs: List[ValueInfoProto] = []
# saves global output tensors
outputs: List[ValueInfoProto] = []
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
ty = op.op_type()
@ -325,10 +339,15 @@ def to_onnx(graph: backend.GraphHandler):
def push_input(self, tensor: backend.Tensor) -> str:
name = self.names.get(tensor)
# means that this input is a global input
if name is None:
self.count_in += 1
name = "input{}".format(self.count_in)
self.names[tensor] = name
shape = tensor.shape()
dtype = backend.tensor_dtype(tensor)
self.inputs.append(make_tensor_value_info(name, dtype, shape))
return name
def push_node(self, node: NodeProto) -> None:
@ -394,7 +413,14 @@ def to_onnx(graph: backend.GraphHandler):
else:
raise Exception("Unsupported OpType {}".format(ty.name))
print()
print(context.names)
print()
print(context.inputs)
print()
print(context.outputs)
print()
print(context.nodes)
def parse_onnx(model: ModelProto):

View File

@ -70,14 +70,37 @@ void init_values(py::module &m) {
#undef VALUE
}
static int tensor_dtype(Tensor t) {
if (t->getDType() == DataType::Float32)
return OnnxDType::FLOAT;
if (t->getDType() == DataType::UInt32)
return OnnxDType::UINT32;
if (t->getDType() == DataType::UInt8)
return OnnxDType::UINT8;
if (t->getDType() == DataType::Int8)
return OnnxDType::INT8;
if (t->getDType() == DataType::UInt16)
return OnnxDType::UINT16;
if (t->getDType() == DataType::Int16)
return OnnxDType::INT16;
if (t->getDType() == DataType::Int32)
return OnnxDType::INT32;
if (t->getDType() == DataType::Int64)
return OnnxDType::INT64;
IT_ASSERT(false, "Unsupported data type");
}
void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj;
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
.def("tensor_dtype", &tensor_dtype);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntime");
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
.def("shape", &TensorObj::getDims, policy::move)
.def("src", &TensorObj::getOutputOf, policy::move);
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
.def("op_type", &OperatorObj::getOpType, policy::automatic)