diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index dc95873c..caec8ca2 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -70,7 +70,7 @@ class GraphHandlerObj { Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul); TensorVec split(Tensor input, std::optional outputs, int axis, - int num_outputs); + std::variant> numOrRatio); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis); Tensor reduceMean(Tensor data, Tensor reduced, diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index fa1facea..80457ecd 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -585,6 +585,20 @@ class OnnxStub: tensors.get(node.output[0]), ) elif node.op_type == "Split": + split = ( + _parse_data(data[node.input[1]]) + if (len(node.input) > 1) + else None + ) + if split is None: + split = next( + ( + attr.ints + for attr in node.attribute + if attr.name == "split" + ), + None, + ) for name, tensor in zip( node.output, self.handler.split( @@ -598,7 +612,7 @@ class OnnxStub: ), 0, ), - len(node.output), + split if split is not None else len(node.output), ), ): tensors[name] = tensor diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index ca290d76..02911b50 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -435,6 +435,12 @@ class TestStringMethods(unittest.TestCase): split = make_node("Split", ["input"], ["output"], name="split", axis=0) make_and_import_model(make_graph([split], "split", [input], [])) + def test_split1(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) + splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1]) + split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1) + make_and_import_model(make_graph([split], "split", [input, splitAttr], [])) + def test_allBroadcast(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index dd2c425f..e323538a 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -25,6 +25,7 @@ #include "operators/unary.h" #include "operators/where.h" #include +#include namespace infini { @@ -284,14 +285,29 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, } TensorVec GraphHandlerObj::split(Tensor input, std::optional outputs, - int axis, int num_outputs) { + int axis, + std::variant> numOrRatio) { if (outputs) { - g->addOpWithOutputs(std::move(input), outputs, axis, - num_outputs); + if (std::holds_alternative(numOrRatio)) { + g->addOpWithOutputs(std::move(input), outputs, axis, + std::get(numOrRatio)); + } else { + g->addOpWithOutputs(std::move(input), outputs, axis, + std::get>(numOrRatio)); + } return *outputs; } else { - return g->addOp(std::move(input), outputs, axis, num_outputs) - ->getOutputs(); + if (std::holds_alternative(numOrRatio)) { + return g + ->addOp(std::move(input), outputs, axis, + std::get(numOrRatio)) + ->getOutputs(); + } else { + return g + ->addOp(std::move(input), outputs, axis, + std::get>(numOrRatio)) + ->getOutputs(); + } } } diff --git a/test/kernels/cuda/test_cuda_split.cc b/test/kernels/cuda/test_cuda_split.cc index 2cab944e..43700b77 100644 --- a/test/kernels/cuda/test_cuda_split.cc +++ b/test/kernels/cuda/test_cuda_split.cc @@ -73,6 +73,38 @@ TEST(Split, CudaHigh) { 44., 45., 46., 47.})); } +TEST(Split, SplitWithRatio) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({2, 6, 2, 1, 2}, DataType::Float32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputGpu = gCuda->cloneTensor(input); + vector split = {2, 4}; + auto op = gCuda->addOp(inputGpu, std::nullopt, 1, split); + gCuda->dataMalloc(); + inputGpu->setData(IncrementalGenerator()); + + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + EXPECT_EQ(op->getOutputs().size(), (size_t)2); + auto o0Cpu = gCpu->cloneTensor(op->getOutput(0)); + auto o1Cpu = gCpu->cloneTensor(op->getOutput(1)); + EXPECT_TRUE( + o0Cpu->equalData(vector{0., 1., 2., 3., 4., 5., 6., 7., 24., 25., + 26., 27., 28., 29., 30., 31.})); + EXPECT_TRUE(o1Cpu->equalData( + vector{8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 32., 33., 34., 35., 36., 37., + 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.})); +} + TEST(Split, Cuda_dim0) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime);