forked from jiuyuan/InfiniTensor
fix Issue 187 split infershape wrong (#197)
* fix: fix splitOp to support unequal portions * fix: fix as review comment --------- Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
6e7bd6ca0c
commit
3967b437c8
|
@ -70,7 +70,7 @@ class GraphHandlerObj {
|
|||
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||
Tensor position_id, Tensor output_matmul);
|
||||
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
||||
int num_outputs);
|
||||
std::variant<int, vector<int>> numOrRatio);
|
||||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||
Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis);
|
||||
Tensor reduceMean(Tensor data, Tensor reduced,
|
||||
|
|
|
@ -585,6 +585,20 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Split":
|
||||
split = (
|
||||
_parse_data(data[node.input[1]])
|
||||
if (len(node.input) > 1)
|
||||
else None
|
||||
)
|
||||
if split is None:
|
||||
split = next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "split"
|
||||
),
|
||||
None,
|
||||
)
|
||||
for name, tensor in zip(
|
||||
node.output,
|
||||
self.handler.split(
|
||||
|
@ -598,7 +612,7 @@ class OnnxStub:
|
|||
),
|
||||
0,
|
||||
),
|
||||
len(node.output),
|
||||
split if split is not None else len(node.output),
|
||||
),
|
||||
):
|
||||
tensors[name] = tensor
|
||||
|
|
|
@ -435,6 +435,12 @@ class TestStringMethods(unittest.TestCase):
|
|||
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
|
||||
make_and_import_model(make_graph([split], "split", [input], []))
|
||||
|
||||
def test_split1(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1])
|
||||
split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1)
|
||||
make_and_import_model(make_graph([split], "split", [input, splitAttr], []))
|
||||
|
||||
def test_allBroadcast(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "operators/unary.h"
|
||||
#include "operators/where.h"
|
||||
#include <numeric>
|
||||
#include <variant>
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -284,14 +285,29 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
|
|||
}
|
||||
|
||||
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||
int axis, int num_outputs) {
|
||||
int axis,
|
||||
std::variant<int, vector<int>> numOrRatio) {
|
||||
if (outputs) {
|
||||
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
|
||||
num_outputs);
|
||||
if (std::holds_alternative<int>(numOrRatio)) {
|
||||
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
|
||||
std::get<int>(numOrRatio));
|
||||
} else {
|
||||
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
|
||||
std::get<vector<int>>(numOrRatio));
|
||||
}
|
||||
return *outputs;
|
||||
} else {
|
||||
return g->addOp<SplitObj>(std::move(input), outputs, axis, num_outputs)
|
||||
->getOutputs();
|
||||
if (std::holds_alternative<int>(numOrRatio)) {
|
||||
return g
|
||||
->addOp<SplitObj>(std::move(input), outputs, axis,
|
||||
std::get<int>(numOrRatio))
|
||||
->getOutputs();
|
||||
} else {
|
||||
return g
|
||||
->addOp<SplitObj>(std::move(input), outputs, axis,
|
||||
std::get<vector<int>>(numOrRatio))
|
||||
->getOutputs();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -73,6 +73,38 @@ TEST(Split, CudaHigh) {
|
|||
44., 45., 46., 47.}));
|
||||
}
|
||||
|
||||
TEST(Split, SplitWithRatio) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({2, 6, 2, 1, 2}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto inputGpu = gCuda->cloneTensor(input);
|
||||
vector<int> split = {2, 4};
|
||||
auto op = gCuda->addOp<SplitObj>(inputGpu, std::nullopt, 1, split);
|
||||
gCuda->dataMalloc();
|
||||
inputGpu->setData(IncrementalGenerator());
|
||||
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
EXPECT_EQ(op->getOutputs().size(), (size_t)2);
|
||||
auto o0Cpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
auto o1Cpu = gCpu->cloneTensor(op->getOutput(1));
|
||||
EXPECT_TRUE(
|
||||
o0Cpu->equalData(vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 24., 25.,
|
||||
26., 27., 28., 29., 30., 31.}));
|
||||
EXPECT_TRUE(o1Cpu->equalData(
|
||||
vector<float>{8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 32., 33., 34., 35., 36., 37.,
|
||||
38., 39., 40., 41., 42., 43., 44., 45., 46., 47.}));
|
||||
}
|
||||
|
||||
TEST(Split, Cuda_dim0) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
|
Loading…
Reference in New Issue