forked from jiuyuan/InfiniTensor
fix test-onnx
This commit is contained in:
parent
8baa34a1d2
commit
9b63a62b70
|
@ -463,13 +463,16 @@ class TestStringMethods(unittest.TestCase):
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
|
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
|
||||||
make_and_import_model(make_graph([split], "split", [input], []))
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
make_and_import_model(make_graph([split], "split", [input], [output]))
|
||||||
|
|
||||||
def test_split1(self):
|
def test_split1(self):
|
||||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1])
|
splitAttr = make_tensor("split", TensorProto.INT64, [2], [2, 1])
|
||||||
split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1)
|
output1 = make_tensor_value_info("output1", TensorProto.FLOAT, [1, 2, 2, 4])
|
||||||
make_and_import_model(make_graph([split], "split", [input, splitAttr], []))
|
output2 = make_tensor_value_info("output2", TensorProto.FLOAT, [1, 1, 2, 4])
|
||||||
|
split = make_node("Split", ["input", "split"], ["output1", "output2"], name="split", axis=1)
|
||||||
|
make_and_import_model(make_graph([split], "split", [input], [output1, output2], [splitAttr]))
|
||||||
|
|
||||||
def test_allBroadcast(self):
|
def test_allBroadcast(self):
|
||||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
|
|
@ -47,6 +47,7 @@ class DropoutCnnl : public BangKernelWithoutConfig {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL(Device::BANG, OpType::Dropout, DropoutCnnl, "Dropout_cnnl_BANG_Float32");
|
REGISTER_KERNEL(Device::BANG, OpType::Dropout, DropoutCnnl,
|
||||||
|
"Dropout_cnnl_BANG_Float32");
|
||||||
|
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue