forked from jiuyuan/InfiniTensor
fix test-onnx
This commit is contained in:
parent
8baa34a1d2
commit
9b63a62b70
|
@ -463,13 +463,16 @@ class TestStringMethods(unittest.TestCase):
|
|||
def test_split(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
|
||||
make_and_import_model(make_graph([split], "split", [input], []))
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
make_and_import_model(make_graph([split], "split", [input], [output]))
|
||||
|
||||
def test_split1(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1])
|
||||
split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1)
|
||||
make_and_import_model(make_graph([split], "split", [input, splitAttr], []))
|
||||
splitAttr = make_tensor("split", TensorProto.INT64, [2], [2, 1])
|
||||
output1 = make_tensor_value_info("output1", TensorProto.FLOAT, [1, 2, 2, 4])
|
||||
output2 = make_tensor_value_info("output2", TensorProto.FLOAT, [1, 1, 2, 4])
|
||||
split = make_node("Split", ["input", "split"], ["output1", "output2"], name="split", axis=1)
|
||||
make_and_import_model(make_graph([split], "split", [input], [output1, output2], [splitAttr]))
|
||||
|
||||
def test_allBroadcast(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
|
|
|
@ -47,6 +47,7 @@ class DropoutCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Dropout, DropoutCnnl, "Dropout_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Dropout, DropoutCnnl,
|
||||
"Dropout_cnnl_BANG_Float32");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
Loading…
Reference in New Issue