forked from jiuyuan/InfiniTensor
feat: 导出输入张量到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
eff4c14a85
commit
0517089dca
|
@ -1,6 +1,13 @@
|
|||
import backend
|
||||
from onnx import ModelProto, TensorProto, NodeProto, AttributeProto, TensorShapeProto
|
||||
from onnx.helper import make_node
|
||||
from onnx import (
|
||||
ModelProto,
|
||||
TensorProto,
|
||||
NodeProto,
|
||||
AttributeProto,
|
||||
TensorShapeProto,
|
||||
ValueInfoProto,
|
||||
)
|
||||
from onnx.helper import make_node, make_tensor_value_info
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from typing import Dict, List, Any, Tuple
|
||||
from functools import reduce
|
||||
|
@ -307,11 +314,18 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler:
|
|||
|
||||
def to_onnx(graph: backend.GraphHandler):
|
||||
class Context:
|
||||
names: Dict[Any, str] = dict() # 记录所有对象的名字,包括张量和算子
|
||||
nodes: List[NodeProto] = [] # 保存所有算子
|
||||
count_op: Dict[backend.OpType, int] = dict() # 统计每个算子出现次数,用于命名
|
||||
count_in = 0 # 统计输入张量数量,用于命名
|
||||
count_out = 0 # 统计输出张量数量,用于命名
|
||||
# saves object names, including tensors and operators
|
||||
names: Dict[Any, str] = dict()
|
||||
# counts the occurrence times of each operator for naming
|
||||
count_op: Dict[backend.OpType, int] = dict()
|
||||
# counts input and output tensors for naming
|
||||
count_in, count_out = 0, 0
|
||||
# saves nodes (operators)
|
||||
nodes: List[NodeProto] = []
|
||||
# saves global input tensors
|
||||
inputs: List[ValueInfoProto] = []
|
||||
# saves global output tensors
|
||||
outputs: List[ValueInfoProto] = []
|
||||
|
||||
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
||||
ty = op.op_type()
|
||||
|
@ -325,10 +339,15 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
|
||||
def push_input(self, tensor: backend.Tensor) -> str:
|
||||
name = self.names.get(tensor)
|
||||
# means that this input is a global input
|
||||
if name is None:
|
||||
self.count_in += 1
|
||||
name = "input{}".format(self.count_in)
|
||||
self.names[tensor] = name
|
||||
shape = tensor.shape()
|
||||
dtype = backend.tensor_dtype(tensor)
|
||||
self.inputs.append(make_tensor_value_info(name, dtype, shape))
|
||||
|
||||
return name
|
||||
|
||||
def push_node(self, node: NodeProto) -> None:
|
||||
|
@ -394,7 +413,14 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
else:
|
||||
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||
|
||||
print()
|
||||
print(context.names)
|
||||
print()
|
||||
print(context.inputs)
|
||||
print()
|
||||
print(context.outputs)
|
||||
print()
|
||||
print(context.nodes)
|
||||
|
||||
|
||||
def parse_onnx(model: ModelProto):
|
||||
|
|
|
@ -70,14 +70,37 @@ void init_values(py::module &m) {
|
|||
#undef VALUE
|
||||
}
|
||||
|
||||
static int tensor_dtype(Tensor t) {
|
||||
if (t->getDType() == DataType::Float32)
|
||||
return OnnxDType::FLOAT;
|
||||
if (t->getDType() == DataType::UInt32)
|
||||
return OnnxDType::UINT32;
|
||||
if (t->getDType() == DataType::UInt8)
|
||||
return OnnxDType::UINT8;
|
||||
if (t->getDType() == DataType::Int8)
|
||||
return OnnxDType::INT8;
|
||||
if (t->getDType() == DataType::UInt16)
|
||||
return OnnxDType::UINT16;
|
||||
if (t->getDType() == DataType::Int16)
|
||||
return OnnxDType::INT16;
|
||||
if (t->getDType() == DataType::Int32)
|
||||
return OnnxDType::INT32;
|
||||
if (t->getDType() == DataType::Int64)
|
||||
return OnnxDType::INT64;
|
||||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
|
||||
using Handler = GraphHandlerObj;
|
||||
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
.def("tensor_dtype", &tensor_dtype);
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntime");
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||
.def("shape", &TensorObj::getDims, policy::move)
|
||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
|
|
Loading…
Reference in New Issue