diff --git a/include/core/operator.h b/include/core/operator.h index b2db95ce..a40de331 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -186,6 +186,10 @@ class OperatorObj : public Object { IT_ASSERT(outputs.size() == 1, "Unimplemented"); return outputs[0]; } + Tensor getOutput(size_t i) const { + IT_ASSERT(i < outputs.size(), "Index exceeded"); + return outputs.at(i); + } OpType getOpType() const { return type; } // HACK: set correct data type DataType getDType() const { return getInputs(0)->getDType(); } diff --git a/include/cuda/cuda_element_wise.h b/include/cuda/cuda_element_wise.h index b6dede22..d51a04cf 100644 --- a/include/cuda/cuda_element_wise.h +++ b/include/cuda/cuda_element_wise.h @@ -1,25 +1,6 @@ #pragma once -#include "operators/element_wise.h" - namespace infini { void div_kernel(float *a, float *b, float *c, int num); void pow_kernel(float *a, float *b, float *c, int num); - -void element_wise_kernel(const Operator &_op) { - auto op = as(_op); - float *const aData = (op->getInputs(0)->getRawDataPtr()); - float *const bData = (op->getInputs(1)->getRawDataPtr()); - float *const cData = (op->getOutput()->getRawDataPtr()); - - auto dim = op->getInputs(0)->getDims(); - int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; - if (op->getOpType() == OpType::Div) - div_kernel(aData, bData, cData, n * c * h * w); - else if (op->getOpType() == OpType::Pow) - pow_kernel(aData, bData, cData, n * c * h * w); - else - IT_TODO_HALT(); -} - }; // namespace infini \ No newline at end of file diff --git a/include/cuda/cuda_split_concat.h b/include/cuda/cuda_split_concat.h new file mode 100644 index 00000000..454a8f52 --- /dev/null +++ b/include/cuda/cuda_split_concat.h @@ -0,0 +1,35 @@ + +#pragma once +#include + +const int BATCH_SIZE = 32; // parallel tensor number. +const int DIM_MAX_SIZE = 4; + +// Concat operator acts like element tensors composing to one big tensor,and +// split operator acts like one big tensor being composed by element +// tensors. +struct ElementTensorMetadata { + float *data[BATCH_SIZE]; + int dimBgNo[BATCH_SIZE]; // the dimention begin no of the element tensor in + // the composed tensor. + int dimSize[BATCH_SIZE]; // the dimention size of the element tensor. + int nElements[BATCH_SIZE]; // the number of elements of the element tensor. + void print() const { + for (int i = 0; i < BATCH_SIZE; i++) + printf("%d:(data=%p,dimBgNo=%d,dimSize=%d,nElements=%d)\n", i, + data[i], dimBgNo[i], dimSize[i], nElements[i]); + } +}; + +struct ComposedTensorMetadata { + int dimSize[DIM_MAX_SIZE]; + int stride[DIM_MAX_SIZE]; + float *data; +}; + +namespace infini { +void split_concat_kernel(const ElementTensorMetadata &eleMeta, + const ComposedTensorMetadata &compMeta, int dim, + int batchSize, int nDims, bool isSplit); + +} // namespace infini \ No newline at end of file diff --git a/include/operators/concat.h b/include/operators/concat.h new file mode 100644 index 00000000..ebca158c --- /dev/null +++ b/include/operators/concat.h @@ -0,0 +1,22 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class ConcatObj : public OperatorObj { + int dim; + + public: + ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim); + + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return inputs.size(); } + int numOutputs() const override { return 1; } + int getDim() const { return dim; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini \ No newline at end of file diff --git a/include/operators/pad.h b/include/operators/pad.h index 2e80666a..d60443eb 100644 --- a/include/operators/pad.h +++ b/include/operators/pad.h @@ -15,7 +15,7 @@ class PadObj : public OperatorObj { std::string toString() const override; int numInputs() const override { return 1; } int numOutputs() const override { return 1; } - Shape PadObj::getPads() const { return pads; } + Shape getPads() const { return pads; } private: vector getWorkloadVector() const override; diff --git a/include/operators/split.h b/include/operators/split.h new file mode 100644 index 00000000..980060ce --- /dev/null +++ b/include/operators/split.h @@ -0,0 +1,25 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class SplitObj : public OperatorObj { + int dim, num; // split dim;Average split num or outputs size + vector ratio; // output dim ratio + public: + SplitObj(GraphObj *graph, Tensor input, std::optional outputs, + int dim, int num); + SplitObj(GraphObj *graph, Tensor input, std::optional outputs, + int dim, const vector &ratio); + + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return num; } + int getDim() const { return dim; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini \ No newline at end of file diff --git a/src/kernels/cuda/element_wise.cc b/src/kernels/cuda/element_wise.cc index 6e4130e0..374fcf5d 100644 --- a/src/kernels/cuda/element_wise.cc +++ b/src/kernels/cuda/element_wise.cc @@ -78,7 +78,19 @@ class MulCudnn : public ElementWiseCudnn { class ElementWiseCuda : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { - element_wise_kernel(_op); + auto op = as(_op); + float *const aData = (op->getInputs(0)->getRawDataPtr()); + float *const bData = (op->getInputs(1)->getRawDataPtr()); + float *const cData = (op->getOutput()->getRawDataPtr()); + + auto dim = op->getInputs(0)->getDims(); + int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; + if (op->getOpType() == OpType::Div) + div_kernel(aData, bData, cData, n * c * h * w); + else if (op->getOpType() == OpType::Pow) + pow_kernel(aData, bData, cData, n * c * h * w); + else + IT_TODO_HALT(); } }; diff --git a/src/kernels/cuda/split_concat.cc b/src/kernels/cuda/split_concat.cc new file mode 100644 index 00000000..b7a12e5a --- /dev/null +++ b/src/kernels/cuda/split_concat.cc @@ -0,0 +1,81 @@ +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_split_concat.h" +#include "operators/concat.h" +#include "operators/split.h" +#include + +namespace infini { + +void initComposedTensorMetadata(ComposedTensorMetadata &metadata, + Tensor tensor) { + int nDims = tensor->getDims().size(); + auto strides = tensor->getStride(); + IT_ASSERT(strides.size() == (size_t)nDims); + for (int i = 0; i < nDims; ++i) { + metadata.dimSize[i] = tensor->getDims().at(i); + metadata.stride[i] = strides.at(i); + } + metadata.data = tensor->getRawDataPtr(); +} + +void initElementTensorMetadata(ElementTensorMetadata &metadata, + TensorVec tensors, int idx, int dim, + int &dimBgIdx, int &batchCounter) { + int nTensors = tensors.size(); + for (; batchCounter < BATCH_SIZE && idx + batchCounter < nTensors; + ++batchCounter) { + auto tensor = tensors.at(idx + batchCounter); + auto dimSize = tensor->getDims()[dim]; + metadata.data[batchCounter] = tensor->getRawDataPtr(); + metadata.dimBgNo[batchCounter] = dimBgIdx; + metadata.dimSize[batchCounter] = dimSize; + metadata.nElements[batchCounter] = tensor->size(); + dimBgIdx += dimSize; + } +} + +class CudaCompute { + public: + void do_compute(Tensor composedTensor, TensorVec elementsTensor, int dim, + int nDims, bool isSplit) const { + IT_ASSERT(nDims <= DIM_MAX_SIZE); + + ComposedTensorMetadata composedMetadata; + initComposedTensorMetadata(composedMetadata, composedTensor); + + int dimBgNo = 0; + int nElemets = elementsTensor.size(); + for (int i = 0; i < nElemets; i += BATCH_SIZE) { + ElementTensorMetadata elemMetadata; + int batchCounter = 0; + initElementTensorMetadata(elemMetadata, elementsTensor, i, dim, + dimBgNo, batchCounter); + split_concat_kernel(elemMetadata, composedMetadata, dim, + batchCounter, nDims, isSplit); + } + } +}; + +class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + do_compute(_op->getOutput(), _op->getInputs(), + as(_op)->getDim(), + _op->getOutput()->getDims().size(), false); + } +}; + +class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + do_compute(_op->getInputs(0), _op->getOutputs(), + as(_op)->getDim(), + _op->getInputs(0)->getDims().size(), true); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Concat, DataType::Float32, ConcatCuda, + "Concat_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Split, DataType::Float32, SplitCuda, + "Split_CUDA_Float32"); +} // namespace infini \ No newline at end of file diff --git a/src/kernels/cuda/split_concat.cu b/src/kernels/cuda/split_concat.cu new file mode 100644 index 00000000..e71e7890 --- /dev/null +++ b/src/kernels/cuda/split_concat.cu @@ -0,0 +1,71 @@ +#include "cuda/cuda_common.h" +#include "cuda/cuda_split_concat.h" + +int getMultiProcessorCount() { + int cur_device; + checkCudaError(cudaGetDevice(&cur_device)); + + struct cudaDeviceProp prop; + checkCudaError(cudaGetDeviceProperties(&prop, cur_device)); + return prop.multiProcessorCount; +} + +__host__ __device__ int +elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim, + int nDim, ComposedTensorMetadata wholeMeta) { + int offset = 0; + +#pragma unroll + for (int i = nDim - 1; i >= 1; --i) { + int size = (i == dim) ? dimSize : wholeMeta.dimSize[i]; + int p = elementIndex % size; + int oP = (i == dim) ? (p + dimBgNo) : p; + elementIndex = (elementIndex - p) / size; + offset += oP * wholeMeta.stride[i]; + } + + return offset + elementIndex * wholeMeta.stride[0]; +} + +__global__ void _split_concat_kernel(ElementTensorMetadata elemMeta, + ComposedTensorMetadata compMeta, int dim, + int nDims, bool isSplit) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int nElements = elemMeta.nElements[blockIdx.y]; + if (tid >= nElements) + return; + + auto dimBgNo = elemMeta.dimBgNo[blockIdx.y]; + auto dimSize = elemMeta.dimSize[blockIdx.y]; + float *elemData = elemMeta.data[blockIdx.y]; + int stride = gridDim.x * blockDim.x; + + while (tid < nElements) { + int Offset = + elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta); + // copy data from input to output + // for split:input is composed tensor;for concat:input is element + // tensors. + if (isSplit) + elemData[tid] = compMeta.data[Offset]; + else + compMeta.data[Offset] = elemData[tid]; + tid += stride; + } +} + +namespace infini { + +void split_concat_kernel(const ElementTensorMetadata &eleMeta, + const ComposedTensorMetadata &compMeta, int dim, + int batchSize, int nDims, bool isSplit) { + dim3 blockSize = dim3(32 * 16); + + // y dim is number of tensors. + dim3 gridSize(getMultiProcessorCount(), batchSize); + + _split_concat_kernel<<>>(eleMeta, compMeta, dim, nDims, + isSplit); +} + +} // namespace infini \ No newline at end of file diff --git a/src/operators/concat.cc b/src/operators/concat.cc new file mode 100644 index 00000000..98a5527a --- /dev/null +++ b/src/operators/concat.cc @@ -0,0 +1,58 @@ +#include "operators/concat.h" + +namespace infini { +ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim) + : OperatorObj(OpType::Concat, inputs, {output}), dim(dim) { + IT_ASSERT(checkValid(graph)); +} + +optional> ConcatObj::inferShape(const TensorVec &inputs) const { + IT_ASSERT(inputs.size() > 1); + Shape dims = inputs[0]->getDims(); + ShapeElem n = dims.at(dim); + for (auto itr = inputs.begin() + 1; itr != inputs.end(); ++itr) { + auto input = *itr; + auto iDims = input->getDims(); + if (dims.size() != iDims.size()) + return {}; + int nDims = dims.size(); + for (auto i = 0; i < nDims; i++) { + if (i == dim) { + n += iDims.at(i); + continue; + } + if (iDims.at(i) != dims.at(i)) + return {}; + } + } + dims[dim] = n; + return {{dims}}; +} + +std::string ConcatObj::toString() const { + std::ostringstream os; + os << "Concat[" << getGuid() << "]"; + os << "("; + for (auto input : inputs) + os << vecToString(input->getDims()) << ","; + os << "dim=" << dim << ","; + os << "input="; + for (auto input : inputs) + os << input->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector ConcatObj::getWorkloadVector() const { + vector ret = getOutput()->getDims(); + ret.emplace(ret.begin(), (int)inputs.size()); + ret.emplace(ret.begin(), dim); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +vector ConcatObj::getOpAttrVector() const { + return {enum_to_underlying(type), dim}; +} + +} // namespace infini \ No newline at end of file diff --git a/src/operators/split.cc b/src/operators/split.cc new file mode 100644 index 00000000..f387d4e4 --- /dev/null +++ b/src/operators/split.cc @@ -0,0 +1,89 @@ +#include "operators/split.h" +#include + +namespace infini { +SplitObj::SplitObj(GraphObj *graph, Tensor input, + std::optional outputs, int dim, int num) + : OperatorObj(OpType::Split, {input}, + ((!outputs) ? TensorVec{nullptr} : (*outputs))), + dim(dim), num(num), ratio({}) { + int dimSize = input->getDims().at(dim); + int pieceSize = dimSize / num; + int lastSize = dimSize - pieceSize * num; + + if (lastSize > 0) { + ratio = std::vector(num - 1, pieceSize); + ratio.emplace_back(lastSize + pieceSize); + } else + ratio = std::vector(num, pieceSize); + + if (!outputs) { + TensorVec tmp(num, nullptr); + this->outputs = tmp; + } + IT_ASSERT(checkValid(graph)); +} + +SplitObj::SplitObj(GraphObj *graph, Tensor input, + std::optional outputs, int dim, + const vector &ratio) + : OperatorObj(OpType::Split, {input}, + ((!outputs) ? TensorVec{nullptr} : (*outputs))), + dim(dim), num(-1), ratio(ratio) { + num = ratio.size(); + if (!outputs) { + TensorVec tmp(num, nullptr); + this->outputs = tmp; + } + IT_ASSERT(checkValid(graph)); +} + +optional> SplitObj::inferShape(const TensorVec &inputs) const { + if (num == -1 || ratio.size() == 0) + return {}; + auto inputDims = inputs[0]->getDims(); + int totalSize = inputDims.at(dim); + int ratioSum = std::accumulate(ratio.begin(), ratio.end(), 0); + if (totalSize % ratioSum != 0) + return {}; + + int pieceSize = totalSize / ratioSum; + + vector ret; + Shape outShape = inputDims; + for (int i = 0; i < num; i++) { + outShape[dim] = pieceSize * ratio.at(i); + ret.push_back(outShape); + } + return {ret}; +} + +vector SplitObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.emplace(ret.begin(), enum_to_underlying(type)); + ret.emplace_back(dim); + ret.emplace_back(num); + return ret; +} + +vector SplitObj::getOpAttrVector() const { + return {enum_to_underlying(type), dim, num}; +} + +string SplitObj::toString() const { + std::ostringstream os; + os << "Split[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "dim=" << dim << ","; + os << "num= " << num << ","; + os << "ratio= " << vecToString(ratio) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output="; + for (auto i = 0; i < num; i++) + os << outputs[i]->getGuid() << ","; + os << ")"; + return os.str(); +} + +} // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_cuda_concat.cc b/test/kernels/cuda/test_cuda_concat.cc new file mode 100644 index 00000000..13ccb67a --- /dev/null +++ b/test/kernels/cuda/test_cuda_concat.cc @@ -0,0 +1,76 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/concat.h" + +#include "test.h" + +namespace infini { +/* +int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize, + int concatDim, int outputDimSize[4], + int outputStride[4], int nDim) { + int offset = 0; + + for (int i = nDim - 1; i >= 1; --i) { + int size = (i == concatDim) ? dimSize : outputDimSize[i]; + int p = linearIndex % size; + int oP = (i == concatDim) ? (p + dimBgNo) : p; + linearIndex = (linearIndex - p) / size; + + offset += oP * outputStride[i]; + } + + return offset + linearIndex * outputStride[0]; +} + +TEST(Concat, OffsetTrans) { + int dimSize[] = {2, 3}; + int strides[] = {3, 1}; + int catDim = 1, nDim = 2; + EXPECT_EQ(inputOffset2CatOffset(0, 0, 1, catDim, dimSize, strides, nDim), + 0); + EXPECT_EQ(inputOffset2CatOffset(1, 0, 1, catDim, dimSize, strides, nDim), + 3); + EXPECT_EQ(inputOffset2CatOffset(0, 1, 2, catDim, dimSize, strides, nDim), + 1); + EXPECT_EQ(inputOffset2CatOffset(1, 1, 2, catDim, dimSize, strides, nDim), + 2); + EXPECT_EQ(inputOffset2CatOffset(2, 1, 2, catDim, dimSize, strides, nDim), + 4); + EXPECT_EQ(inputOffset2CatOffset(3, 1, 2, catDim, dimSize, strides, nDim), + 5); +} +*/ +TEST(Concat, Cuda) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto t1 = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32); + auto t2 = gCpu->addTensor({2, 2, 1, 1}, DataType::Float32); + auto t3 = gCpu->addTensor({2, 2, 2, 1}, DataType::Float32); + gCpu->dataMalloc(); + t1->setData(IncrementalGenerator()); + t2->setData(OneGenerator()); + t3->setData(OneGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp(TensorVec{gCuda->cloneTensor(t1), + gCuda->cloneTensor(t2), + gCuda->cloneTensor(t3)}, + nullptr, 2); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE( + oCpu->equalData(vector{0, 1, 2, 1, 1, 1, 3, 4, 5, 1, 1, 1, + 6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1})); +} + +} // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_cuda_split.cc b/test/kernels/cuda/test_cuda_split.cc new file mode 100644 index 00000000..38e409f0 --- /dev/null +++ b/test/kernels/cuda/test_cuda_split.cc @@ -0,0 +1,40 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/split.h" + +#include "test.h" + +namespace infini { + +TEST(Split, Cuda) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({2, 10, 2, 1}, DataType::Float32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = + gCuda->addOp(gCuda->cloneTensor(input), std::nullopt, 1, 3); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + EXPECT_EQ(op->getOutputs().size(), (size_t)3); + auto o0Cpu = gCpu->cloneTensor(op->getOutput(0)); + auto o1Cpu = gCpu->cloneTensor(op->getOutput(1)); + auto o2Cpu = gCpu->cloneTensor(op->getOutput(2)); + EXPECT_TRUE(o0Cpu->equalData( + vector{0, 1, 2, 3, 4, 5, 20, 21, 22, 23, 24, 25})); + EXPECT_TRUE(o1Cpu->equalData( + vector{6, 7, 8, 9, 10, 11, 26, 27, 28, 29, 30, 31})); + EXPECT_TRUE(o2Cpu->equalData(vector{ + 12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39})); +} + +} // namespace infini \ No newline at end of file diff --git a/test/operators/test_concat.cc b/test/operators/test_concat.cc new file mode 100644 index 00000000..15ef074b --- /dev/null +++ b/test/operators/test_concat.cc @@ -0,0 +1,17 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/concat.h" +#include "test.h" + +namespace infini { +TEST(Concat, ShapeInfer) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto t1 = g->addTensor({1, 3, 2, 4}, DataType::Float32); + auto t2 = g->addTensor({1, 3, 2, 5}, DataType::Float32); + + auto op = g->addOp(TensorVec{t1, t2}, nullptr, 3); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 9})); +} + +} // namespace infini \ No newline at end of file diff --git a/test/operators/test_split.cc b/test/operators/test_split.cc new file mode 100644 index 00000000..d0e76031 --- /dev/null +++ b/test/operators/test_split.cc @@ -0,0 +1,38 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/split.h" + +#include "test.h" + +namespace infini { +TEST(Split, ShapeInfer) { + { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32); + + auto op = g->addOp(input, std::nullopt, 3, 4); + EXPECT_EQ(op->numOutputs(), 4); + EXPECT_EQ(op->getOutputs().size(), (size_t)4); + EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3})); + EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 3})); + EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 3})); + EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6})); + } + + { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32); + + auto op = + g->addOp(input, std::nullopt, 3, vector{1, 2, 2}); + EXPECT_EQ(op->getOutputs().size(), (size_t)3); + EXPECT_EQ(op->numOutputs(), 3); + EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3})); + EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 6})); + EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 6})); + } +} + +} // namespace infini \ No newline at end of file