refactor(py): 使用工厂方法创建 OnnxStub

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-04-18 15:32:59 +08:00 committed by Liyan Zheng
parent 725f9260cf
commit 4e1cc8d3e4
4 changed files with 69 additions and 57 deletions

View File

@ -81,7 +81,7 @@ import onnx
from pyinfinitensor.onnx import OnnxStub
from pyinfinitensor import backend
stub = OnnxStub(onnx.load("model_file"), backend.cpu_runtime())
stub = OnnxStub.from_model(onnx.load("model_file"), backend.cpu_runtime())
```
[`onnx.load`](https://onnx.ai/onnx/api/serialization.html#load-a-model) 是 onnx 提供的加载函数,将 onnx 文件读取为保存在内存中的 onnx 模型。
@ -201,7 +201,7 @@ def infer(model: ModelProto, input) -> dict:
model0 = onnx.load(sys.argv[1])
model1 = OnnxStub(model0, backend.cpu_runtime()).to_onnx("new")
model1 = OnnxStub.from_model(model0, backend.cpu_runtime()).to_onnx("new")
input_shape = [x.dim_value for x in model1.graph.input[0].type.tensor_type.shape.dim]
input = numpy.random.random(input_shape).astype(numpy.float32)

View File

@ -35,9 +35,11 @@ class GraphHandlerObj {
Graph g;
public:
GraphHandlerObj(Runtime runtime)
explicit GraphHandlerObj(Runtime runtime)
: g(make_ref<GraphObj>(std::move(runtime))) {}
explicit GraphHandlerObj(Graph g) : g(std::move(g)) {}
Tensor tensor(Shape dims, int dtype);
//------ operators

View File

@ -38,28 +38,30 @@ class OnnxStub:
initializer: Dict[int, TensorProto] = {}
handler: backend.GraphHandler
def __init__(self, model: ModelProto, runtime):
@classmethod
def from_onnx(cls, model: ModelProto, runtime):
model = infer_shapes(model)
self.handler = backend.GraphHandler(runtime)
ans = OnnxStub()
ans.handler = backend.GraphHandler(runtime)
tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, TensorProto] = dict()
for input in model.graph.input:
dims = _take_shape_dim(input.type.tensor_type.shape)
tensors[input.name] = self.handler.tensor(
tensors[input.name] = ans.handler.tensor(
dims, input.type.tensor_type.elem_type
)
for output in model.graph.output:
dims = _take_shape_dim(output.type.tensor_type.shape)
tensors[output.name] = self.handler.tensor(
tensors[output.name] = ans.handler.tensor(
dims, output.type.tensor_type.elem_type
)
for initializer in model.graph.initializer:
dims = [d for d in initializer.dims]
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
tensors[initializer.name] = ans.handler.tensor(dims, initializer.data_type)
data[initializer.name] = initializer
for node in model.graph.node:
@ -77,7 +79,7 @@ class OnnxStub:
)
if p[0] != p[2] or p[1] != p[3]:
adapt = "{}-adapt".format(node.output[0])
tensors[adapt] = self.handler.pad(
tensors[adapt] = ans.handler.pad(
tensors[node.input[0]], None, p, [-2, -1]
)
p = [0, 0, 0, 0]
@ -87,7 +89,7 @@ class OnnxStub:
if len(node.input) > 2:
bias = "{}-bias".format(node.output[0])
reshape = "{}-reshape".format(node.output[0])
tensors[bias] = self.handler.conv(
tensors[bias] = ans.handler.conv(
tensors[adapt],
tensors[node.input[1]],
None,
@ -98,7 +100,7 @@ class OnnxStub:
d[0],
d[1],
)
tensors[reshape] = self.handler.reshape(
tensors[reshape] = ans.handler.reshape(
tensors[node.input[2]],
None,
[
@ -111,13 +113,13 @@ class OnnxStub:
1,
],
)
tensors[node.output[0]] = self.handler.add(
tensors[node.output[0]] = ans.handler.add(
tensors[bias],
tensors[reshape],
tensors.get(node.output[0]),
)
else:
tensors[node.output[0]] = self.handler.conv(
tensors[node.output[0]] = ans.handler.conv(
tensors[adapt],
tensors[node.input[1]],
tensors.get(node.output[0]),
@ -142,7 +144,7 @@ class OnnxStub:
attributes[name]
for name in ["dilations", "pads", "strides", "output_padding"]
)
tensors[node.output[0]] = self.handler.convTransposed2d(
tensors[node.output[0]] = ans.handler.convTransposed2d(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
@ -156,7 +158,7 @@ class OnnxStub:
op[1],
)
elif node.op_type == "MatMul":
tensors[node.output[0]] = self.handler.matmul(
tensors[node.output[0]] = ans.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
@ -175,7 +177,7 @@ class OnnxStub:
# FIXME unsupport attributes: `alpha` `beta`
assert alpha == 1.0
assert beta == 1.0
tensors[node.output[0]] = self.handler.matmul(
tensors[node.output[0]] = ans.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
@ -196,7 +198,7 @@ class OnnxStub:
attributes[name]
for name in ["momentum", "epsilon", "training_mode"]
)
tensors[node.output[0]] = self.handler.batchNorm(
tensors[node.output[0]] = ans.handler.batchNorm(
input, output, mean, var, scale, bias, momentum, eps, training != 0
)
elif node.op_type == "MaxPool":
@ -215,10 +217,10 @@ class OnnxStub:
)
if p[0] != p[2] or p[1] != p[3]:
adapt = "{}-adapt".format(node.output[0])
tensors[adapt] = self.handler.pad(
tensors[adapt] = ans.handler.pad(
tensors.get(node.input[0]), None, p, [-2, -1]
)
tensors[node.output[0]] = self.handler.maxPool(
tensors[node.output[0]] = ans.handler.maxPool(
tensors[adapt],
tensors.get(node.output[0]),
k[0],
@ -231,7 +233,7 @@ class OnnxStub:
s[1],
)
else:
tensors[node.output[0]] = self.handler.maxPool(
tensors[node.output[0]] = ans.handler.maxPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
k[0],
@ -257,10 +259,10 @@ class OnnxStub:
)
if p[0] != p[2] or p[1] != p[3]:
adapt = "{}-adapt".format(node.output[0])
tensors[adapt] = self.handler.pad(
tensors[adapt] = ans.handler.pad(
tensors.get(node.input[0]), None, p, [-2, -1]
)
tensors[node.output[0]] = self.handler.avgPool(
tensors[node.output[0]] = ans.handler.avgPool(
tensors[adapt],
tensors.get(node.output[0]),
k[0],
@ -273,7 +275,7 @@ class OnnxStub:
s[1],
)
else:
tensors[node.output[0]] = self.handler.avgPool(
tensors[node.output[0]] = ans.handler.avgPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
k[0],
@ -287,7 +289,7 @@ class OnnxStub:
)
elif node.op_type == "GlobalAveragePool":
[_, _, h, w] = _search_shape(model, node.input[0])
tensors[node.output[0]] = self.handler.avgPool(
tensors[node.output[0]] = ans.handler.avgPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
h,
@ -300,52 +302,52 @@ class OnnxStub:
1,
)
elif node.op_type == "Add":
tensors[node.output[0]] = self.handler.add(
tensors[node.output[0]] = ans.handler.add(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Sub":
tensors[node.output[0]] = self.handler.sub(
tensors[node.output[0]] = ans.handler.sub(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Mul":
tensors[node.output[0]] = self.handler.mul(
tensors[node.output[0]] = ans.handler.mul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Div":
tensors[node.output[0]] = self.handler.div(
tensors[node.output[0]] = ans.handler.div(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Pow":
tensors[node.output[0]] = self.handler.pow(
tensors[node.output[0]] = ans.handler.pow(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Relu":
tensors[node.output[0]] = self.handler.relu(
tensors[node.output[0]] = ans.handler.relu(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Sigmoid":
tensors[node.output[0]] = self.handler.sigmoid(
tensors[node.output[0]] = ans.handler.sigmoid(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Tanh":
tensors[node.output[0]] = self.handler.tanh(
tensors[node.output[0]] = ans.handler.tanh(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Softmax":
tensors[node.output[0]] = self.handler.softmax(
tensors[node.output[0]] = ans.handler.softmax(
tensors[node.input[0]],
tensors.get(node.output[0]),
next(
@ -353,34 +355,34 @@ class OnnxStub:
),
)
elif node.op_type == "Abs":
tensors[node.output[0]] = self.handler.abs(
tensors[node.output[0]] = ans.handler.abs(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Shape":
tensors[node.output[0]] = self.handler.shape(
tensors[node.output[0]] = ans.handler.shape(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Identity":
tensors[node.output[0]] = self.handler.identity(
tensors[node.output[0]] = ans.handler.identity(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Flatten":
tensors[node.output[0]] = self.handler.flatten(
tensors[node.output[0]] = ans.handler.flatten(
tensors[node.input[0]],
tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
elif node.op_type == "PRelu":
tensors[node.output[0]] = self.handler.pRelu(
tensors[node.output[0]] = ans.handler.pRelu(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Clip":
tensors[node.output[0]] = self.handler.clip(
tensors[node.output[0]] = ans.handler.clip(
tensors[node.input[0]],
tensors.get(node.output[0]),
next(_parse_data(data[node.input[1]]).__iter__(), None)
@ -394,7 +396,7 @@ class OnnxStub:
perm = next(
(attr.ints for attr in node.attribute if attr.name == "perm"), None
)
tensors[node.output[0]] = self.handler.transpose(
tensors[node.output[0]] = ans.handler.transpose(
tensors[node.input[0]],
tensors.get(node.output[0]),
perm,
@ -409,7 +411,7 @@ class OnnxStub:
temp = reduce(lambda acc, x: acc * x, input_shape, 1)
if temp < 0:
input_shape[input_shape.index(-1)] = size // -temp
tensors[node.output[0]] = self.handler.reshape(
tensors[node.output[0]] = ans.handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
input_shape,
@ -426,7 +428,7 @@ class OnnxStub:
for i, x in enumerate(input_shape):
if i not in axes:
output_shape.append(x)
tensors[node.output[0]] = self.handler.reshape(
tensors[node.output[0]] = ans.handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
output_shape,
@ -440,13 +442,13 @@ class OnnxStub:
)
for i in axes:
input_shape.insert(i, 1)
tensors[node.output[0]] = self.handler.reshape(
tensors[node.output[0]] = ans.handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
input_shape,
)
elif node.op_type == "Concat":
tensors[node.output[0]] = self.handler.concat(
tensors[node.output[0]] = ans.handler.concat(
[tensors[name] for name in node.input],
tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "axis")),
@ -454,7 +456,7 @@ class OnnxStub:
elif node.op_type == "Split":
for name, tensor in zip(
node.output,
self.handler.split(
ans.handler.split(
tensors[node.input[0]],
None,
next(
@ -466,14 +468,14 @@ class OnnxStub:
):
tensors[name] = tensor
elif node.op_type == "Gather":
tensors[node.output[0]] = self.handler.gather(
tensors[node.output[0]] = ans.handler.gather(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
elif node.op_type == "ReduceMean":
tensors[node.output[0]] = self.handler.reduce_mean(
tensors[node.output[0]] = ans.handler.reduce_mean(
tensors[node.input[0]],
tensors.get(node.output[0]),
tensors[node.input[1]] if len(node.input) > 1 else None,
@ -481,7 +483,7 @@ class OnnxStub:
!= 0,
)
elif node.op_type == "Slice":
tensors[node.output[0]] = self.handler.slice(
tensors[node.output[0]] = ans.handler.slice(
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
@ -490,7 +492,7 @@ class OnnxStub:
_parse_data(data[node.input[4]]) if len(node.input) > 4 else None,
)
elif node.op_type == "Pad":
tensors[node.output[0]] = self.handler.pad(
tensors[node.output[0]] = ans.handler.pad(
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
@ -499,7 +501,7 @@ class OnnxStub:
elif node.op_type == "Dropout":
for name, tensor in zip(
node.output,
self.handler.dropout(
ans.handler.dropout(
tensors[node.input[0]],
tensors.get(node.output[0]),
tensors.get(node.output[1]) if len(node.output) > 1 else None,
@ -517,15 +519,15 @@ class OnnxStub:
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
self.handler.data_malloc()
ans.handler.data_malloc()
for name, obj in tensors.items():
tensor = data.get(name)
if tensor == None:
if any(input.name == name for input in model.graph.input):
self.inputs[name] = obj
ans.inputs[name] = obj
else:
self.initializer[obj.fuid()] = tensor
ans.initializer[obj.fuid()] = tensor
if tensor.data_type == TensorProto.INT32:
obj.copyin_int32(_parse_data(tensor))
elif tensor.data_type == TensorProto.INT64:
@ -535,9 +537,17 @@ class OnnxStub:
else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
return ans
for output in model.graph.output:
self.outputs[output.name] = tensors[output.name]
@classmethod
def from_graph(cls, handler: backend.GraphHandler):
ans = OnnxStub()
ans.handler = handler
return ans
def to_onnx(self, name: str) -> ModelProto:
class Context:
# saves object names, including tensors and operators
@ -841,7 +851,7 @@ class OnnxStub:
def from_onnx(model: ModelProto, runtime):
stub = OnnxStub(model, runtime)
stub = OnnxStub.from_onnx(model, runtime)
return stub.inputs, stub.outputs, stub.handler

View File

@ -40,9 +40,9 @@ class TestStringMethods(unittest.TestCase):
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
)
)
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime()).to_onnx(
"new"
)
model = OnnxStub.from_onnx(
onnx.load(model_file), backend.cpu_runtime()
).to_onnx("new")
model = infer_shapes(model)
def test_tensor(self):