test: batchNorm 单元测试

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 08:54:58 +08:00
parent 26be533faa
commit ee0a562006
1 changed files with 17 additions and 0 deletions

View File

@ -35,6 +35,23 @@ class TestStringMethods(unittest.TestCase):
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul") matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa])) make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa]))
def test_batch_norm(self):
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2])
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1])
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 1, 1])
mean = make_tensor_value_info("mean", TensorProto.FLOAT, [1, 3, 1, 1])
var = make_tensor_value_info("var", TensorProto.FLOAT, [1, 3, 1, 1])
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 3, 2, 2])
batch_norm = make_node(
"BatchNormalization",
["x", "scale", "b", "mean", "var"],
["y"],
name="batchNormalization",
)
make_and_import_model(
make_graph([batch_norm], "batch_norm", [x, scale, b, mean, var], [y])
)
def test_add(self): def test_add(self):
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])