From ec3adf6fa73cc6390f09a9bbd23910640d9ed000 Mon Sep 17 00:00:00 2001 From: xgqdut2016 <140036308+xgqdut2016@users.noreply.github.com> Date: Tue, 31 Oct 2023 10:47:36 +0800 Subject: [PATCH] support 8D tensor, add test example (#170) Co-authored-by: Haojie Wang --- include/cuda/cuda_split_concat.h | 2 +- test/kernels/cuda/test_cuda_concat.cc | 36 +++++++++++++++++++++++++++ test/kernels/cuda/test_cuda_split.cc | 34 +++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/include/cuda/cuda_split_concat.h b/include/cuda/cuda_split_concat.h index 699f15bc..58bdf330 100644 --- a/include/cuda/cuda_split_concat.h +++ b/include/cuda/cuda_split_concat.h @@ -3,7 +3,7 @@ #include const int BATCH_SIZE = 32; // parallel tensor number. -const int DIM_MAX_SIZE = 4; +const int DIM_MAX_SIZE = 8; // Concat operator acts like element tensors composing to one big tensor,and // split operator acts like one big tensor being composed by element diff --git a/test/kernels/cuda/test_cuda_concat.cc b/test/kernels/cuda/test_cuda_concat.cc index 013d25b5..2c76f405 100644 --- a/test/kernels/cuda/test_cuda_concat.cc +++ b/test/kernels/cuda/test_cuda_concat.cc @@ -122,4 +122,40 @@ TEST(Concat, Cuda_dim0) { EXPECT_TRUE(oCpu->equalData(vector{0, 1, 2, 1, 1, 1, 0, 1, 2})); } +TEST(Concat, CudaHigh) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto t1 = gCpu->addTensor({2, 2, 3, 1, 2}, DataType::Float32); + auto t2 = gCpu->addTensor({2, 2, 1, 1, 2}, DataType::Float32); + auto t3 = gCpu->addTensor({2, 2, 2, 1, 2}, DataType::Float32); + gCpu->dataMalloc(); + t1->setData(IncrementalGenerator()); + t2->setData(OneGenerator()); + t3->setData(OneGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto t1Gpu = gCuda->cloneTensor(t1); + auto t2Gpu = gCuda->cloneTensor(t2); + auto t3Gpu = gCuda->cloneTensor(t3); + + auto op = + gCuda->addOp(TensorVec{t1Gpu, t2Gpu, t3Gpu}, nullptr, 2); + gCuda->dataMalloc(); + t1Gpu->setData(IncrementalGenerator()); + t2Gpu->setData(OneGenerator()); + t3Gpu->setData(OneGenerator()); + cudaRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData( + vector{0., 1., 2., 3., 4., 5., 1., 1., 1., 1., 1., 1., + 6., 7., 8., 9., 10., 11., 1., 1., 1., 1., 1., 1., + 12., 13., 14., 15., 16., 17., 1., 1., 1., 1., 1., 1., + 18., 19., 20., 21., 22., 23., 1., 1., 1., 1., 1., 1.})); +} } // namespace infini diff --git a/test/kernels/cuda/test_cuda_split.cc b/test/kernels/cuda/test_cuda_split.cc index 5a32f27f..2cab944e 100644 --- a/test/kernels/cuda/test_cuda_split.cc +++ b/test/kernels/cuda/test_cuda_split.cc @@ -39,6 +39,40 @@ TEST(Split, Cuda) { 12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39})); } +TEST(Split, CudaHigh) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({2, 6, 2, 1, 2}, DataType::Float32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputGpu = gCuda->cloneTensor(input); + auto op = gCuda->addOp(inputGpu, std::nullopt, 1, 3); + gCuda->dataMalloc(); + inputGpu->setData(IncrementalGenerator()); + + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + EXPECT_EQ(op->getOutputs().size(), (size_t)3); + auto o0Cpu = gCpu->cloneTensor(op->getOutput(0)); + auto o1Cpu = gCpu->cloneTensor(op->getOutput(1)); + auto o2Cpu = gCpu->cloneTensor(op->getOutput(2)); + EXPECT_TRUE( + o0Cpu->equalData(vector{0., 1., 2., 3., 4., 5., 6., 7., 24., 25., + 26., 27., 28., 29., 30., 31.})); + EXPECT_TRUE(o1Cpu->equalData(vector{8., 9., 10., 11., 12., 13., 14., + 15., 32., 33., 34., 35., 36., + 37., 38., 39.})); + EXPECT_TRUE(o2Cpu->equalData(vector{16., 17., 18., 19., 20., 21., + 22., 23., 40., 41., 42., 43., + 44., 45., 46., 47.})); +} + TEST(Split, Cuda_dim0) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime);