diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 31e10c31..50eb6481 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -40,6 +40,10 @@ class GraphHandlerObj { Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor bias, ActType act); + Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var, + Tensor scale, Tensor bias, float momentum, float eps, + bool training); + Tensor add(Tensor a, Tensor b, Tensor c); Tensor sub(Tensor a, Tensor b, Tensor c); Tensor mul(Tensor a, Tensor b, Tensor c); diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index f3fd2898..ac46f0d8 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -1,5 +1,4 @@ -import onnx -import backend +import typing, onnx, backend runtime = backend.cpu_runtime() @@ -28,6 +27,20 @@ def from_onnx(model: onnx.ModelProto): None, backend.ActType.Linear, ) + elif node.op_type == "BatchNormalization": + (input, mean, var, scale, bias) = ( + tensors[node.input[i]] for i in [0, 3, 4, 1, 2] + ) + output = tensors.get(node.output[0], None) + attributes = _parse_attribute( + node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0} + ) + (momentum, eps, training) = ( + attributes[name] for name in ["momentum", "epsilon", "training_mode"] + ) + tensors[node.output[0]] = handler.batchNorm( + input, output, mean, var, scale, bias, momentum, eps, training != 0 + ) elif node.op_type == "Add": tensors[node.output[0]] = handler.add( tensors[node.input[0]], @@ -136,3 +149,21 @@ def parse_onnx(model: onnx.ModelProto): print("weight:") for node in model.graph.initializer: print(" {}".format(node.name)) + + +def _parse_attribute(node: onnx.NodeProto, attrs: dict = dict()): + for attr in node.attribute: + if attr.name in attrs: + if attr.type == onnx.AttributeProto.INT: + attrs[attr.name] = attr.i + elif attr.type == onnx.AttributeProto.INTS: + attrs[attr.name] = attr.ints + elif attr.type == onnx.AttributeProto.FLOAT: + attrs[attr.name] = attr.f + elif attr.type == onnx.AttributeProto.STRING: + attrs[attr.name] = attr.s + elif attr.type == onnx.AttributeProto.TENSOR: + attrs[attr.name] = attr.t + else: + assert False, "Unsupported Attribute Type: {}".format(attr.type) + return attrs diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 662e9e44..49b8ec57 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -1,4 +1,5 @@ #include "core/graph_handler.h" +#include "operators/batch_norm.h" #include "operators/element_wise.h" #include "operators/matmul.h" #include "operators/reshape.h" @@ -26,14 +27,32 @@ Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA, } } +Tensor GraphHandlerObj::batchNorm(Tensor input, Tensor output, Tensor mean, + Tensor var, Tensor scale, Tensor bias, + float momentum, float eps, bool training) { + if (output) { + g->addOpWithOutputs( + std::move(input), output, std::move(mean), std::move(var), + std::move(scale), std::move(bias), momentum, eps, training); + return output; + } else { + return g + ->addOp(std::move(input), output, std::move(mean), + std::move(var), std::move(scale), + std::move(bias), momentum, eps, training) + ->getOutput(); + } +} + // see operators/element_wise.h #define DEFINE_ELEMENT_WISE_METHOD(name, obj) \ Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \ if (c) { \ - g->addOpWithOutputs(a, b, c); \ + g->addOpWithOutputs(std::move(a), std::move(b), c); \ return c; \ } else { \ - return g->addOp(a, b, c)->getOutput(); \ + return g->addOp(std::move(a), std::move(b), c) \ + ->getOutput(); \ } \ } @@ -47,10 +66,10 @@ DEFINE_ELEMENT_WISE_METHOD(pow, Pow) #define DEFINE_UNARY_METHOD(name, obj) \ Tensor GraphHandlerObj::name(Tensor x, Tensor y) { \ if (y) { \ - g->addOpWithOutputs(x, y); \ + g->addOpWithOutputs(std::move(x), y); \ return y; \ } else { \ - return g->addOp(x, y)->getOutput(); \ + return g->addOp(std::move(x), y)->getOutput(); \ } \ } diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index be78a4d5..9d604d7f 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -46,6 +46,10 @@ void init_graph_builder(py::module &m) { py::overload_cast(&Handler::matmul), policy::move) + .def("batchNorm", + py::overload_cast(&Handler::batchNorm), + policy::move) .def("add", py::overload_cast(&Handler::add), policy::move) .def("sub", py::overload_cast(&Handler::sub),