fix Issue 187 split infershape wrong (#197)

* fix: fix splitOp to support unequal portions

* fix: fix as review comment

---------

Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
zhangyunze 2023-12-28 21:39:24 +08:00 committed by GitHub
parent 6e7bd6ca0c
commit 3967b437c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 75 additions and 7 deletions

View File

@ -70,7 +70,7 @@ class GraphHandlerObj {
Tensor input_q, Tensor input_k, Tensor input_v,
Tensor position_id, Tensor output_matmul);
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
int num_outputs);
std::variant<int, vector<int>> numOrRatio);
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis);
Tensor reduceMean(Tensor data, Tensor reduced,

View File

@ -585,6 +585,20 @@ class OnnxStub:
tensors.get(node.output[0]),
)
elif node.op_type == "Split":
split = (
_parse_data(data[node.input[1]])
if (len(node.input) > 1)
else None
)
if split is None:
split = next(
(
attr.ints
for attr in node.attribute
if attr.name == "split"
),
None,
)
for name, tensor in zip(
node.output,
self.handler.split(
@ -598,7 +612,7 @@ class OnnxStub:
),
0,
),
len(node.output),
split if split is not None else len(node.output),
),
):
tensors[name] = tensor

View File

@ -435,6 +435,12 @@ class TestStringMethods(unittest.TestCase):
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
make_and_import_model(make_graph([split], "split", [input], []))
def test_split1(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1])
split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1)
make_and_import_model(make_graph([split], "split", [input, splitAttr], []))
def test_allBroadcast(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])

View File

@ -25,6 +25,7 @@
#include "operators/unary.h"
#include "operators/where.h"
#include <numeric>
#include <variant>
namespace infini {
@ -284,14 +285,29 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
}
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
int axis, int num_outputs) {
int axis,
std::variant<int, vector<int>> numOrRatio) {
if (outputs) {
if (std::holds_alternative<int>(numOrRatio)) {
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
num_outputs);
std::get<int>(numOrRatio));
} else {
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
std::get<vector<int>>(numOrRatio));
}
return *outputs;
} else {
return g->addOp<SplitObj>(std::move(input), outputs, axis, num_outputs)
if (std::holds_alternative<int>(numOrRatio)) {
return g
->addOp<SplitObj>(std::move(input), outputs, axis,
std::get<int>(numOrRatio))
->getOutputs();
} else {
return g
->addOp<SplitObj>(std::move(input), outputs, axis,
std::get<vector<int>>(numOrRatio))
->getOutputs();
}
}
}

View File

@ -73,6 +73,38 @@ TEST(Split, CudaHigh) {
44., 45., 46., 47.}));
}
TEST(Split, SplitWithRatio) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({2, 6, 2, 1, 2}, DataType::Float32);
gCpu->dataMalloc();
input->setData(IncrementalGenerator());
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputGpu = gCuda->cloneTensor(input);
vector<int> split = {2, 4};
auto op = gCuda->addOp<SplitObj>(inputGpu, std::nullopt, 1, split);
gCuda->dataMalloc();
inputGpu->setData(IncrementalGenerator());
cudaRuntime->run(gCuda);
// copy output from CUDA to CPU
EXPECT_EQ(op->getOutputs().size(), (size_t)2);
auto o0Cpu = gCpu->cloneTensor(op->getOutput(0));
auto o1Cpu = gCpu->cloneTensor(op->getOutput(1));
EXPECT_TRUE(
o0Cpu->equalData(vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 24., 25.,
26., 27., 28., 29., 30., 31.}));
EXPECT_TRUE(o1Cpu->equalData(
vector<float>{8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 32., 33., 34., 35., 36., 37.,
38., 39., 40., 41., 42., 43., 44., 45., 46., 47.}));
}
TEST(Split, Cuda_dim0) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);