forked from jiuyuan/InfiniTensor
feat: 导出 batchnorm 到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
71ca4459d9
commit
a5e692baea
|
@ -39,9 +39,11 @@ class BatchNormObj : public OperatorObj {
|
|||
std::string toString() const override;
|
||||
|
||||
// output size will be 3 when training
|
||||
int numInputs() const override { return 5; }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
float getEps() const { return eps; }
|
||||
inline int numInputs() const override { return 5; }
|
||||
inline int numOutputs() const override { return outputs.size(); }
|
||||
inline float getMomentum() const { return momentum; }
|
||||
inline float getEps() const { return eps; }
|
||||
inline bool getTraining() const { return training; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
|
|
@ -32,490 +32,516 @@ def cuda_runtime():
|
|||
return backend.cuda_runtime()
|
||||
|
||||
|
||||
def from_onnx(
|
||||
model: ModelProto, runtime
|
||||
) -> Tuple[Dict[str, backend.Tensor], Dict[str, backend.Tensor], backend.GraphHandler]:
|
||||
model = infer_shapes(model)
|
||||
handler = backend.GraphHandler(runtime)
|
||||
|
||||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
data: Dict[str, TensorProto] = dict()
|
||||
|
||||
for input in model.graph.input:
|
||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||
tensors[input.name] = handler.tensor(dims, input.type.tensor_type.elem_type)
|
||||
|
||||
for output in model.graph.output:
|
||||
dims = _take_shape_dim(output.type.tensor_type.shape)
|
||||
tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type)
|
||||
|
||||
for initializer in model.graph.initializer:
|
||||
data[initializer.name] = initializer
|
||||
|
||||
for node in model.graph.node:
|
||||
if node.op_type == "Conv":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
{
|
||||
"dilations": [1, 1],
|
||||
"pads": [0, 0],
|
||||
"strides": [1, 1],
|
||||
},
|
||||
)
|
||||
(d, p, s) = (attributes[name] for name in ["dilations", "pads", "strides"])
|
||||
tensors[node.output[0]] = handler.conv(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
d[0],
|
||||
d[1],
|
||||
)
|
||||
elif node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
elif node.op_type == "Gemm":
|
||||
attributes = _parse_attribute(
|
||||
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||
)
|
||||
(alpha, beta, transA, transB) = (
|
||||
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
|
||||
)
|
||||
# FIXME 不支持 `alpha` `beta`
|
||||
assert alpha == 1.0
|
||||
assert beta == 1.0
|
||||
tensors[node.output[0]] = handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
transA == 1,
|
||||
transB == 1,
|
||||
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
elif node.op_type == "BatchNormalization":
|
||||
(input, mean, var, scale, bias) = (
|
||||
tensors[node.input[i]] for i in [0, 3, 4, 1, 2]
|
||||
)
|
||||
output = tensors.get(node.output[0])
|
||||
attributes = _parse_attribute(
|
||||
node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0}
|
||||
)
|
||||
(momentum, eps, training) = (
|
||||
attributes[name] for name in ["momentum", "epsilon", "training_mode"]
|
||||
)
|
||||
tensors[node.output[0]] = handler.batchNorm(
|
||||
input, output, mean, var, scale, bias, momentum, eps, training != 0
|
||||
)
|
||||
elif node.op_type == "MaxPool":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
{
|
||||
"kernel_shape": None,
|
||||
"dilations": [1, 1],
|
||||
"pads": [0, 0],
|
||||
"strides": [1, 1],
|
||||
},
|
||||
)
|
||||
(k, d, p, s) = (
|
||||
attributes[name]
|
||||
for name in ["kernel_shape", "dilations", "pads", "strides"]
|
||||
)
|
||||
tensors[node.output[0]] = handler.maxPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
k[1],
|
||||
d[0],
|
||||
d[1],
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
)
|
||||
elif node.op_type == "AveragePool":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
{
|
||||
"kernel_shape": None,
|
||||
"pads": [0, 0],
|
||||
"strides": [1, 1],
|
||||
},
|
||||
)
|
||||
(k, p, s) = (
|
||||
attributes[name] for name in ["kernel_shape", "pads", "strides"]
|
||||
)
|
||||
tensors[node.output[0]] = handler.avgPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
k[1],
|
||||
1,
|
||||
1,
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
)
|
||||
elif node.op_type == "GlobalAveragePool":
|
||||
shape = next(
|
||||
(
|
||||
value.type.tensor_type.shape
|
||||
for value in model.graph.value_info
|
||||
if value.name == node.input[0]
|
||||
),
|
||||
None,
|
||||
) or next(
|
||||
input.type.tensor_type.shape
|
||||
for input in model.graph.input
|
||||
if input.name == node.input[0]
|
||||
)
|
||||
[_, _, h, w] = _take_shape_dim(shape)
|
||||
tensors[node.output[0]] = handler.avgPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
h,
|
||||
w,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
)
|
||||
elif node.op_type == "Add":
|
||||
tensors[node.output[0]] = handler.add(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Sub":
|
||||
tensors[node.output[0]] = handler.sub(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Mul":
|
||||
tensors[node.output[0]] = handler.mul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Div":
|
||||
tensors[node.output[0]] = handler.div(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Pow":
|
||||
tensors[node.output[0]] = handler.pow(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Relu":
|
||||
tensors[node.output[0]] = handler.relu(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Sigmoid":
|
||||
tensors[node.output[0]] = handler.sigmoid(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Tanh":
|
||||
tensors[node.output[0]] = handler.tanh(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Softmax":
|
||||
tensors[node.output[0]] = handler.softmax(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Abs":
|
||||
tensors[node.output[0]] = handler.abs(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Identity":
|
||||
tensors[node.output[0]] = handler.identity(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Flatten":
|
||||
# FIXME 后端算子不支持沿任意轴展开
|
||||
axis = next(
|
||||
(attr.i for attr in node.attribute if attr.name == "axis"), None
|
||||
)
|
||||
assert axis == None or axis == 1
|
||||
tensors[node.output[0]] = handler.flatten(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Reshape":
|
||||
input_shape = next(
|
||||
(
|
||||
value.type.tensor_type.shape
|
||||
for value in model.graph.value_info
|
||||
if value.name == node.input[0]
|
||||
),
|
||||
None,
|
||||
) or next(
|
||||
input.type.tensor_type.shape
|
||||
for input in model.graph.input
|
||||
if input.name == node.input[0]
|
||||
)
|
||||
dims = _take_shape_dim(input_shape)
|
||||
size = reduce(lambda acc, x: acc * x, dims)
|
||||
output_shape = [int(i) for i in data[node.input[1]].int64_data]
|
||||
for i, x in enumerate(output_shape):
|
||||
if x == 0:
|
||||
output_shape[i] = dims[i]
|
||||
temp = reduce(lambda acc, x: acc * x, output_shape)
|
||||
if temp < 0:
|
||||
output_shape[output_shape.index(-1)] = size // -temp
|
||||
tensors[node.output[0]] = handler.reshape(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
output_shape,
|
||||
)
|
||||
elif node.op_type == "Concat":
|
||||
tensors[node.output[0]] = handler.concat(
|
||||
[tensors[name] for name in node.input],
|
||||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "Gather":
|
||||
tensors[node.output[0]] = handler.gather(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "ReduceMean":
|
||||
tensors[node.output[0]] = handler.reduceMean(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors[node.input[1]] if len(node.input) > 1 else None,
|
||||
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
|
||||
!= 0,
|
||||
)
|
||||
elif node.op_type == "Slice":
|
||||
tensors[node.output[0]] = handler.slice(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
_parse_data(data[node.input[2]]),
|
||||
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
|
||||
_parse_data(data[node.input[4]]) if len(node.input) > 4 else None,
|
||||
)
|
||||
elif node.op_type == "Pad":
|
||||
tensors[node.output[0]] = handler.pad(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
|
||||
)
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
handler.data_malloc()
|
||||
|
||||
class OnnxStub:
|
||||
inputs: Dict[str, backend.Tensor] = {}
|
||||
for name, obj in tensors.items():
|
||||
tensor = data.get(name)
|
||||
if tensor == None:
|
||||
if any(input.name == name for input in model.graph.input):
|
||||
inputs[name] = obj
|
||||
else:
|
||||
if tensor.data_type == TensorProto.INT32:
|
||||
handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
|
||||
elif tensor.data_type == TensorProto.INT64:
|
||||
handler.copy_int64(obj, [int(i) for i in tensor.int64_data])
|
||||
elif tensor.data_type == TensorProto.FLOAT:
|
||||
handler.copy_float(obj, [float(i) for i in tensor.float_data])
|
||||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
outputs: Dict[str, backend.Tensor] = {}
|
||||
for output in model.graph.output:
|
||||
outputs[output.name] = tensors[output.name]
|
||||
handler: backend.GraphHandler
|
||||
|
||||
return inputs, outputs, handler
|
||||
def __init__(self, model: ModelProto, runtime):
|
||||
model = infer_shapes(model)
|
||||
self.handler = backend.GraphHandler(runtime)
|
||||
|
||||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
data: Dict[str, TensorProto] = dict()
|
||||
|
||||
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||
class Context:
|
||||
# saves object names, including tensors and operators
|
||||
names: Dict[Any, str] = dict()
|
||||
# counts the occurrence times of each operator for naming
|
||||
count_op: Dict[backend.OpType, int] = dict()
|
||||
# counts input and output tensors for naming
|
||||
count_in, count_out = 0, 0
|
||||
# saves nodes (operators)
|
||||
nodes: List[NodeProto] = []
|
||||
# saves global input tensors
|
||||
inputs: List[ValueInfoProto] = []
|
||||
# saves global output tensors
|
||||
outputs: List[ValueInfoProto] = []
|
||||
# saves global input tensors
|
||||
initializers: List[TensorProto] = []
|
||||
|
||||
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
||||
ty = op.op_type()
|
||||
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
|
||||
self.names[op] = name
|
||||
self.count_op[ty] += 1
|
||||
return ty, name
|
||||
|
||||
def push_output(self, name: str, tensor: backend.Tensor) -> str:
|
||||
self.names[tensor] = name
|
||||
if not tensor.has_target():
|
||||
shape = tensor.shape()
|
||||
dtype = backend.tensor_dtype(tensor)
|
||||
value_info = make_tensor_value_info(name, dtype, shape)
|
||||
check_value_info(value_info)
|
||||
self.outputs.append(value_info)
|
||||
return name
|
||||
|
||||
def push_input(self, tensor: backend.Tensor) -> str:
|
||||
name = self.names.get(tensor)
|
||||
# means that this input is a global input
|
||||
if name is None:
|
||||
self.count_in += 1
|
||||
name = "input{}".format(self.count_in)
|
||||
self.names[tensor] = name
|
||||
shape = tensor.shape()
|
||||
dtype = backend.tensor_dtype(tensor)
|
||||
value_info = make_tensor_value_info(name, dtype, shape)
|
||||
check_value_info(value_info)
|
||||
self.inputs.append(value_info)
|
||||
|
||||
return name
|
||||
|
||||
def push_data_input(
|
||||
self,
|
||||
node_name: str,
|
||||
attr_name: str,
|
||||
elem_type: int,
|
||||
shape: Sequence[int],
|
||||
vals: Any,
|
||||
) -> str:
|
||||
name = "{}_{}".format(node_name, attr_name)
|
||||
value_info = make_tensor_value_info(name, elem_type, shape)
|
||||
tensor = make_tensor(name, elem_type, shape, vals)
|
||||
check_value_info(value_info)
|
||||
check_tensor(tensor)
|
||||
self.inputs.append(value_info)
|
||||
self.initializers.append(tensor)
|
||||
return name
|
||||
|
||||
def push_node(self, node: NodeProto) -> None:
|
||||
check_node(node)
|
||||
self.nodes.append(node)
|
||||
|
||||
def build(self, name: str) -> ModelProto:
|
||||
print()
|
||||
print(ctx.names)
|
||||
print()
|
||||
print(ctx.inputs)
|
||||
print()
|
||||
print(ctx.outputs)
|
||||
print()
|
||||
print(ctx.nodes)
|
||||
|
||||
graph = make_graph(
|
||||
self.nodes, name, self.inputs, self.outputs, self.initializers
|
||||
for input in model.graph.input:
|
||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||
tensors[input.name] = self.handler.tensor(
|
||||
dims, input.type.tensor_type.elem_type
|
||||
)
|
||||
check_graph(graph)
|
||||
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
for output in model.graph.output:
|
||||
dims = _take_shape_dim(output.type.tensor_type.shape)
|
||||
tensors[output.name] = self.handler.tensor(
|
||||
dims, output.type.tensor_type.elem_type
|
||||
)
|
||||
|
||||
return model
|
||||
for initializer in model.graph.initializer:
|
||||
data[initializer.name] = initializer
|
||||
|
||||
# 拓扑排序
|
||||
if not graph.topo_sort():
|
||||
raise Exception("Sorting fails")
|
||||
|
||||
ops = graph.operators() # 图中所有算子(节点)
|
||||
|
||||
ctx = Context()
|
||||
|
||||
for op in ops:
|
||||
ty, name = ctx.name_op(op)
|
||||
inputs = [ctx.push_input(it) for it in op.inputs()]
|
||||
outputs = [
|
||||
ctx.push_output("{}_{}".format(name, i), it)
|
||||
for (i, it) in enumerate(op.outputs())
|
||||
]
|
||||
if ty == backend.OpType.Matmul:
|
||||
ctx.push_node(make_node("MatMul", inputs, outputs, name))
|
||||
elif ty == backend.OpType.BatchNorm:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.MaxPool:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.AvgPool:
|
||||
raise Exception("TODO")
|
||||
elif ty in [
|
||||
backend.OpType.Add,
|
||||
backend.OpType.Sub,
|
||||
backend.OpType.Mul,
|
||||
backend.OpType.Div,
|
||||
backend.OpType.Pow,
|
||||
backend.OpType.Relu,
|
||||
backend.OpType.Sigmoid,
|
||||
backend.OpType.Tanh,
|
||||
backend.OpType.Softmax,
|
||||
backend.OpType.Abs,
|
||||
backend.OpType.Identity,
|
||||
]:
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.Flatten:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Reshape:
|
||||
shape = backend.reshape_shape_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(
|
||||
name,
|
||||
"shape",
|
||||
TensorProto.INT32,
|
||||
[len(shape)],
|
||||
shape,
|
||||
for node in model.graph.node:
|
||||
if node.op_type == "Conv":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
{
|
||||
"dilations": [1, 1],
|
||||
"pads": [0, 0],
|
||||
"strides": [1, 1],
|
||||
},
|
||||
)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.Concat:
|
||||
axis = backend.concat_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
elif ty == backend.OpType.Gather:
|
||||
axis = backend.gather_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
elif ty == backend.OpType.ReduceMean:
|
||||
axes = backend.reduce_mean_axes_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1))
|
||||
elif ty == backend.OpType.Slice:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Pad:
|
||||
raise Exception("TODO")
|
||||
else:
|
||||
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||
(d, p, s) = (
|
||||
attributes[name] for name in ["dilations", "pads", "strides"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.conv(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
d[0],
|
||||
d[1],
|
||||
)
|
||||
elif node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
elif node.op_type == "Gemm":
|
||||
attributes = _parse_attribute(
|
||||
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||
)
|
||||
(alpha, beta, transA, transB) = (
|
||||
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
|
||||
)
|
||||
# TODO 不支持这些参数
|
||||
assert alpha == 1.0
|
||||
assert beta == 1.0
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
transA == 1,
|
||||
transB == 1,
|
||||
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
elif node.op_type == "BatchNormalization":
|
||||
(input, mean, var, scale, bias) = (
|
||||
tensors[node.input[i]] for i in [0, 3, 4, 1, 2]
|
||||
)
|
||||
output = tensors.get(node.output[0])
|
||||
attributes = _parse_attribute(
|
||||
node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0}
|
||||
)
|
||||
(momentum, eps, training) = (
|
||||
attributes[name]
|
||||
for name in ["momentum", "epsilon", "training_mode"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.batchNorm(
|
||||
input, output, mean, var, scale, bias, momentum, eps, training != 0
|
||||
)
|
||||
elif node.op_type == "MaxPool":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
{
|
||||
"kernel_shape": None,
|
||||
"dilations": [1, 1],
|
||||
"pads": [0, 0],
|
||||
"strides": [1, 1],
|
||||
},
|
||||
)
|
||||
(k, d, p, s) = (
|
||||
attributes[name]
|
||||
for name in ["kernel_shape", "dilations", "pads", "strides"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.maxPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
k[1],
|
||||
d[0],
|
||||
d[1],
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
)
|
||||
elif node.op_type == "AveragePool":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
{
|
||||
"kernel_shape": None,
|
||||
"pads": [0, 0],
|
||||
"strides": [1, 1],
|
||||
},
|
||||
)
|
||||
(k, p, s) = (
|
||||
attributes[name] for name in ["kernel_shape", "pads", "strides"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.avgPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
k[1],
|
||||
1,
|
||||
1,
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
)
|
||||
elif node.op_type == "GlobalAveragePool":
|
||||
shape = next(
|
||||
(
|
||||
value.type.tensor_type.shape
|
||||
for value in model.graph.value_info
|
||||
if value.name == node.input[0]
|
||||
),
|
||||
None,
|
||||
) or next(
|
||||
input.type.tensor_type.shape
|
||||
for input in model.graph.input
|
||||
if input.name == node.input[0]
|
||||
)
|
||||
[_, _, h, w] = _take_shape_dim(shape)
|
||||
tensors[node.output[0]] = self.handler.avgPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
h,
|
||||
w,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
)
|
||||
elif node.op_type == "Add":
|
||||
tensors[node.output[0]] = self.handler.add(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Sub":
|
||||
tensors[node.output[0]] = self.handler.sub(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Mul":
|
||||
tensors[node.output[0]] = self.handler.mul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Div":
|
||||
tensors[node.output[0]] = self.handler.div(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Pow":
|
||||
tensors[node.output[0]] = self.handler.pow(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Relu":
|
||||
tensors[node.output[0]] = self.handler.relu(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Sigmoid":
|
||||
tensors[node.output[0]] = self.handler.sigmoid(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Tanh":
|
||||
tensors[node.output[0]] = self.handler.tanh(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Softmax":
|
||||
tensors[node.output[0]] = self.handler.softmax(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Abs":
|
||||
tensors[node.output[0]] = self.handler.abs(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Identity":
|
||||
tensors[node.output[0]] = self.handler.identity(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Flatten":
|
||||
# TODO 后端算子不支持沿任意轴展开
|
||||
axis = next(
|
||||
(attr.i for attr in node.attribute if attr.name == "axis"), None
|
||||
)
|
||||
assert axis == None or axis == 1
|
||||
tensors[node.output[0]] = self.handler.flatten(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Reshape":
|
||||
input_shape = next(
|
||||
(
|
||||
value.type.tensor_type.shape
|
||||
for value in model.graph.value_info
|
||||
if value.name == node.input[0]
|
||||
),
|
||||
None,
|
||||
) or next(
|
||||
input.type.tensor_type.shape
|
||||
for input in model.graph.input
|
||||
if input.name == node.input[0]
|
||||
)
|
||||
dims = _take_shape_dim(input_shape)
|
||||
size = reduce(lambda acc, x: acc * x, dims)
|
||||
output_shape = [int(i) for i in data[node.input[1]].int64_data]
|
||||
for i, x in enumerate(output_shape):
|
||||
if x == 0:
|
||||
output_shape[i] = dims[i]
|
||||
temp = reduce(lambda acc, x: acc * x, output_shape)
|
||||
if temp < 0:
|
||||
output_shape[output_shape.index(-1)] = size // -temp
|
||||
tensors[node.output[0]] = self.handler.reshape(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
output_shape,
|
||||
)
|
||||
elif node.op_type == "Concat":
|
||||
tensors[node.output[0]] = self.handler.concat(
|
||||
[tensors[name] for name in node.input],
|
||||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "Gather":
|
||||
tensors[node.output[0]] = self.handler.gather(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "ReduceMean":
|
||||
tensors[node.output[0]] = self.handler.reduceMean(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors[node.input[1]] if len(node.input) > 1 else None,
|
||||
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
|
||||
!= 0,
|
||||
)
|
||||
elif node.op_type == "Slice":
|
||||
tensors[node.output[0]] = self.handler.slice(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
_parse_data(data[node.input[2]]),
|
||||
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
|
||||
_parse_data(data[node.input[4]]) if len(node.input) > 4 else None,
|
||||
)
|
||||
elif node.op_type == "Pad":
|
||||
tensors[node.output[0]] = self.handler.pad(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
|
||||
)
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
return ctx.build(name)
|
||||
self.handler.data_malloc()
|
||||
|
||||
for name, obj in tensors.items():
|
||||
tensor = data.get(name)
|
||||
if tensor == None:
|
||||
if any(input.name == name for input in model.graph.input):
|
||||
self.inputs[name] = obj
|
||||
else:
|
||||
if tensor.data_type == TensorProto.INT32:
|
||||
self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
|
||||
elif tensor.data_type == TensorProto.INT64:
|
||||
self.handler.copy_int64(obj, [int(i) for i in tensor.int64_data])
|
||||
elif tensor.data_type == TensorProto.FLOAT:
|
||||
self.handler.copy_float(obj, [float(i) for i in tensor.float_data])
|
||||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
for output in model.graph.output:
|
||||
self.outputs[output.name] = tensors[output.name]
|
||||
|
||||
def to_onnx(self, name: str) -> ModelProto:
|
||||
class Context:
|
||||
# saves object names, including tensors and operators
|
||||
names: Dict[Any, str] = dict()
|
||||
# counts the occurrence times of each operator for naming
|
||||
count_op: Dict[backend.OpType, int] = dict()
|
||||
# counts input and output tensors for naming
|
||||
count_in, count_out = 0, 0
|
||||
# saves nodes (operators)
|
||||
nodes: List[NodeProto] = []
|
||||
# saves global input tensors
|
||||
inputs: List[ValueInfoProto] = []
|
||||
# saves global output tensors
|
||||
outputs: List[ValueInfoProto] = []
|
||||
# saves global input tensors
|
||||
initializers: List[TensorProto] = []
|
||||
|
||||
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
||||
ty = op.op_type()
|
||||
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
|
||||
self.names[op] = name
|
||||
self.count_op[ty] += 1
|
||||
return ty, name
|
||||
|
||||
def push_output(self, name: str, tensor: backend.Tensor) -> str:
|
||||
self.names[tensor] = name
|
||||
if not tensor.has_target():
|
||||
shape = tensor.shape()
|
||||
dtype = backend.tensor_dtype(tensor)
|
||||
value_info = make_tensor_value_info(name, dtype, shape)
|
||||
check_value_info(value_info)
|
||||
self.outputs.append(value_info)
|
||||
return name
|
||||
|
||||
def push_input(self, tensor: backend.Tensor) -> str:
|
||||
name = self.names.get(tensor)
|
||||
# means that this input is a global input
|
||||
if name is None:
|
||||
self.count_in += 1
|
||||
name = "input{}".format(self.count_in)
|
||||
self.names[tensor] = name
|
||||
shape = tensor.shape()
|
||||
dtype = backend.tensor_dtype(tensor)
|
||||
value_info = make_tensor_value_info(name, dtype, shape)
|
||||
check_value_info(value_info)
|
||||
self.inputs.append(value_info)
|
||||
|
||||
return name
|
||||
|
||||
def push_data_input(
|
||||
self,
|
||||
node_name: str,
|
||||
attr_name: str,
|
||||
elem_type: int,
|
||||
shape: Sequence[int],
|
||||
vals: Any,
|
||||
) -> str:
|
||||
name = "{}_{}".format(node_name, attr_name)
|
||||
value_info = make_tensor_value_info(name, elem_type, shape)
|
||||
tensor = make_tensor(name, elem_type, shape, vals)
|
||||
check_value_info(value_info)
|
||||
check_tensor(tensor)
|
||||
self.inputs.append(value_info)
|
||||
self.initializers.append(tensor)
|
||||
return name
|
||||
|
||||
def push_node(self, node: NodeProto) -> None:
|
||||
check_node(node)
|
||||
self.nodes.append(node)
|
||||
|
||||
def build(self, name: str) -> ModelProto:
|
||||
print()
|
||||
print(ctx.names)
|
||||
print()
|
||||
print(ctx.inputs)
|
||||
print()
|
||||
print(ctx.outputs)
|
||||
print()
|
||||
print(ctx.nodes)
|
||||
|
||||
graph = make_graph(
|
||||
self.nodes, name, self.inputs, self.outputs, self.initializers
|
||||
)
|
||||
check_graph(graph)
|
||||
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
|
||||
return model
|
||||
|
||||
# 拓扑排序
|
||||
if not self.handler.topo_sort():
|
||||
raise Exception("Sorting fails")
|
||||
|
||||
ops = self.handler.operators() # 图中所有算子(节点)
|
||||
|
||||
ctx = Context()
|
||||
|
||||
for op in ops:
|
||||
ty, name = ctx.name_op(op)
|
||||
inputs = [ctx.push_input(it) for it in op.inputs()]
|
||||
outputs = [
|
||||
ctx.push_output("{}_{}".format(name, i), it)
|
||||
for (i, it) in enumerate(op.outputs())
|
||||
]
|
||||
if ty == backend.OpType.Conv:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Matmul:
|
||||
ctx.push_node(make_node("MatMul", inputs, outputs, name))
|
||||
elif ty == backend.OpType.BatchNorm:
|
||||
inputs = [inputs[i] for i in [0, 3, 4, 1, 2]]
|
||||
momentum, eps, training = backend.batch_norm_attrs_of(op)
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
"BatchNormalization",
|
||||
inputs,
|
||||
outputs,
|
||||
name,
|
||||
epsilon=eps,
|
||||
momentum=momentum,
|
||||
training_mode=training,
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpType.MaxPool:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.AvgPool:
|
||||
raise Exception("TODO")
|
||||
elif ty in [
|
||||
backend.OpType.Add,
|
||||
backend.OpType.Sub,
|
||||
backend.OpType.Mul,
|
||||
backend.OpType.Div,
|
||||
backend.OpType.Pow,
|
||||
backend.OpType.Relu,
|
||||
backend.OpType.Sigmoid,
|
||||
backend.OpType.Tanh,
|
||||
backend.OpType.Softmax,
|
||||
backend.OpType.Abs,
|
||||
backend.OpType.Identity,
|
||||
]:
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.Flatten:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Reshape:
|
||||
shape = backend.reshape_shape_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(
|
||||
name,
|
||||
"shape",
|
||||
TensorProto.INT32,
|
||||
[len(shape)],
|
||||
shape,
|
||||
)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.Concat:
|
||||
axis = backend.concat_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
elif ty == backend.OpType.Gather:
|
||||
axis = backend.gather_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
elif ty == backend.OpType.ReduceMean:
|
||||
axes = backend.reduce_mean_axes_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(
|
||||
name, "axes", TensorProto.INT32, [len(axes)], axes
|
||||
)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1))
|
||||
elif ty == backend.OpType.Slice:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Pad:
|
||||
raise Exception("TODO")
|
||||
else:
|
||||
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||
|
||||
return ctx.build(name)
|
||||
|
||||
|
||||
def from_onnx(model: ModelProto, runtime):
|
||||
stub = OnnxStub(model, runtime)
|
||||
return stub.inputs, stub.outputs, stub.handler
|
||||
|
||||
|
||||
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||
|
|
|
@ -8,7 +8,7 @@ from onnx.helper import (
|
|||
make_tensor_value_info,
|
||||
)
|
||||
from onnx.checker import check_model
|
||||
from pyinfinitensor.onnx import from_onnx, backend, to_onnx, cpu_runtime
|
||||
from pyinfinitensor.onnx import from_onnx, backend, cpu_runtime
|
||||
|
||||
|
||||
def make_and_import_model(graph: onnx.GraphProto):
|
||||
|
@ -305,8 +305,6 @@ class TestStringMethods(unittest.TestCase):
|
|||
y = handler.tensor([3, 2, 1], 12)
|
||||
handler.reshape(x, y, [3, 2, 1])
|
||||
|
||||
to_onnx(handler, "test_frontend")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
|
@ -120,6 +121,13 @@ static Shape reshape_shape_of(Operator op) {
|
|||
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
|
||||
}
|
||||
|
||||
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::BatchNorm);
|
||||
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
|
||||
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
|
||||
batchnorm->getTraining());
|
||||
}
|
||||
|
||||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
|
@ -130,7 +138,8 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(reshape_shape_of)
|
||||
.FUNCTION(concat_axis_of)
|
||||
.FUNCTION(gather_axis_of)
|
||||
.FUNCTION(reduce_mean_axes_of);
|
||||
.FUNCTION(reduce_mean_axes_of)
|
||||
.FUNCTION(batch_norm_attrs_of);
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue