forked from jiuyuan/InfiniTensor
feat: 导出 batchnorm 到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
71ca4459d9
commit
a5e692baea
|
@ -39,9 +39,11 @@ class BatchNormObj : public OperatorObj {
|
|||
std::string toString() const override;
|
||||
|
||||
// output size will be 3 when training
|
||||
int numInputs() const override { return 5; }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
float getEps() const { return eps; }
|
||||
inline int numInputs() const override { return 5; }
|
||||
inline int numOutputs() const override { return outputs.size(); }
|
||||
inline float getMomentum() const { return momentum; }
|
||||
inline float getEps() const { return eps; }
|
||||
inline bool getTraining() const { return training; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
|
|
@ -32,22 +32,29 @@ def cuda_runtime():
|
|||
return backend.cuda_runtime()
|
||||
|
||||
|
||||
def from_onnx(
|
||||
model: ModelProto, runtime
|
||||
) -> Tuple[Dict[str, backend.Tensor], Dict[str, backend.Tensor], backend.GraphHandler]:
|
||||
class OnnxStub:
|
||||
inputs: Dict[str, backend.Tensor] = {}
|
||||
outputs: Dict[str, backend.Tensor] = {}
|
||||
handler: backend.GraphHandler
|
||||
|
||||
def __init__(self, model: ModelProto, runtime):
|
||||
model = infer_shapes(model)
|
||||
handler = backend.GraphHandler(runtime)
|
||||
self.handler = backend.GraphHandler(runtime)
|
||||
|
||||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
data: Dict[str, TensorProto] = dict()
|
||||
|
||||
for input in model.graph.input:
|
||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||
tensors[input.name] = handler.tensor(dims, input.type.tensor_type.elem_type)
|
||||
tensors[input.name] = self.handler.tensor(
|
||||
dims, input.type.tensor_type.elem_type
|
||||
)
|
||||
|
||||
for output in model.graph.output:
|
||||
dims = _take_shape_dim(output.type.tensor_type.shape)
|
||||
tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type)
|
||||
tensors[output.name] = self.handler.tensor(
|
||||
dims, output.type.tensor_type.elem_type
|
||||
)
|
||||
|
||||
for initializer in model.graph.initializer:
|
||||
data[initializer.name] = initializer
|
||||
|
@ -62,8 +69,10 @@ def from_onnx(
|
|||
"strides": [1, 1],
|
||||
},
|
||||
)
|
||||
(d, p, s) = (attributes[name] for name in ["dilations", "pads", "strides"])
|
||||
tensors[node.output[0]] = handler.conv(
|
||||
(d, p, s) = (
|
||||
attributes[name] for name in ["dilations", "pads", "strides"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.conv(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
|
@ -75,7 +84,7 @@ def from_onnx(
|
|||
d[1],
|
||||
)
|
||||
elif node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = handler.matmul(
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
|
@ -91,10 +100,10 @@ def from_onnx(
|
|||
(alpha, beta, transA, transB) = (
|
||||
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
|
||||
)
|
||||
# FIXME 不支持 `alpha` `beta`
|
||||
# TODO 不支持这些参数
|
||||
assert alpha == 1.0
|
||||
assert beta == 1.0
|
||||
tensors[node.output[0]] = handler.matmul(
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
|
@ -112,9 +121,10 @@ def from_onnx(
|
|||
node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0}
|
||||
)
|
||||
(momentum, eps, training) = (
|
||||
attributes[name] for name in ["momentum", "epsilon", "training_mode"]
|
||||
attributes[name]
|
||||
for name in ["momentum", "epsilon", "training_mode"]
|
||||
)
|
||||
tensors[node.output[0]] = handler.batchNorm(
|
||||
tensors[node.output[0]] = self.handler.batchNorm(
|
||||
input, output, mean, var, scale, bias, momentum, eps, training != 0
|
||||
)
|
||||
elif node.op_type == "MaxPool":
|
||||
|
@ -131,7 +141,7 @@ def from_onnx(
|
|||
attributes[name]
|
||||
for name in ["kernel_shape", "dilations", "pads", "strides"]
|
||||
)
|
||||
tensors[node.output[0]] = handler.maxPool(
|
||||
tensors[node.output[0]] = self.handler.maxPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
|
@ -155,7 +165,7 @@ def from_onnx(
|
|||
(k, p, s) = (
|
||||
attributes[name] for name in ["kernel_shape", "pads", "strides"]
|
||||
)
|
||||
tensors[node.output[0]] = handler.avgPool(
|
||||
tensors[node.output[0]] = self.handler.avgPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
|
@ -181,7 +191,7 @@ def from_onnx(
|
|||
if input.name == node.input[0]
|
||||
)
|
||||
[_, _, h, w] = _take_shape_dim(shape)
|
||||
tensors[node.output[0]] = handler.avgPool(
|
||||
tensors[node.output[0]] = self.handler.avgPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
h,
|
||||
|
@ -194,72 +204,72 @@ def from_onnx(
|
|||
1,
|
||||
)
|
||||
elif node.op_type == "Add":
|
||||
tensors[node.output[0]] = handler.add(
|
||||
tensors[node.output[0]] = self.handler.add(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Sub":
|
||||
tensors[node.output[0]] = handler.sub(
|
||||
tensors[node.output[0]] = self.handler.sub(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Mul":
|
||||
tensors[node.output[0]] = handler.mul(
|
||||
tensors[node.output[0]] = self.handler.mul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Div":
|
||||
tensors[node.output[0]] = handler.div(
|
||||
tensors[node.output[0]] = self.handler.div(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Pow":
|
||||
tensors[node.output[0]] = handler.pow(
|
||||
tensors[node.output[0]] = self.handler.pow(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Relu":
|
||||
tensors[node.output[0]] = handler.relu(
|
||||
tensors[node.output[0]] = self.handler.relu(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Sigmoid":
|
||||
tensors[node.output[0]] = handler.sigmoid(
|
||||
tensors[node.output[0]] = self.handler.sigmoid(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Tanh":
|
||||
tensors[node.output[0]] = handler.tanh(
|
||||
tensors[node.output[0]] = self.handler.tanh(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Softmax":
|
||||
tensors[node.output[0]] = handler.softmax(
|
||||
tensors[node.output[0]] = self.handler.softmax(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Abs":
|
||||
tensors[node.output[0]] = handler.abs(
|
||||
tensors[node.output[0]] = self.handler.abs(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Identity":
|
||||
tensors[node.output[0]] = handler.identity(
|
||||
tensors[node.output[0]] = self.handler.identity(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Flatten":
|
||||
# FIXME 后端算子不支持沿任意轴展开
|
||||
# TODO 后端算子不支持沿任意轴展开
|
||||
axis = next(
|
||||
(attr.i for attr in node.attribute if attr.name == "axis"), None
|
||||
)
|
||||
assert axis == None or axis == 1
|
||||
tensors[node.output[0]] = handler.flatten(
|
||||
tensors[node.output[0]] = self.handler.flatten(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
|
@ -285,26 +295,26 @@ def from_onnx(
|
|||
temp = reduce(lambda acc, x: acc * x, output_shape)
|
||||
if temp < 0:
|
||||
output_shape[output_shape.index(-1)] = size // -temp
|
||||
tensors[node.output[0]] = handler.reshape(
|
||||
tensors[node.output[0]] = self.handler.reshape(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
output_shape,
|
||||
)
|
||||
elif node.op_type == "Concat":
|
||||
tensors[node.output[0]] = handler.concat(
|
||||
tensors[node.output[0]] = self.handler.concat(
|
||||
[tensors[name] for name in node.input],
|
||||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "Gather":
|
||||
tensors[node.output[0]] = handler.gather(
|
||||
tensors[node.output[0]] = self.handler.gather(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "ReduceMean":
|
||||
tensors[node.output[0]] = handler.reduceMean(
|
||||
tensors[node.output[0]] = self.handler.reduceMean(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors[node.input[1]] if len(node.input) > 1 else None,
|
||||
|
@ -312,7 +322,7 @@ def from_onnx(
|
|||
!= 0,
|
||||
)
|
||||
elif node.op_type == "Slice":
|
||||
tensors[node.output[0]] = handler.slice(
|
||||
tensors[node.output[0]] = self.handler.slice(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
|
@ -321,7 +331,7 @@ def from_onnx(
|
|||
_parse_data(data[node.input[4]]) if len(node.input) > 4 else None,
|
||||
)
|
||||
elif node.op_type == "Pad":
|
||||
tensors[node.output[0]] = handler.pad(
|
||||
tensors[node.output[0]] = self.handler.pad(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
|
@ -330,32 +340,27 @@ def from_onnx(
|
|||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
handler.data_malloc()
|
||||
self.handler.data_malloc()
|
||||
|
||||
inputs: Dict[str, backend.Tensor] = {}
|
||||
for name, obj in tensors.items():
|
||||
tensor = data.get(name)
|
||||
if tensor == None:
|
||||
if any(input.name == name for input in model.graph.input):
|
||||
inputs[name] = obj
|
||||
self.inputs[name] = obj
|
||||
else:
|
||||
if tensor.data_type == TensorProto.INT32:
|
||||
handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
|
||||
self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
|
||||
elif tensor.data_type == TensorProto.INT64:
|
||||
handler.copy_int64(obj, [int(i) for i in tensor.int64_data])
|
||||
self.handler.copy_int64(obj, [int(i) for i in tensor.int64_data])
|
||||
elif tensor.data_type == TensorProto.FLOAT:
|
||||
handler.copy_float(obj, [float(i) for i in tensor.float_data])
|
||||
self.handler.copy_float(obj, [float(i) for i in tensor.float_data])
|
||||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
outputs: Dict[str, backend.Tensor] = {}
|
||||
for output in model.graph.output:
|
||||
outputs[output.name] = tensors[output.name]
|
||||
self.outputs[output.name] = tensors[output.name]
|
||||
|
||||
return inputs, outputs, handler
|
||||
|
||||
|
||||
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||
def to_onnx(self, name: str) -> ModelProto:
|
||||
class Context:
|
||||
# saves object names, including tensors and operators
|
||||
names: Dict[Any, str] = dict()
|
||||
|
@ -446,10 +451,10 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
return model
|
||||
|
||||
# 拓扑排序
|
||||
if not graph.topo_sort():
|
||||
if not self.handler.topo_sort():
|
||||
raise Exception("Sorting fails")
|
||||
|
||||
ops = graph.operators() # 图中所有算子(节点)
|
||||
ops = self.handler.operators() # 图中所有算子(节点)
|
||||
|
||||
ctx = Context()
|
||||
|
||||
|
@ -460,10 +465,24 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
ctx.push_output("{}_{}".format(name, i), it)
|
||||
for (i, it) in enumerate(op.outputs())
|
||||
]
|
||||
if ty == backend.OpType.Matmul:
|
||||
if ty == backend.OpType.Conv:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Matmul:
|
||||
ctx.push_node(make_node("MatMul", inputs, outputs, name))
|
||||
elif ty == backend.OpType.BatchNorm:
|
||||
raise Exception("TODO")
|
||||
inputs = [inputs[i] for i in [0, 3, 4, 1, 2]]
|
||||
momentum, eps, training = backend.batch_norm_attrs_of(op)
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
"BatchNormalization",
|
||||
inputs,
|
||||
outputs,
|
||||
name,
|
||||
epsilon=eps,
|
||||
momentum=momentum,
|
||||
training_mode=training,
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpType.MaxPool:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.AvgPool:
|
||||
|
@ -505,7 +524,9 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
elif ty == backend.OpType.ReduceMean:
|
||||
axes = backend.reduce_mean_axes_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes)
|
||||
ctx.push_data_input(
|
||||
name, "axes", TensorProto.INT32, [len(axes)], axes
|
||||
)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1))
|
||||
elif ty == backend.OpType.Slice:
|
||||
|
@ -518,6 +539,11 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
return ctx.build(name)
|
||||
|
||||
|
||||
def from_onnx(model: ModelProto, runtime):
|
||||
stub = OnnxStub(model, runtime)
|
||||
return stub.inputs, stub.outputs, stub.handler
|
||||
|
||||
|
||||
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||
for attr in node.attribute:
|
||||
if attr.name in attrs:
|
||||
|
|
|
@ -8,7 +8,7 @@ from onnx.helper import (
|
|||
make_tensor_value_info,
|
||||
)
|
||||
from onnx.checker import check_model
|
||||
from pyinfinitensor.onnx import from_onnx, backend, to_onnx, cpu_runtime
|
||||
from pyinfinitensor.onnx import from_onnx, backend, cpu_runtime
|
||||
|
||||
|
||||
def make_and_import_model(graph: onnx.GraphProto):
|
||||
|
@ -305,8 +305,6 @@ class TestStringMethods(unittest.TestCase):
|
|||
y = handler.tensor([3, 2, 1], 12)
|
||||
handler.reshape(x, y, [3, 2, 1])
|
||||
|
||||
to_onnx(handler, "test_frontend")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
|
@ -120,6 +121,13 @@ static Shape reshape_shape_of(Operator op) {
|
|||
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
|
||||
}
|
||||
|
||||
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::BatchNorm);
|
||||
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
|
||||
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
|
||||
batchnorm->getTraining());
|
||||
}
|
||||
|
||||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
|
@ -130,7 +138,8 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(reshape_shape_of)
|
||||
.FUNCTION(concat_axis_of)
|
||||
.FUNCTION(gather_axis_of)
|
||||
.FUNCTION(reduce_mean_axes_of);
|
||||
.FUNCTION(reduce_mean_axes_of)
|
||||
.FUNCTION(batch_norm_attrs_of);
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue