From ee0a5620069fca01f082e6937e0466d51b4e64f1 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 14 Feb 2023 08:54:58 +0800 Subject: [PATCH] =?UTF-8?q?test:=20batchNorm=20=E5=8D=95=E5=85=83=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- pyinfinitensor/tests/test_onnx.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 1f27c6b1..18b75773 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -35,6 +35,23 @@ class TestStringMethods(unittest.TestCase): matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul") make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa])) + def test_batch_norm(self): + x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2]) + scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1]) + b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 1, 1]) + mean = make_tensor_value_info("mean", TensorProto.FLOAT, [1, 3, 1, 1]) + var = make_tensor_value_info("var", TensorProto.FLOAT, [1, 3, 1, 1]) + y = make_tensor_value_info("y", TensorProto.UINT32, [1, 3, 2, 2]) + batch_norm = make_node( + "BatchNormalization", + ["x", "scale", "b", "mean", "var"], + ["y"], + name="batchNormalization", + ) + make_and_import_model( + make_graph([batch_norm], "batch_norm", [x, scale, b, mean, var], [y]) + ) + def test_add(self): a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])