forked from jiuyuan/InfiniTensor
test: batchNorm 单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
26be533faa
commit
ee0a562006
|
@ -35,6 +35,23 @@ class TestStringMethods(unittest.TestCase):
|
||||||
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
||||||
make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa]))
|
make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa]))
|
||||||
|
|
||||||
|
def test_batch_norm(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2])
|
||||||
|
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||||
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||||
|
mean = make_tensor_value_info("mean", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||||
|
var = make_tensor_value_info("var", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 3, 2, 2])
|
||||||
|
batch_norm = make_node(
|
||||||
|
"BatchNormalization",
|
||||||
|
["x", "scale", "b", "mean", "var"],
|
||||||
|
["y"],
|
||||||
|
name="batchNormalization",
|
||||||
|
)
|
||||||
|
make_and_import_model(
|
||||||
|
make_graph([batch_norm], "batch_norm", [x, scale, b, mean, var], [y])
|
||||||
|
)
|
||||||
|
|
||||||
def test_add(self):
|
def test_add(self):
|
||||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
|
Loading…
Reference in New Issue