From 3c6e208f4200ab5eac24f4b7da0d3a11165ef447 Mon Sep 17 00:00:00 2001 From: wendy12022 Date: Thu, 29 Sep 2022 11:01:30 +0800 Subject: [PATCH] ADD:concat/split operator and cuda kernels (#29) * ADD:concat/split operator and cuda kernels refector minor change comment ADD:concat/split operator and cuda kernels merge split_kernel and concat_kernel to split_concat_kernel. Revert "fix" This reverts commit 459926be09a838658ec55f1e0a72b3cf17037d5c. fix ADD:concat/split operator and cuda kernels change whole tensor name to composed tensor fix some remove unused header. rebase add CudaKernel add test for split. ADD split operator and cuda kernel. modify test. ADD:concat operator and cuda kernel. ADD:concat/split operator and cuda kernels fix some remove unused header. rebase add CudaKernel ADD:concat/split operator and cuda kernels add test for split. ADD split operator and cuda kernel. modify test. ADD:concat operator and cuda kernel. * remove extra comment; typo fix. Co-authored-by: Haojie Wang --- include/core/operator.h | 4 ++ include/cuda/cuda_element_wise.h | 19 ------ include/cuda/cuda_split_concat.h | 35 +++++++++++ include/operators/concat.h | 22 +++++++ include/operators/pad.h | 2 +- include/operators/split.h | 25 ++++++++ src/kernels/cuda/element_wise.cc | 14 ++++- src/kernels/cuda/split_concat.cc | 81 ++++++++++++++++++++++++ src/kernels/cuda/split_concat.cu | 71 +++++++++++++++++++++ src/operators/concat.cc | 58 +++++++++++++++++ src/operators/split.cc | 89 +++++++++++++++++++++++++++ test/kernels/cuda/test_cuda_concat.cc | 76 +++++++++++++++++++++++ test/kernels/cuda/test_cuda_split.cc | 40 ++++++++++++ test/operators/test_concat.cc | 17 +++++ test/operators/test_split.cc | 38 ++++++++++++ 15 files changed, 570 insertions(+), 21 deletions(-) create mode 100644 include/cuda/cuda_split_concat.h create mode 100644 include/operators/concat.h create mode 100644 include/operators/split.h create mode 100644 src/kernels/cuda/split_concat.cc create mode 100644 src/kernels/cuda/split_concat.cu create mode 100644 src/operators/concat.cc create mode 100644 src/operators/split.cc create mode 100644 test/kernels/cuda/test_cuda_concat.cc create mode 100644 test/kernels/cuda/test_cuda_split.cc create mode 100644 test/operators/test_concat.cc create mode 100644 test/operators/test_split.cc diff --git a/include/core/operator.h b/include/core/operator.h index b2db95ce..a40de331 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -186,6 +186,10 @@ class OperatorObj : public Object { IT_ASSERT(outputs.size() == 1, "Unimplemented"); return outputs[0]; } + Tensor getOutput(size_t i) const { + IT_ASSERT(i < outputs.size(), "Index exceeded"); + return outputs.at(i); + } OpType getOpType() const { return type; } // HACK: set correct data type DataType getDType() const { return getInputs(0)->getDType(); } diff --git a/include/cuda/cuda_element_wise.h b/include/cuda/cuda_element_wise.h index b6dede22..d51a04cf 100644 --- a/include/cuda/cuda_element_wise.h +++ b/include/cuda/cuda_element_wise.h @@ -1,25 +1,6 @@ #pragma once -#include "operators/element_wise.h" - namespace infini { void div_kernel(float *a, float *b, float *c, int num); void pow_kernel(float *a, float *b, float *c, int num); - -void element_wise_kernel(const Operator &_op) { - auto op = as(_op); - float *const aData = (op->getInputs(0)->getRawDataPtr()); - float *const bData = (op->getInputs(1)->getRawDataPtr()); - float *const cData = (op->getOutput()->getRawDataPtr()); - - auto dim = op->getInputs(0)->getDims(); - int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; - if (op->getOpType() == OpType::Div) - div_kernel(aData, bData, cData, n * c * h * w); - else if (op->getOpType() == OpType::Pow) - pow_kernel(aData, bData, cData, n * c * h * w); - else - IT_TODO_HALT(); -} - }; // namespace infini \ No newline at end of file diff --git a/include/cuda/cuda_split_concat.h b/include/cuda/cuda_split_concat.h new file mode 100644 index 00000000..454a8f52 --- /dev/null +++ b/include/cuda/cuda_split_concat.h @@ -0,0 +1,35 @@ + +#pragma once +#include + +const int BATCH_SIZE = 32; // parallel tensor number. +const int DIM_MAX_SIZE = 4; + +// Concat operator acts like element tensors composing to one big tensor,and +// split operator acts like one big tensor being composed by element +// tensors. +struct ElementTensorMetadata { + float *data[BATCH_SIZE]; + int dimBgNo[BATCH_SIZE]; // the dimention begin no of the element tensor in + // the composed tensor. + int dimSize[BATCH_SIZE]; // the dimention size of the element tensor. + int nElements[BATCH_SIZE]; // the number of elements of the element tensor. + void print() const { + for (int i = 0; i < BATCH_SIZE; i++) + printf("%d:(data=%p,dimBgNo=%d,dimSize=%d,nElements=%d)\n", i, + data[i], dimBgNo[i], dimSize[i], nElements[i]); + } +}; + +struct ComposedTensorMetadata { + int dimSize[DIM_MAX_SIZE]; + int stride[DIM_MAX_SIZE]; + float *data; +}; + +namespace infini { +void split_concat_kernel(const ElementTensorMetadata &eleMeta, + const ComposedTensorMetadata &compMeta, int dim, + int batchSize, int nDims, bool isSplit); + +} // namespace infini \ No newline at end of file diff --git a/include/operators/concat.h b/include/operators/concat.h new file mode 100644 index 00000000..ebca158c --- /dev/null +++ b/include/operators/concat.h @@ -0,0 +1,22 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class ConcatObj : public OperatorObj { + int dim; + + public: + ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim); + + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return inputs.size(); } + int numOutputs() const override { return 1; } + int getDim() const { return dim; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini \ No newline at end of file diff --git a/include/operators/pad.h b/include/operators/pad.h index 2e80666a..d60443eb 100644 --- a/include/operators/pad.h +++ b/include/operators/pad.h @@ -15,7 +15,7 @@ class PadObj : public OperatorObj { std::string toString() const override; int numInputs() const override { return 1; } int numOutputs() const override { return 1; } - Shape PadObj::getPads() const { return pads; } + Shape getPads() const { return pads; } private: vector getWorkloadVector() const override; diff --git a/include/operators/split.h b/include/operators/split.h new file mode 100644 index 00000000..980060ce --- /dev/null +++ b/include/operators/split.h @@ -0,0 +1,25 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class SplitObj : public OperatorObj { + int dim, num; // split dim;Average split num or outputs size + vector ratio; // output dim ratio + public: + SplitObj(GraphObj *graph, Tensor input, std::optional outputs, + int dim, int num); + SplitObj(GraphObj *graph, Tensor input, std::optional outputs, + int dim, const vector &ratio); + + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return num; } + int getDim() const { return dim; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini \ No newline at end of file diff --git a/src/kernels/cuda/element_wise.cc b/src/kernels/cuda/element_wise.cc index 6e4130e0..374fcf5d 100644 --- a/src/kernels/cuda/element_wise.cc +++ b/src/kernels/cuda/element_wise.cc @@ -78,7 +78,19 @@ class MulCudnn : public ElementWiseCudnn { class ElementWiseCuda : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { - element_wise_kernel(_op); + auto op = as(_op); + float *const aData = (op->getInputs(0)->getRawDataPtr()); + float *const bData = (op->getInputs(1)->getRawDataPtr()); + float *const cData = (op->getOutput()->getRawDataPtr()); + + auto dim = op->getInputs(0)->getDims(); + int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; + if (op->getOpType() == OpType::Div) + div_kernel(aData, bData, cData, n * c * h * w); + else if (op->getOpType() == OpType::Pow) + pow_kernel(aData, bData, cData, n * c * h * w); + else + IT_TODO_HALT(); } }; diff --git a/src/kernels/cuda/split_concat.cc b/src/kernels/cuda/split_concat.cc new file mode 100644 index 00000000..b7a12e5a --- /dev/null +++ b/src/kernels/cuda/split_concat.cc @@ -0,0 +1,81 @@ +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_split_concat.h" +#include "operators/concat.h" +#include "operators/split.h" +#include + +namespace infini { + +void initComposedTensorMetadata(ComposedTensorMetadata &metadata, + Tensor tensor) { + int nDims = tensor->getDims().size(); + auto strides = tensor->getStride(); + IT_ASSERT(strides.size() == (size_t)nDims); + for (int i = 0; i < nDims; ++i) { + metadata.dimSize[i] = tensor->getDims().at(i); + metadata.stride[i] = strides.at(i); + } + metadata.data = tensor->getRawDataPtr(); +} + +void initElementTensorMetadata(ElementTensorMetadata &metadata, + TensorVec tensors, int idx, int dim, + int &dimBgIdx, int &batchCounter) { + int nTensors = tensors.size(); + for (; batchCounter < BATCH_SIZE && idx + batchCounter < nTensors; + ++batchCounter) { + auto tensor = tensors.at(idx + batchCounter); + auto dimSize = tensor->getDims()[dim]; + metadata.data[batchCounter] = tensor->getRawDataPtr(); + metadata.dimBgNo[batchCounter] = dimBgIdx; + metadata.dimSize[batchCounter] = dimSize; + metadata.nElements[batchCounter] = tensor->size(); + dimBgIdx += dimSize; + } +} + +class CudaCompute { + public: + void do_compute(Tensor composedTensor, TensorVec elementsTensor, int dim, + int nDims, bool isSplit) const { + IT_ASSERT(nDims <= DIM_MAX_SIZE); + + ComposedTensorMetadata composedMetadata; + initComposedTensorMetadata(composedMetadata, composedTensor); + + int dimBgNo = 0; + int nElemets = elementsTensor.size(); + for (int i = 0; i < nElemets; i += BATCH_SIZE) { + ElementTensorMetadata elemMetadata; + int batchCounter = 0; + initElementTensorMetadata(elemMetadata, elementsTensor, i, dim, + dimBgNo, batchCounter); + split_concat_kernel(elemMetadata, composedMetadata, dim, + batchCounter, nDims, isSplit); + } + } +}; + +class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + do_compute(_op->getOutput(), _op->getInputs(), + as(_op)->getDim(), + _op->getOutput()->getDims().size(), false); + } +}; + +class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + do_compute(_op->getInputs(0), _op->getOutputs(), + as(_op)->getDim(), + _op->getInputs(0)->getDims().size(), true); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Concat, DataType::Float32, ConcatCuda, + "Concat_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Split, DataType::Float32, SplitCuda, + "Split_CUDA_Float32"); +} // namespace infini \ No newline at end of file diff --git a/src/kernels/cuda/split_concat.cu b/src/kernels/cuda/split_concat.cu new file mode 100644 index 00000000..e71e7890 --- /dev/null +++ b/src/kernels/cuda/split_concat.cu @@ -0,0 +1,71 @@ +#include "cuda/cuda_common.h" +#include "cuda/cuda_split_concat.h" + +int getMultiProcessorCount() { + int cur_device; + checkCudaError(cudaGetDevice(&cur_device)); + + struct cudaDeviceProp prop; + checkCudaError(cudaGetDeviceProperties(&prop, cur_device)); + return prop.multiProcessorCount; +} + +__host__ __device__ int +elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim, + int nDim, ComposedTensorMetadata wholeMeta) { + int offset = 0; + +#pragma unroll + for (int i = nDim - 1; i >= 1; --i) { + int size = (i == dim) ? dimSize : wholeMeta.dimSize[i]; + int p = elementIndex % size; + int oP = (i == dim) ? (p + dimBgNo) : p; + elementIndex = (elementIndex - p) / size; + offset += oP * wholeMeta.stride[i]; + } + + return offset + elementIndex * wholeMeta.stride[0]; +} + +__global__ void _split_concat_kernel(ElementTensorMetadata elemMeta, + ComposedTensorMetadata compMeta, int dim, + int nDims, bool isSplit) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int nElements = elemMeta.nElements[blockIdx.y]; + if (tid >= nElements) + return; + + auto dimBgNo = elemMeta.dimBgNo[blockIdx.y]; + auto dimSize = elemMeta.dimSize[blockIdx.y]; + float *elemData = elemMeta.data[blockIdx.y]; + int stride = gridDim.x * blockDim.x; + + while (tid < nElements) { + int Offset = + elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta); + // copy data from input to output + // for split:input is composed tensor;for concat:input is element + // tensors. + if (isSplit) + elemData[tid] = compMeta.data[Offset]; + else + compMeta.data[Offset] = elemData[tid]; + tid += stride; + } +} + +namespace infini { + +void split_concat_kernel(const ElementTensorMetadata &eleMeta, + const ComposedTensorMetadata &compMeta, int dim, + int batchSize, int nDims, bool isSplit) { + dim3 blockSize = dim3(32 * 16); + + // y dim is number of tensors. + dim3 gridSize(getMultiProcessorCount(), batchSize); + + _split_concat_kernel<<>>(eleMeta, compMeta, dim, nDims, + isSplit); +} + +} // namespace infini \ No newline at end of file diff --git a/src/operators/concat.cc b/src/operators/concat.cc new file mode 100644 index 00000000..98a5527a --- /dev/null +++ b/src/operators/concat.cc @@ -0,0 +1,58 @@ +#include "operators/concat.h" + +namespace infini { +ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim) + : OperatorObj(OpType::Concat, inputs, {output}), dim(dim) { + IT_ASSERT(checkValid(graph)); +} + +optional> ConcatObj::inferShape(const TensorVec &inputs) const { + IT_ASSERT(inputs.size() > 1); + Shape dims = inputs[0]->getDims(); + ShapeElem n = dims.at(dim); + for (auto itr = inputs.begin() + 1; itr != inputs.end(); ++itr) { + auto input = *itr; + auto iDims = input->getDims(); + if (dims.size() != iDims.size()) + return {}; + int nDims = dims.size(); + for (auto i = 0; i < nDims; i++) { + if (i == dim) { + n += iDims.at(i); + continue; + } + if (iDims.at(i) != dims.at(i)) + return {}; + } + } + dims[dim] = n; + return {{dims}}; +} + +std::string ConcatObj::toString() const { + std::ostringstream os; + os << "Concat[" << getGuid() << "]"; + os << "("; + for (auto input : inputs) + os << vecToString(input->getDims()) << ","; + os << "dim=" << dim << ","; + os << "input="; + for (auto input : inputs) + os << input->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector ConcatObj::getWorkloadVector() const { + vector ret = getOutput()->getDims(); + ret.emplace(ret.begin(), (int)inputs.size()); + ret.emplace(ret.begin(), dim); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +vector ConcatObj::getOpAttrVector() const { + return {enum_to_underlying(type), dim}; +} + +} // namespace infini \ No newline at end of file diff --git a/src/operators/split.cc b/src/operators/split.cc new file mode 100644 index 00000000..f387d4e4 --- /dev/null +++ b/src/operators/split.cc @@ -0,0 +1,89 @@ +#include "operators/split.h" +#include + +namespace infini { +SplitObj::SplitObj(GraphObj *graph, Tensor input, + std::optional outputs, int dim, int num) + : OperatorObj(OpType::Split, {input}, + ((!outputs) ? TensorVec{nullptr} : (*outputs))), + dim(dim), num(num), ratio({}) { + int dimSize = input->getDims().at(dim); + int pieceSize = dimSize / num; + int lastSize = dimSize - pieceSize * num; + + if (lastSize > 0) { + ratio = std::vector(num - 1, pieceSize); + ratio.emplace_back(lastSize + pieceSize); + } else + ratio = std::vector(num, pieceSize); + + if (!outputs) { + TensorVec tmp(num, nullptr); + this->outputs = tmp; + } + IT_ASSERT(checkValid(graph)); +} + +SplitObj::SplitObj(GraphObj *graph, Tensor input, + std::optional outputs, int dim, + const vector &ratio) + : OperatorObj(OpType::Split, {input}, + ((!outputs) ? TensorVec{nullptr} : (*outputs))), + dim(dim), num(-1), ratio(ratio) { + num = ratio.size(); + if (!outputs) { + TensorVec tmp(num, nullptr); + this->outputs = tmp; + } + IT_ASSERT(checkValid(graph)); +} + +optional> SplitObj::inferShape(const TensorVec &inputs) const { + if (num == -1 || ratio.size() == 0) + return {}; + auto inputDims = inputs[0]->getDims(); + int totalSize = inputDims.at(dim); + int ratioSum = std::accumulate(ratio.begin(), ratio.end(), 0); + if (totalSize % ratioSum != 0) + return {}; + + int pieceSize = totalSize / ratioSum; + + vector ret; + Shape outShape = inputDims; + for (int i = 0; i < num; i++) { + outShape[dim] = pieceSize * ratio.at(i); + ret.push_back(outShape); + } + return {ret}; +} + +vector SplitObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.emplace(ret.begin(), enum_to_underlying(type)); + ret.emplace_back(dim); + ret.emplace_back(num); + return ret; +} + +vector SplitObj::getOpAttrVector() const { + return {enum_to_underlying(type), dim, num}; +} + +string SplitObj::toString() const { + std::ostringstream os; + os << "Split[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "dim=" << dim << ","; + os << "num= " << num << ","; + os << "ratio= " << vecToString(ratio) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output="; + for (auto i = 0; i < num; i++) + os << outputs[i]->getGuid() << ","; + os << ")"; + return os.str(); +} + +} // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_cuda_concat.cc b/test/kernels/cuda/test_cuda_concat.cc new file mode 100644 index 00000000..13ccb67a --- /dev/null +++ b/test/kernels/cuda/test_cuda_concat.cc @@ -0,0 +1,76 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/concat.h" + +#include "test.h" + +namespace infini { +/* +int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize, + int concatDim, int outputDimSize[4], + int outputStride[4], int nDim) { + int offset = 0; + + for (int i = nDim - 1; i >= 1; --i) { + int size = (i == concatDim) ? dimSize : outputDimSize[i]; + int p = linearIndex % size; + int oP = (i == concatDim) ? (p + dimBgNo) : p; + linearIndex = (linearIndex - p) / size; + + offset += oP * outputStride[i]; + } + + return offset + linearIndex * outputStride[0]; +} + +TEST(Concat, OffsetTrans) { + int dimSize[] = {2, 3}; + int strides[] = {3, 1}; + int catDim = 1, nDim = 2; + EXPECT_EQ(inputOffset2CatOffset(0, 0, 1, catDim, dimSize, strides, nDim), + 0); + EXPECT_EQ(inputOffset2CatOffset(1, 0, 1, catDim, dimSize, strides, nDim), + 3); + EXPECT_EQ(inputOffset2CatOffset(0, 1, 2, catDim, dimSize, strides, nDim), + 1); + EXPECT_EQ(inputOffset2CatOffset(1, 1, 2, catDim, dimSize, strides, nDim), + 2); + EXPECT_EQ(inputOffset2CatOffset(2, 1, 2, catDim, dimSize, strides, nDim), + 4); + EXPECT_EQ(inputOffset2CatOffset(3, 1, 2, catDim, dimSize, strides, nDim), + 5); +} +*/ +TEST(Concat, Cuda) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto t1 = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32); + auto t2 = gCpu->addTensor({2, 2, 1, 1}, DataType::Float32); + auto t3 = gCpu->addTensor({2, 2, 2, 1}, DataType::Float32); + gCpu->dataMalloc(); + t1->setData(IncrementalGenerator()); + t2->setData(OneGenerator()); + t3->setData(OneGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp(TensorVec{gCuda->cloneTensor(t1), + gCuda->cloneTensor(t2), + gCuda->cloneTensor(t3)}, + nullptr, 2); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE( + oCpu->equalData(vector{0, 1, 2, 1, 1, 1, 3, 4, 5, 1, 1, 1, + 6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1})); +} + +} // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_cuda_split.cc b/test/kernels/cuda/test_cuda_split.cc new file mode 100644 index 00000000..38e409f0 --- /dev/null +++ b/test/kernels/cuda/test_cuda_split.cc @@ -0,0 +1,40 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/split.h" + +#include "test.h" + +namespace infini { + +TEST(Split, Cuda) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({2, 10, 2, 1}, DataType::Float32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = + gCuda->addOp(gCuda->cloneTensor(input), std::nullopt, 1, 3); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + EXPECT_EQ(op->getOutputs().size(), (size_t)3); + auto o0Cpu = gCpu->cloneTensor(op->getOutput(0)); + auto o1Cpu = gCpu->cloneTensor(op->getOutput(1)); + auto o2Cpu = gCpu->cloneTensor(op->getOutput(2)); + EXPECT_TRUE(o0Cpu->equalData( + vector{0, 1, 2, 3, 4, 5, 20, 21, 22, 23, 24, 25})); + EXPECT_TRUE(o1Cpu->equalData( + vector{6, 7, 8, 9, 10, 11, 26, 27, 28, 29, 30, 31})); + EXPECT_TRUE(o2Cpu->equalData(vector{ + 12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39})); +} + +} // namespace infini \ No newline at end of file diff --git a/test/operators/test_concat.cc b/test/operators/test_concat.cc new file mode 100644 index 00000000..15ef074b --- /dev/null +++ b/test/operators/test_concat.cc @@ -0,0 +1,17 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/concat.h" +#include "test.h" + +namespace infini { +TEST(Concat, ShapeInfer) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto t1 = g->addTensor({1, 3, 2, 4}, DataType::Float32); + auto t2 = g->addTensor({1, 3, 2, 5}, DataType::Float32); + + auto op = g->addOp(TensorVec{t1, t2}, nullptr, 3); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 9})); +} + +} // namespace infini \ No newline at end of file diff --git a/test/operators/test_split.cc b/test/operators/test_split.cc new file mode 100644 index 00000000..d0e76031 --- /dev/null +++ b/test/operators/test_split.cc @@ -0,0 +1,38 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/split.h" + +#include "test.h" + +namespace infini { +TEST(Split, ShapeInfer) { + { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32); + + auto op = g->addOp(input, std::nullopt, 3, 4); + EXPECT_EQ(op->numOutputs(), 4); + EXPECT_EQ(op->getOutputs().size(), (size_t)4); + EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3})); + EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 3})); + EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 3})); + EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6})); + } + + { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32); + + auto op = + g->addOp(input, std::nullopt, 3, vector{1, 2, 2}); + EXPECT_EQ(op->getOutputs().size(), (size_t)3); + EXPECT_EQ(op->numOutputs(), 3); + EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3})); + EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 6})); + EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 6})); + } +} + +} // namespace infini \ No newline at end of file