feat: 导出 batchnorm 到 onnx

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-24 15:42:39 +08:00
parent 71ca4459d9
commit a5e692baea
4 changed files with 515 additions and 480 deletions

View File

@ -39,9 +39,11 @@ class BatchNormObj : public OperatorObj {
std::string toString() const override;
// output size will be 3 when training
int numInputs() const override { return 5; }
int numOutputs() const override { return outputs.size(); }
float getEps() const { return eps; }
inline int numInputs() const override { return 5; }
inline int numOutputs() const override { return outputs.size(); }
inline float getMomentum() const { return momentum; }
inline float getEps() const { return eps; }
inline bool getTraining() const { return training; }
private:
vector<int> getWorkloadVector() const override;

View File

@ -32,22 +32,29 @@ def cuda_runtime():
return backend.cuda_runtime()
def from_onnx(
model: ModelProto, runtime
) -> Tuple[Dict[str, backend.Tensor], Dict[str, backend.Tensor], backend.GraphHandler]:
class OnnxStub:
inputs: Dict[str, backend.Tensor] = {}
outputs: Dict[str, backend.Tensor] = {}
handler: backend.GraphHandler
def __init__(self, model: ModelProto, runtime):
model = infer_shapes(model)
handler = backend.GraphHandler(runtime)
self.handler = backend.GraphHandler(runtime)
tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, TensorProto] = dict()
for input in model.graph.input:
dims = _take_shape_dim(input.type.tensor_type.shape)
tensors[input.name] = handler.tensor(dims, input.type.tensor_type.elem_type)
tensors[input.name] = self.handler.tensor(
dims, input.type.tensor_type.elem_type
)
for output in model.graph.output:
dims = _take_shape_dim(output.type.tensor_type.shape)
tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type)
tensors[output.name] = self.handler.tensor(
dims, output.type.tensor_type.elem_type
)
for initializer in model.graph.initializer:
data[initializer.name] = initializer
@ -62,8 +69,10 @@ def from_onnx(
"strides": [1, 1],
},
)
(d, p, s) = (attributes[name] for name in ["dilations", "pads", "strides"])
tensors[node.output[0]] = handler.conv(
(d, p, s) = (
attributes[name] for name in ["dilations", "pads", "strides"]
)
tensors[node.output[0]] = self.handler.conv(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
@ -75,7 +84,7 @@ def from_onnx(
d[1],
)
elif node.op_type == "MatMul":
tensors[node.output[0]] = handler.matmul(
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
@ -91,10 +100,10 @@ def from_onnx(
(alpha, beta, transA, transB) = (
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
)
# FIXME 不支持 `alpha` `beta`
# TODO 不支持这些参数
assert alpha == 1.0
assert beta == 1.0
tensors[node.output[0]] = handler.matmul(
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
@ -112,9 +121,10 @@ def from_onnx(
node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0}
)
(momentum, eps, training) = (
attributes[name] for name in ["momentum", "epsilon", "training_mode"]
attributes[name]
for name in ["momentum", "epsilon", "training_mode"]
)
tensors[node.output[0]] = handler.batchNorm(
tensors[node.output[0]] = self.handler.batchNorm(
input, output, mean, var, scale, bias, momentum, eps, training != 0
)
elif node.op_type == "MaxPool":
@ -131,7 +141,7 @@ def from_onnx(
attributes[name]
for name in ["kernel_shape", "dilations", "pads", "strides"]
)
tensors[node.output[0]] = handler.maxPool(
tensors[node.output[0]] = self.handler.maxPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
k[0],
@ -155,7 +165,7 @@ def from_onnx(
(k, p, s) = (
attributes[name] for name in ["kernel_shape", "pads", "strides"]
)
tensors[node.output[0]] = handler.avgPool(
tensors[node.output[0]] = self.handler.avgPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
k[0],
@ -181,7 +191,7 @@ def from_onnx(
if input.name == node.input[0]
)
[_, _, h, w] = _take_shape_dim(shape)
tensors[node.output[0]] = handler.avgPool(
tensors[node.output[0]] = self.handler.avgPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
h,
@ -194,72 +204,72 @@ def from_onnx(
1,
)
elif node.op_type == "Add":
tensors[node.output[0]] = handler.add(
tensors[node.output[0]] = self.handler.add(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Sub":
tensors[node.output[0]] = handler.sub(
tensors[node.output[0]] = self.handler.sub(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Mul":
tensors[node.output[0]] = handler.mul(
tensors[node.output[0]] = self.handler.mul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Div":
tensors[node.output[0]] = handler.div(
tensors[node.output[0]] = self.handler.div(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Pow":
tensors[node.output[0]] = handler.pow(
tensors[node.output[0]] = self.handler.pow(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Relu":
tensors[node.output[0]] = handler.relu(
tensors[node.output[0]] = self.handler.relu(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Sigmoid":
tensors[node.output[0]] = handler.sigmoid(
tensors[node.output[0]] = self.handler.sigmoid(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Tanh":
tensors[node.output[0]] = handler.tanh(
tensors[node.output[0]] = self.handler.tanh(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Softmax":
tensors[node.output[0]] = handler.softmax(
tensors[node.output[0]] = self.handler.softmax(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Abs":
tensors[node.output[0]] = handler.abs(
tensors[node.output[0]] = self.handler.abs(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Identity":
tensors[node.output[0]] = handler.identity(
tensors[node.output[0]] = self.handler.identity(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Flatten":
# FIXME 后端算子不支持沿任意轴展开
# TODO 后端算子不支持沿任意轴展开
axis = next(
(attr.i for attr in node.attribute if attr.name == "axis"), None
)
assert axis == None or axis == 1
tensors[node.output[0]] = handler.flatten(
tensors[node.output[0]] = self.handler.flatten(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
@ -285,26 +295,26 @@ def from_onnx(
temp = reduce(lambda acc, x: acc * x, output_shape)
if temp < 0:
output_shape[output_shape.index(-1)] = size // -temp
tensors[node.output[0]] = handler.reshape(
tensors[node.output[0]] = self.handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
output_shape,
)
elif node.op_type == "Concat":
tensors[node.output[0]] = handler.concat(
tensors[node.output[0]] = self.handler.concat(
[tensors[name] for name in node.input],
tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
elif node.op_type == "Gather":
tensors[node.output[0]] = handler.gather(
tensors[node.output[0]] = self.handler.gather(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
elif node.op_type == "ReduceMean":
tensors[node.output[0]] = handler.reduceMean(
tensors[node.output[0]] = self.handler.reduceMean(
tensors[node.input[0]],
tensors.get(node.output[0]),
tensors[node.input[1]] if len(node.input) > 1 else None,
@ -312,7 +322,7 @@ def from_onnx(
!= 0,
)
elif node.op_type == "Slice":
tensors[node.output[0]] = handler.slice(
tensors[node.output[0]] = self.handler.slice(
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
@ -321,7 +331,7 @@ def from_onnx(
_parse_data(data[node.input[4]]) if len(node.input) > 4 else None,
)
elif node.op_type == "Pad":
tensors[node.output[0]] = handler.pad(
tensors[node.output[0]] = self.handler.pad(
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
@ -330,32 +340,27 @@ def from_onnx(
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
handler.data_malloc()
self.handler.data_malloc()
inputs: Dict[str, backend.Tensor] = {}
for name, obj in tensors.items():
tensor = data.get(name)
if tensor == None:
if any(input.name == name for input in model.graph.input):
inputs[name] = obj
self.inputs[name] = obj
else:
if tensor.data_type == TensorProto.INT32:
handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
elif tensor.data_type == TensorProto.INT64:
handler.copy_int64(obj, [int(i) for i in tensor.int64_data])
self.handler.copy_int64(obj, [int(i) for i in tensor.int64_data])
elif tensor.data_type == TensorProto.FLOAT:
handler.copy_float(obj, [float(i) for i in tensor.float_data])
self.handler.copy_float(obj, [float(i) for i in tensor.float_data])
else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
outputs: Dict[str, backend.Tensor] = {}
for output in model.graph.output:
outputs[output.name] = tensors[output.name]
self.outputs[output.name] = tensors[output.name]
return inputs, outputs, handler
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
def to_onnx(self, name: str) -> ModelProto:
class Context:
# saves object names, including tensors and operators
names: Dict[Any, str] = dict()
@ -446,10 +451,10 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
return model
# 拓扑排序
if not graph.topo_sort():
if not self.handler.topo_sort():
raise Exception("Sorting fails")
ops = graph.operators() # 图中所有算子(节点)
ops = self.handler.operators() # 图中所有算子(节点)
ctx = Context()
@ -460,10 +465,24 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
ctx.push_output("{}_{}".format(name, i), it)
for (i, it) in enumerate(op.outputs())
]
if ty == backend.OpType.Matmul:
if ty == backend.OpType.Conv:
raise Exception("TODO")
elif ty == backend.OpType.Matmul:
ctx.push_node(make_node("MatMul", inputs, outputs, name))
elif ty == backend.OpType.BatchNorm:
raise Exception("TODO")
inputs = [inputs[i] for i in [0, 3, 4, 1, 2]]
momentum, eps, training = backend.batch_norm_attrs_of(op)
ctx.push_node(
make_node(
"BatchNormalization",
inputs,
outputs,
name,
epsilon=eps,
momentum=momentum,
training_mode=training,
)
)
elif ty == backend.OpType.MaxPool:
raise Exception("TODO")
elif ty == backend.OpType.AvgPool:
@ -505,7 +524,9 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
elif ty == backend.OpType.ReduceMean:
axes = backend.reduce_mean_axes_of(op)
inputs.append(
ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes)
ctx.push_data_input(
name, "axes", TensorProto.INT32, [len(axes)], axes
)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1))
elif ty == backend.OpType.Slice:
@ -518,6 +539,11 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
return ctx.build(name)
def from_onnx(model: ModelProto, runtime):
stub = OnnxStub(model, runtime)
return stub.inputs, stub.outputs, stub.handler
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
for attr in node.attribute:
if attr.name in attrs:

View File

@ -8,7 +8,7 @@ from onnx.helper import (
make_tensor_value_info,
)
from onnx.checker import check_model
from pyinfinitensor.onnx import from_onnx, backend, to_onnx, cpu_runtime
from pyinfinitensor.onnx import from_onnx, backend, cpu_runtime
def make_and_import_model(graph: onnx.GraphProto):
@ -305,8 +305,6 @@ class TestStringMethods(unittest.TestCase):
y = handler.tensor([3, 2, 1], 12)
handler.reshape(x, y, [3, 2, 1])
to_onnx(handler, "test_frontend")
if __name__ == "__main__":
unittest.main()

View File

@ -1,4 +1,5 @@
#include "core/graph_handler.h"
#include "operators/batch_norm.h"
#include "operators/concat.h"
#include "operators/gather.h"
#include "operators/reduce_mean.h"
@ -120,6 +121,13 @@ static Shape reshape_shape_of(Operator op) {
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
}
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::BatchNorm);
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
batchnorm->getTraining());
}
void export_functions(py::module &m) {
#define FUNCTION(NAME) def(#NAME, &NAME)
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
@ -130,7 +138,8 @@ void export_functions(py::module &m) {
.FUNCTION(reshape_shape_of)
.FUNCTION(concat_axis_of)
.FUNCTION(gather_axis_of)
.FUNCTION(reduce_mean_axes_of);
.FUNCTION(reduce_mean_axes_of)
.FUNCTION(batch_norm_attrs_of);
#undef FUNCTION
}