forked from jiuyuan/InfiniTensor
refactor(py): 使用工厂方法创建 OnnxStub
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
725f9260cf
commit
4e1cc8d3e4
|
@ -81,7 +81,7 @@ import onnx
|
|||
from pyinfinitensor.onnx import OnnxStub
|
||||
from pyinfinitensor import backend
|
||||
|
||||
stub = OnnxStub(onnx.load("model_file"), backend.cpu_runtime())
|
||||
stub = OnnxStub.from_model(onnx.load("model_file"), backend.cpu_runtime())
|
||||
```
|
||||
|
||||
[`onnx.load`](https://onnx.ai/onnx/api/serialization.html#load-a-model) 是 onnx 提供的加载函数,将 onnx 文件读取为保存在内存中的 onnx 模型。
|
||||
|
@ -201,7 +201,7 @@ def infer(model: ModelProto, input) -> dict:
|
|||
|
||||
|
||||
model0 = onnx.load(sys.argv[1])
|
||||
model1 = OnnxStub(model0, backend.cpu_runtime()).to_onnx("new")
|
||||
model1 = OnnxStub.from_model(model0, backend.cpu_runtime()).to_onnx("new")
|
||||
|
||||
input_shape = [x.dim_value for x in model1.graph.input[0].type.tensor_type.shape.dim]
|
||||
input = numpy.random.random(input_shape).astype(numpy.float32)
|
||||
|
|
|
@ -35,9 +35,11 @@ class GraphHandlerObj {
|
|||
Graph g;
|
||||
|
||||
public:
|
||||
GraphHandlerObj(Runtime runtime)
|
||||
explicit GraphHandlerObj(Runtime runtime)
|
||||
: g(make_ref<GraphObj>(std::move(runtime))) {}
|
||||
|
||||
explicit GraphHandlerObj(Graph g) : g(std::move(g)) {}
|
||||
|
||||
Tensor tensor(Shape dims, int dtype);
|
||||
|
||||
//------ operators
|
||||
|
|
|
@ -38,28 +38,30 @@ class OnnxStub:
|
|||
initializer: Dict[int, TensorProto] = {}
|
||||
handler: backend.GraphHandler
|
||||
|
||||
def __init__(self, model: ModelProto, runtime):
|
||||
@classmethod
|
||||
def from_onnx(cls, model: ModelProto, runtime):
|
||||
model = infer_shapes(model)
|
||||
self.handler = backend.GraphHandler(runtime)
|
||||
ans = OnnxStub()
|
||||
ans.handler = backend.GraphHandler(runtime)
|
||||
|
||||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
data: Dict[str, TensorProto] = dict()
|
||||
|
||||
for input in model.graph.input:
|
||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||
tensors[input.name] = self.handler.tensor(
|
||||
tensors[input.name] = ans.handler.tensor(
|
||||
dims, input.type.tensor_type.elem_type
|
||||
)
|
||||
|
||||
for output in model.graph.output:
|
||||
dims = _take_shape_dim(output.type.tensor_type.shape)
|
||||
tensors[output.name] = self.handler.tensor(
|
||||
tensors[output.name] = ans.handler.tensor(
|
||||
dims, output.type.tensor_type.elem_type
|
||||
)
|
||||
|
||||
for initializer in model.graph.initializer:
|
||||
dims = [d for d in initializer.dims]
|
||||
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
|
||||
tensors[initializer.name] = ans.handler.tensor(dims, initializer.data_type)
|
||||
data[initializer.name] = initializer
|
||||
|
||||
for node in model.graph.node:
|
||||
|
@ -77,7 +79,7 @@ class OnnxStub:
|
|||
)
|
||||
if p[0] != p[2] or p[1] != p[3]:
|
||||
adapt = "{}-adapt".format(node.output[0])
|
||||
tensors[adapt] = self.handler.pad(
|
||||
tensors[adapt] = ans.handler.pad(
|
||||
tensors[node.input[0]], None, p, [-2, -1]
|
||||
)
|
||||
p = [0, 0, 0, 0]
|
||||
|
@ -87,7 +89,7 @@ class OnnxStub:
|
|||
if len(node.input) > 2:
|
||||
bias = "{}-bias".format(node.output[0])
|
||||
reshape = "{}-reshape".format(node.output[0])
|
||||
tensors[bias] = self.handler.conv(
|
||||
tensors[bias] = ans.handler.conv(
|
||||
tensors[adapt],
|
||||
tensors[node.input[1]],
|
||||
None,
|
||||
|
@ -98,7 +100,7 @@ class OnnxStub:
|
|||
d[0],
|
||||
d[1],
|
||||
)
|
||||
tensors[reshape] = self.handler.reshape(
|
||||
tensors[reshape] = ans.handler.reshape(
|
||||
tensors[node.input[2]],
|
||||
None,
|
||||
[
|
||||
|
@ -111,13 +113,13 @@ class OnnxStub:
|
|||
1,
|
||||
],
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.add(
|
||||
tensors[node.output[0]] = ans.handler.add(
|
||||
tensors[bias],
|
||||
tensors[reshape],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
else:
|
||||
tensors[node.output[0]] = self.handler.conv(
|
||||
tensors[node.output[0]] = ans.handler.conv(
|
||||
tensors[adapt],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
|
@ -142,7 +144,7 @@ class OnnxStub:
|
|||
attributes[name]
|
||||
for name in ["dilations", "pads", "strides", "output_padding"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.convTransposed2d(
|
||||
tensors[node.output[0]] = ans.handler.convTransposed2d(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
|
@ -156,7 +158,7 @@ class OnnxStub:
|
|||
op[1],
|
||||
)
|
||||
elif node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.output[0]] = ans.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
|
@ -175,7 +177,7 @@ class OnnxStub:
|
|||
# FIXME unsupport attributes: `alpha` `beta`
|
||||
assert alpha == 1.0
|
||||
assert beta == 1.0
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.output[0]] = ans.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
|
@ -196,7 +198,7 @@ class OnnxStub:
|
|||
attributes[name]
|
||||
for name in ["momentum", "epsilon", "training_mode"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.batchNorm(
|
||||
tensors[node.output[0]] = ans.handler.batchNorm(
|
||||
input, output, mean, var, scale, bias, momentum, eps, training != 0
|
||||
)
|
||||
elif node.op_type == "MaxPool":
|
||||
|
@ -215,10 +217,10 @@ class OnnxStub:
|
|||
)
|
||||
if p[0] != p[2] or p[1] != p[3]:
|
||||
adapt = "{}-adapt".format(node.output[0])
|
||||
tensors[adapt] = self.handler.pad(
|
||||
tensors[adapt] = ans.handler.pad(
|
||||
tensors.get(node.input[0]), None, p, [-2, -1]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.maxPool(
|
||||
tensors[node.output[0]] = ans.handler.maxPool(
|
||||
tensors[adapt],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
|
@ -231,7 +233,7 @@ class OnnxStub:
|
|||
s[1],
|
||||
)
|
||||
else:
|
||||
tensors[node.output[0]] = self.handler.maxPool(
|
||||
tensors[node.output[0]] = ans.handler.maxPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
|
@ -257,10 +259,10 @@ class OnnxStub:
|
|||
)
|
||||
if p[0] != p[2] or p[1] != p[3]:
|
||||
adapt = "{}-adapt".format(node.output[0])
|
||||
tensors[adapt] = self.handler.pad(
|
||||
tensors[adapt] = ans.handler.pad(
|
||||
tensors.get(node.input[0]), None, p, [-2, -1]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.avgPool(
|
||||
tensors[node.output[0]] = ans.handler.avgPool(
|
||||
tensors[adapt],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
|
@ -273,7 +275,7 @@ class OnnxStub:
|
|||
s[1],
|
||||
)
|
||||
else:
|
||||
tensors[node.output[0]] = self.handler.avgPool(
|
||||
tensors[node.output[0]] = ans.handler.avgPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
|
@ -287,7 +289,7 @@ class OnnxStub:
|
|||
)
|
||||
elif node.op_type == "GlobalAveragePool":
|
||||
[_, _, h, w] = _search_shape(model, node.input[0])
|
||||
tensors[node.output[0]] = self.handler.avgPool(
|
||||
tensors[node.output[0]] = ans.handler.avgPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
h,
|
||||
|
@ -300,52 +302,52 @@ class OnnxStub:
|
|||
1,
|
||||
)
|
||||
elif node.op_type == "Add":
|
||||
tensors[node.output[0]] = self.handler.add(
|
||||
tensors[node.output[0]] = ans.handler.add(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Sub":
|
||||
tensors[node.output[0]] = self.handler.sub(
|
||||
tensors[node.output[0]] = ans.handler.sub(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Mul":
|
||||
tensors[node.output[0]] = self.handler.mul(
|
||||
tensors[node.output[0]] = ans.handler.mul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Div":
|
||||
tensors[node.output[0]] = self.handler.div(
|
||||
tensors[node.output[0]] = ans.handler.div(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Pow":
|
||||
tensors[node.output[0]] = self.handler.pow(
|
||||
tensors[node.output[0]] = ans.handler.pow(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Relu":
|
||||
tensors[node.output[0]] = self.handler.relu(
|
||||
tensors[node.output[0]] = ans.handler.relu(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Sigmoid":
|
||||
tensors[node.output[0]] = self.handler.sigmoid(
|
||||
tensors[node.output[0]] = ans.handler.sigmoid(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Tanh":
|
||||
tensors[node.output[0]] = self.handler.tanh(
|
||||
tensors[node.output[0]] = ans.handler.tanh(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Softmax":
|
||||
tensors[node.output[0]] = self.handler.softmax(
|
||||
tensors[node.output[0]] = ans.handler.softmax(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
next(
|
||||
|
@ -353,34 +355,34 @@ class OnnxStub:
|
|||
),
|
||||
)
|
||||
elif node.op_type == "Abs":
|
||||
tensors[node.output[0]] = self.handler.abs(
|
||||
tensors[node.output[0]] = ans.handler.abs(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Shape":
|
||||
tensors[node.output[0]] = self.handler.shape(
|
||||
tensors[node.output[0]] = ans.handler.shape(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Identity":
|
||||
tensors[node.output[0]] = self.handler.identity(
|
||||
tensors[node.output[0]] = ans.handler.identity(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Flatten":
|
||||
tensors[node.output[0]] = self.handler.flatten(
|
||||
tensors[node.output[0]] = ans.handler.flatten(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "PRelu":
|
||||
tensors[node.output[0]] = self.handler.pRelu(
|
||||
tensors[node.output[0]] = ans.handler.pRelu(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Clip":
|
||||
tensors[node.output[0]] = self.handler.clip(
|
||||
tensors[node.output[0]] = ans.handler.clip(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
next(_parse_data(data[node.input[1]]).__iter__(), None)
|
||||
|
@ -394,7 +396,7 @@ class OnnxStub:
|
|||
perm = next(
|
||||
(attr.ints for attr in node.attribute if attr.name == "perm"), None
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.transpose(
|
||||
tensors[node.output[0]] = ans.handler.transpose(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
perm,
|
||||
|
@ -409,7 +411,7 @@ class OnnxStub:
|
|||
temp = reduce(lambda acc, x: acc * x, input_shape, 1)
|
||||
if temp < 0:
|
||||
input_shape[input_shape.index(-1)] = size // -temp
|
||||
tensors[node.output[0]] = self.handler.reshape(
|
||||
tensors[node.output[0]] = ans.handler.reshape(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
input_shape,
|
||||
|
@ -426,7 +428,7 @@ class OnnxStub:
|
|||
for i, x in enumerate(input_shape):
|
||||
if i not in axes:
|
||||
output_shape.append(x)
|
||||
tensors[node.output[0]] = self.handler.reshape(
|
||||
tensors[node.output[0]] = ans.handler.reshape(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
output_shape,
|
||||
|
@ -440,13 +442,13 @@ class OnnxStub:
|
|||
)
|
||||
for i in axes:
|
||||
input_shape.insert(i, 1)
|
||||
tensors[node.output[0]] = self.handler.reshape(
|
||||
tensors[node.output[0]] = ans.handler.reshape(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
input_shape,
|
||||
)
|
||||
elif node.op_type == "Concat":
|
||||
tensors[node.output[0]] = self.handler.concat(
|
||||
tensors[node.output[0]] = ans.handler.concat(
|
||||
[tensors[name] for name in node.input],
|
||||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
|
@ -454,7 +456,7 @@ class OnnxStub:
|
|||
elif node.op_type == "Split":
|
||||
for name, tensor in zip(
|
||||
node.output,
|
||||
self.handler.split(
|
||||
ans.handler.split(
|
||||
tensors[node.input[0]],
|
||||
None,
|
||||
next(
|
||||
|
@ -466,14 +468,14 @@ class OnnxStub:
|
|||
):
|
||||
tensors[name] = tensor
|
||||
elif node.op_type == "Gather":
|
||||
tensors[node.output[0]] = self.handler.gather(
|
||||
tensors[node.output[0]] = ans.handler.gather(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "ReduceMean":
|
||||
tensors[node.output[0]] = self.handler.reduce_mean(
|
||||
tensors[node.output[0]] = ans.handler.reduce_mean(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors[node.input[1]] if len(node.input) > 1 else None,
|
||||
|
@ -481,7 +483,7 @@ class OnnxStub:
|
|||
!= 0,
|
||||
)
|
||||
elif node.op_type == "Slice":
|
||||
tensors[node.output[0]] = self.handler.slice(
|
||||
tensors[node.output[0]] = ans.handler.slice(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
|
@ -490,7 +492,7 @@ class OnnxStub:
|
|||
_parse_data(data[node.input[4]]) if len(node.input) > 4 else None,
|
||||
)
|
||||
elif node.op_type == "Pad":
|
||||
tensors[node.output[0]] = self.handler.pad(
|
||||
tensors[node.output[0]] = ans.handler.pad(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
|
@ -499,7 +501,7 @@ class OnnxStub:
|
|||
elif node.op_type == "Dropout":
|
||||
for name, tensor in zip(
|
||||
node.output,
|
||||
self.handler.dropout(
|
||||
ans.handler.dropout(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors.get(node.output[1]) if len(node.output) > 1 else None,
|
||||
|
@ -517,15 +519,15 @@ class OnnxStub:
|
|||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
self.handler.data_malloc()
|
||||
ans.handler.data_malloc()
|
||||
|
||||
for name, obj in tensors.items():
|
||||
tensor = data.get(name)
|
||||
if tensor == None:
|
||||
if any(input.name == name for input in model.graph.input):
|
||||
self.inputs[name] = obj
|
||||
ans.inputs[name] = obj
|
||||
else:
|
||||
self.initializer[obj.fuid()] = tensor
|
||||
ans.initializer[obj.fuid()] = tensor
|
||||
if tensor.data_type == TensorProto.INT32:
|
||||
obj.copyin_int32(_parse_data(tensor))
|
||||
elif tensor.data_type == TensorProto.INT64:
|
||||
|
@ -535,9 +537,17 @@ class OnnxStub:
|
|||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
return ans
|
||||
|
||||
for output in model.graph.output:
|
||||
self.outputs[output.name] = tensors[output.name]
|
||||
|
||||
@classmethod
|
||||
def from_graph(cls, handler: backend.GraphHandler):
|
||||
ans = OnnxStub()
|
||||
ans.handler = handler
|
||||
return ans
|
||||
|
||||
def to_onnx(self, name: str) -> ModelProto:
|
||||
class Context:
|
||||
# saves object names, including tensors and operators
|
||||
|
@ -841,7 +851,7 @@ class OnnxStub:
|
|||
|
||||
|
||||
def from_onnx(model: ModelProto, runtime):
|
||||
stub = OnnxStub(model, runtime)
|
||||
stub = OnnxStub.from_onnx(model, runtime)
|
||||
return stub.inputs, stub.outputs, stub.handler
|
||||
|
||||
|
||||
|
|
|
@ -40,9 +40,9 @@ class TestStringMethods(unittest.TestCase):
|
|||
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
|
||||
)
|
||||
)
|
||||
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime()).to_onnx(
|
||||
"new"
|
||||
)
|
||||
model = OnnxStub.from_onnx(
|
||||
onnx.load(model_file), backend.cpu_runtime()
|
||||
).to_onnx("new")
|
||||
model = infer_shapes(model)
|
||||
|
||||
def test_tensor(self):
|
||||
|
|
Loading…
Reference in New Issue