diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 1f27c6b1..18b75773 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -35,6 +35,23 @@ class TestStringMethods(unittest.TestCase): matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul") make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa])) + def test_batch_norm(self): + x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2]) + scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1]) + b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 1, 1]) + mean = make_tensor_value_info("mean", TensorProto.FLOAT, [1, 3, 1, 1]) + var = make_tensor_value_info("var", TensorProto.FLOAT, [1, 3, 1, 1]) + y = make_tensor_value_info("y", TensorProto.UINT32, [1, 3, 2, 2]) + batch_norm = make_node( + "BatchNormalization", + ["x", "scale", "b", "mean", "var"], + ["y"], + name="batchNormalization", + ) + make_and_import_model( + make_graph([batch_norm], "batch_norm", [x, scale, b, mean, var], [y]) + ) + def test_add(self): a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])