From 9b63a62b708a4c003d7f4dacffe0464fc004c568 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Mon, 15 Jan 2024 06:55:59 +0000 Subject: [PATCH] fix test-onnx --- pyinfinitensor/tests/test_onnx.py | 11 +++++++---- src/kernels/bang/dropout.cc | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 4d9c7574..7b5bad9c 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -463,13 +463,16 @@ class TestStringMethods(unittest.TestCase): def test_split(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) split = make_node("Split", ["input"], ["output"], name="split", axis=0) - make_and_import_model(make_graph([split], "split", [input], [])) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) + make_and_import_model(make_graph([split], "split", [input], [output])) def test_split1(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) - splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1]) - split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1) - make_and_import_model(make_graph([split], "split", [input, splitAttr], [])) + splitAttr = make_tensor("split", TensorProto.INT64, [2], [2, 1]) + output1 = make_tensor_value_info("output1", TensorProto.FLOAT, [1, 2, 2, 4]) + output2 = make_tensor_value_info("output2", TensorProto.FLOAT, [1, 1, 2, 4]) + split = make_node("Split", ["input", "split"], ["output1", "output2"], name="split", axis=1) + make_and_import_model(make_graph([split], "split", [input], [output1, output2], [splitAttr])) def test_allBroadcast(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) diff --git a/src/kernels/bang/dropout.cc b/src/kernels/bang/dropout.cc index bedf4d11..f46a587d 100644 --- a/src/kernels/bang/dropout.cc +++ b/src/kernels/bang/dropout.cc @@ -47,6 +47,7 @@ class DropoutCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Dropout, DropoutCnnl, "Dropout_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Dropout, DropoutCnnl, + "Dropout_cnnl_BANG_Float32"); }; // namespace infini