forked from jiuyuan/InfiniTensor
feat: 前端支持 batchNorm(无单元测试)
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
e194dd943b
commit
cca4d2a491
|
@ -40,6 +40,10 @@ class GraphHandlerObj {
|
||||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||||
Tensor bias, ActType act);
|
Tensor bias, ActType act);
|
||||||
|
|
||||||
|
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var,
|
||||||
|
Tensor scale, Tensor bias, float momentum, float eps,
|
||||||
|
bool training);
|
||||||
|
|
||||||
Tensor add(Tensor a, Tensor b, Tensor c);
|
Tensor add(Tensor a, Tensor b, Tensor c);
|
||||||
Tensor sub(Tensor a, Tensor b, Tensor c);
|
Tensor sub(Tensor a, Tensor b, Tensor c);
|
||||||
Tensor mul(Tensor a, Tensor b, Tensor c);
|
Tensor mul(Tensor a, Tensor b, Tensor c);
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import onnx
|
import typing, onnx, backend
|
||||||
import backend
|
|
||||||
|
|
||||||
runtime = backend.cpu_runtime()
|
runtime = backend.cpu_runtime()
|
||||||
|
|
||||||
|
@ -28,6 +27,20 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
None,
|
None,
|
||||||
backend.ActType.Linear,
|
backend.ActType.Linear,
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "BatchNormalization":
|
||||||
|
(input, mean, var, scale, bias) = (
|
||||||
|
tensors[node.input[i]] for i in [0, 3, 4, 1, 2]
|
||||||
|
)
|
||||||
|
output = tensors.get(node.output[0], None)
|
||||||
|
attributes = _parse_attribute(
|
||||||
|
node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0}
|
||||||
|
)
|
||||||
|
(momentum, eps, training) = (
|
||||||
|
attributes[name] for name in ["momentum", "epsilon", "training_mode"]
|
||||||
|
)
|
||||||
|
tensors[node.output[0]] = handler.batchNorm(
|
||||||
|
input, output, mean, var, scale, bias, momentum, eps, training != 0
|
||||||
|
)
|
||||||
elif node.op_type == "Add":
|
elif node.op_type == "Add":
|
||||||
tensors[node.output[0]] = handler.add(
|
tensors[node.output[0]] = handler.add(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
@ -136,3 +149,21 @@ def parse_onnx(model: onnx.ModelProto):
|
||||||
print("weight:")
|
print("weight:")
|
||||||
for node in model.graph.initializer:
|
for node in model.graph.initializer:
|
||||||
print(" {}".format(node.name))
|
print(" {}".format(node.name))
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_attribute(node: onnx.NodeProto, attrs: dict = dict()):
|
||||||
|
for attr in node.attribute:
|
||||||
|
if attr.name in attrs:
|
||||||
|
if attr.type == onnx.AttributeProto.INT:
|
||||||
|
attrs[attr.name] = attr.i
|
||||||
|
elif attr.type == onnx.AttributeProto.INTS:
|
||||||
|
attrs[attr.name] = attr.ints
|
||||||
|
elif attr.type == onnx.AttributeProto.FLOAT:
|
||||||
|
attrs[attr.name] = attr.f
|
||||||
|
elif attr.type == onnx.AttributeProto.STRING:
|
||||||
|
attrs[attr.name] = attr.s
|
||||||
|
elif attr.type == onnx.AttributeProto.TENSOR:
|
||||||
|
attrs[attr.name] = attr.t
|
||||||
|
else:
|
||||||
|
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
||||||
|
return attrs
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
|
#include "operators/batch_norm.h"
|
||||||
#include "operators/element_wise.h"
|
#include "operators/element_wise.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
|
@ -26,14 +27,32 @@ Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::batchNorm(Tensor input, Tensor output, Tensor mean,
|
||||||
|
Tensor var, Tensor scale, Tensor bias,
|
||||||
|
float momentum, float eps, bool training) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<BatchNormObj>(
|
||||||
|
std::move(input), output, std::move(mean), std::move(var),
|
||||||
|
std::move(scale), std::move(bias), momentum, eps, training);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g
|
||||||
|
->addOp<BatchNormObj>(std::move(input), output, std::move(mean),
|
||||||
|
std::move(var), std::move(scale),
|
||||||
|
std::move(bias), momentum, eps, training)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// see operators/element_wise.h
|
// see operators/element_wise.h
|
||||||
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
|
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
|
||||||
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \
|
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \
|
||||||
if (c) { \
|
if (c) { \
|
||||||
g->addOpWithOutputs<obj##Obj>(a, b, c); \
|
g->addOpWithOutputs<obj##Obj>(std::move(a), std::move(b), c); \
|
||||||
return c; \
|
return c; \
|
||||||
} else { \
|
} else { \
|
||||||
return g->addOp<obj##Obj>(a, b, c)->getOutput(); \
|
return g->addOp<obj##Obj>(std::move(a), std::move(b), c) \
|
||||||
|
->getOutput(); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,10 +66,10 @@ DEFINE_ELEMENT_WISE_METHOD(pow, Pow)
|
||||||
#define DEFINE_UNARY_METHOD(name, obj) \
|
#define DEFINE_UNARY_METHOD(name, obj) \
|
||||||
Tensor GraphHandlerObj::name(Tensor x, Tensor y) { \
|
Tensor GraphHandlerObj::name(Tensor x, Tensor y) { \
|
||||||
if (y) { \
|
if (y) { \
|
||||||
g->addOpWithOutputs<obj##Obj>(x, y); \
|
g->addOpWithOutputs<obj##Obj>(std::move(x), y); \
|
||||||
return y; \
|
return y; \
|
||||||
} else { \
|
} else { \
|
||||||
return g->addOp<obj##Obj>(x, y)->getOutput(); \
|
return g->addOp<obj##Obj>(std::move(x), y)->getOutput(); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,10 @@ void init_graph_builder(py::module &m) {
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||||
ActType>(&Handler::matmul),
|
ActType>(&Handler::matmul),
|
||||||
policy::move)
|
policy::move)
|
||||||
|
.def("batchNorm",
|
||||||
|
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
|
||||||
|
float, float, bool>(&Handler::batchNorm),
|
||||||
|
policy::move)
|
||||||
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
|
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
|
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
|
||||||
|
|
Loading…
Reference in New Issue