fix test-onnx

This commit is contained in:
kilinchange 2024-01-15 06:55:59 +00:00
parent 8baa34a1d2
commit 9b63a62b70
2 changed files with 9 additions and 5 deletions

View File

@ -463,13 +463,16 @@ class TestStringMethods(unittest.TestCase):
def test_split(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
make_and_import_model(make_graph([split], "split", [input], []))
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
make_and_import_model(make_graph([split], "split", [input], [output]))
def test_split1(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1])
split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1)
make_and_import_model(make_graph([split], "split", [input, splitAttr], []))
splitAttr = make_tensor("split", TensorProto.INT64, [2], [2, 1])
output1 = make_tensor_value_info("output1", TensorProto.FLOAT, [1, 2, 2, 4])
output2 = make_tensor_value_info("output2", TensorProto.FLOAT, [1, 1, 2, 4])
split = make_node("Split", ["input", "split"], ["output1", "output2"], name="split", axis=1)
make_and_import_model(make_graph([split], "split", [input], [output1, output2], [splitAttr]))
def test_allBroadcast(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])

View File

@ -47,6 +47,7 @@ class DropoutCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Dropout, DropoutCnnl, "Dropout_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Dropout, DropoutCnnl,
"Dropout_cnnl_BANG_Float32");
}; // namespace infini