forked from jiuyuan/InfiniTensor
fix Issue 187 split infershape wrong (#197)
* fix: fix splitOp to support unequal portions * fix: fix as review comment --------- Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
6e7bd6ca0c
commit
3967b437c8
|
@ -70,7 +70,7 @@ class GraphHandlerObj {
|
||||||
Tensor input_q, Tensor input_k, Tensor input_v,
|
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||||
Tensor position_id, Tensor output_matmul);
|
Tensor position_id, Tensor output_matmul);
|
||||||
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
||||||
int num_outputs);
|
std::variant<int, vector<int>> numOrRatio);
|
||||||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||||
Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis);
|
Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis);
|
||||||
Tensor reduceMean(Tensor data, Tensor reduced,
|
Tensor reduceMean(Tensor data, Tensor reduced,
|
||||||
|
|
|
@ -585,6 +585,20 @@ class OnnxStub:
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Split":
|
elif node.op_type == "Split":
|
||||||
|
split = (
|
||||||
|
_parse_data(data[node.input[1]])
|
||||||
|
if (len(node.input) > 1)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if split is None:
|
||||||
|
split = next(
|
||||||
|
(
|
||||||
|
attr.ints
|
||||||
|
for attr in node.attribute
|
||||||
|
if attr.name == "split"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
for name, tensor in zip(
|
for name, tensor in zip(
|
||||||
node.output,
|
node.output,
|
||||||
self.handler.split(
|
self.handler.split(
|
||||||
|
@ -598,7 +612,7 @@ class OnnxStub:
|
||||||
),
|
),
|
||||||
0,
|
0,
|
||||||
),
|
),
|
||||||
len(node.output),
|
split if split is not None else len(node.output),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
tensors[name] = tensor
|
tensors[name] = tensor
|
||||||
|
|
|
@ -435,6 +435,12 @@ class TestStringMethods(unittest.TestCase):
|
||||||
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
|
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
|
||||||
make_and_import_model(make_graph([split], "split", [input], []))
|
make_and_import_model(make_graph([split], "split", [input], []))
|
||||||
|
|
||||||
|
def test_split1(self):
|
||||||
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1])
|
||||||
|
split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1)
|
||||||
|
make_and_import_model(make_graph([split], "split", [input, splitAttr], []))
|
||||||
|
|
||||||
def test_allBroadcast(self):
|
def test_allBroadcast(self):
|
||||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
#include "operators/where.h"
|
#include "operators/where.h"
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
@ -284,14 +285,29 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
|
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||||
int axis, int num_outputs) {
|
int axis,
|
||||||
|
std::variant<int, vector<int>> numOrRatio) {
|
||||||
if (outputs) {
|
if (outputs) {
|
||||||
|
if (std::holds_alternative<int>(numOrRatio)) {
|
||||||
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
|
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
|
||||||
num_outputs);
|
std::get<int>(numOrRatio));
|
||||||
|
} else {
|
||||||
|
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
|
||||||
|
std::get<vector<int>>(numOrRatio));
|
||||||
|
}
|
||||||
return *outputs;
|
return *outputs;
|
||||||
} else {
|
} else {
|
||||||
return g->addOp<SplitObj>(std::move(input), outputs, axis, num_outputs)
|
if (std::holds_alternative<int>(numOrRatio)) {
|
||||||
|
return g
|
||||||
|
->addOp<SplitObj>(std::move(input), outputs, axis,
|
||||||
|
std::get<int>(numOrRatio))
|
||||||
->getOutputs();
|
->getOutputs();
|
||||||
|
} else {
|
||||||
|
return g
|
||||||
|
->addOp<SplitObj>(std::move(input), outputs, axis,
|
||||||
|
std::get<vector<int>>(numOrRatio))
|
||||||
|
->getOutputs();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -73,6 +73,38 @@ TEST(Split, CudaHigh) {
|
||||||
44., 45., 46., 47.}));
|
44., 45., 46., 47.}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Split, SplitWithRatio) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
|
||||||
|
auto input = gCpu->addTensor({2, 6, 2, 1, 2}, DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
input->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto inputGpu = gCuda->cloneTensor(input);
|
||||||
|
vector<int> split = {2, 4};
|
||||||
|
auto op = gCuda->addOp<SplitObj>(inputGpu, std::nullopt, 1, split);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
inputGpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
// copy output from CUDA to CPU
|
||||||
|
EXPECT_EQ(op->getOutputs().size(), (size_t)2);
|
||||||
|
auto o0Cpu = gCpu->cloneTensor(op->getOutput(0));
|
||||||
|
auto o1Cpu = gCpu->cloneTensor(op->getOutput(1));
|
||||||
|
EXPECT_TRUE(
|
||||||
|
o0Cpu->equalData(vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 24., 25.,
|
||||||
|
26., 27., 28., 29., 30., 31.}));
|
||||||
|
EXPECT_TRUE(o1Cpu->equalData(
|
||||||
|
vector<float>{8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
|
||||||
|
19., 20., 21., 22., 23., 32., 33., 34., 35., 36., 37.,
|
||||||
|
38., 39., 40., 41., 42., 43., 44., 45., 46., 47.}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Split, Cuda_dim0) {
|
TEST(Split, Cuda_dim0) {
|
||||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
|
Loading…
Reference in New Issue