From fe14c91f54023251ca9f653c124c1f9fa16a8a00 Mon Sep 17 00:00:00 2001 From: wendy12022 Date: Thu, 29 Sep 2022 14:44:20 +0800 Subject: [PATCH] ADD: Gather operator and cuda kernel. (#41) fix a memory leak. add tests. ADD gather cuda kernel. ADD gather operator Co-authored-by: Haojie Wang --- include/core/runtime.h | 4 +- include/cuda/gather.h | 17 ++ include/operators/gather.h | 24 +++ src/kernels/cuda/gather.cc | 48 +++++ src/kernels/cuda/gather.cu | 47 +++++ src/operators/gather.cc | 85 +++++++++ test/kernels/cuda/test_cuda_gather.cc | 244 ++++++++++++++++++++++++++ test/operators/test_gather.cc | 19 ++ 8 files changed, 486 insertions(+), 2 deletions(-) create mode 100644 include/cuda/gather.h create mode 100644 include/operators/gather.h create mode 100644 src/kernels/cuda/gather.cc create mode 100644 src/kernels/cuda/gather.cu create mode 100644 src/operators/gather.cc create mode 100644 test/kernels/cuda/test_cuda_gather.cc create mode 100644 test/operators/test_gather.cc diff --git a/include/core/runtime.h b/include/core/runtime.h index b47b6587..f36f4ac8 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -69,13 +69,13 @@ class RuntimeObj : public std::enable_shared_from_this { // TODO: unify these copy APIs virtual void copyBlobFromCPU(void *dst, const void *src, size_t bytes) const = 0; + virtual void copyBlobToCPU(void *dst, const void *src, + size_t bytes) const = 0; protected: void printProfilingData(double totTime, const std::map &opTime, const std::map &opCnt) const; - virtual void copyBlobToCPU(void *dst, const void *src, - size_t bytes) const = 0; virtual void copyBlobInsideRuntime(void *dst, const void *src, size_t bytes) const = 0; }; diff --git a/include/cuda/gather.h b/include/cuda/gather.h new file mode 100644 index 00000000..80910117 --- /dev/null +++ b/include/cuda/gather.h @@ -0,0 +1,17 @@ +#pragma once + +typedef struct { + int *indexValue; + int axis; + int inNDim; + int outNDim; + int idxNDim; + int outDim[4]; + int idxDim[4]; + int idxStride[4]; + int inStride[4]; +} GatherMetaData; + +namespace infini { +void gather_kernel(float *in, float *out, GatherMetaData metaData, int num); +} \ No newline at end of file diff --git a/include/operators/gather.h b/include/operators/gather.h new file mode 100644 index 00000000..1a93f2af --- /dev/null +++ b/include/operators/gather.h @@ -0,0 +1,24 @@ +#pragma once + +#include "core/operator.h" + +namespace infini { +class GatherObj : public OperatorObj { + int axis; + + public: + GatherObj(GraphObj *graph, Tensor input, Tensor index, Tensor output, + int axis); + std::string toString() const override; + int numInputs() const override { return 2; } + int numOutputs() const override { return 1; } + optional> inferShape(const TensorVec &inputs) const override; + int getAxis() const { return axis; } + vector inferDataType(const TensorVec &inputs) const override; + + private: + bool CheckIndexValid() const; + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini \ No newline at end of file diff --git a/src/kernels/cuda/gather.cc b/src/kernels/cuda/gather.cc new file mode 100644 index 00000000..5be2767e --- /dev/null +++ b/src/kernels/cuda/gather.cc @@ -0,0 +1,48 @@ +#include "operators/gather.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +#include "cuda/gather.h" + +namespace infini { + +void initGatherMetaData(GatherMetaData &metaData, const Operator &_op) { + memset(&metaData, 0, sizeof(metaData)); + auto op = as(_op); + auto in = op->getInputs(0); + auto index = op->getInputs(1); + auto out = op->getOutput(); + metaData.indexValue = index->getRawDataPtr(); + metaData.axis = op->getAxis(); + metaData.inNDim = in->getDims().size(); + metaData.outNDim = out->getDims().size(); + metaData.idxNDim = index->getDims().size(); + for (int i = 0; i < metaData.outNDim; ++i) + metaData.outDim[i] = out->getDims()[i]; + for (int i = 0; i < metaData.idxNDim; ++i) { + metaData.idxDim[i] = index->getDims()[i]; + metaData.idxStride[i] = index->getStride()[i]; + } + for (int i = 0; i < metaData.inNDim; ++i) { + metaData.inStride[i] = in->getStride()[i]; + } +} + +class GatherCuda : public CudaKernelWithoutConfig { + void compute(const Operator &op, + const RuntimeObj *_context) const override { + + auto input = op->getInputs(0); + auto index = op->getInputs(1); + + GatherMetaData metaData; + initGatherMetaData(metaData, op); + + auto inData = input->getRawDataPtr(); + auto outData = op->getOutput()->getRawDataPtr(); + gather_kernel(inData, outData, metaData, op->getOutput()->size()); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Gather, DataType::Float32, GatherCuda, + "Gather_CUDA_Float32"); +} // namespace infini diff --git a/src/kernels/cuda/gather.cu b/src/kernels/cuda/gather.cu new file mode 100644 index 00000000..00b382d7 --- /dev/null +++ b/src/kernels/cuda/gather.cu @@ -0,0 +1,47 @@ +#include "cuda/cuda_common.h" +#include "cuda/gather.h" + +__device__ int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) { + int offset = 0; + for (int i = metaData.inNDim - 1, k = metaData.outNDim - 1; i >= 0; --i) { + int idx = 0; + if (i == metaData.axis) { + int idxOffset = 0; + for (int j = metaData.idxNDim - 1; j >= 0; --j) { + int p = gOffset % metaData.idxDim[j]; + gOffset = gOffset / metaData.idxDim[j]; + idxOffset += p * metaData.idxStride[j]; + } + + idx = metaData.indexValue[idxOffset]; + k = k - metaData.idxNDim; + + } else { + idx = gOffset % metaData.outDim[k]; + gOffset = gOffset / metaData.outDim[k]; + --k; + } + offset += idx * metaData.inStride[i]; + } + return offset; +} + +__global__ void _gather_kernel(float *in, float *out, GatherMetaData metaData, + int num) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + while (tid < num) { + int offset = gatheredOffset2Offset(tid, metaData); + out[tid] = in[offset]; + tid += stride; + } +} + +namespace infini { +void gather_kernel(float *in, float *out, GatherMetaData metaData, int num) { + int blockSize = 32 * 16; + int gridSize = (num + blockSize - 1) / blockSize; + + _gather_kernel<<>>(in, out, metaData, num); +} +} // namespace infini \ No newline at end of file diff --git a/src/operators/gather.cc b/src/operators/gather.cc new file mode 100644 index 00000000..225db6fd --- /dev/null +++ b/src/operators/gather.cc @@ -0,0 +1,85 @@ +#include "operators/gather.h" + +namespace infini { +GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor index, Tensor output, + int axis) + : OperatorObj(OpType::Gather, {input, index}, {output}), axis(axis) { + IT_ASSERT(checkValid(graph)); +} + +optional> GatherObj::inferShape(const TensorVec &inputs) const { + auto dims0 = inputs[0]->getDims(); + auto dims1 = inputs[1]->getDims(); + + if (axis < 0) + IT_TODO_HALT(); + + if ((size_t)axis >= dims0.size()) + return {}; + + IT_ASSERT(CheckIndexValid()); + + Shape dim = dims0; + dim.erase(dim.begin() + axis); + dim.insert(dim.begin() + axis, dims1.begin(), dims1.end()); + return {{dim}}; +} + +vector GatherObj::inferDataType(const TensorVec &inputs) const { + IT_ASSERT(inputs.size() == 2); + auto index = inputs[1]; + IT_ASSERT(index->getDType() == DataType::UInt32); + return {inputs[0]->getDType()}; +} + +// TODO:should check everytime index updated. +bool GatherObj::CheckIndexValid() const { + auto index = inputs[1]; + if (index->getDataBlob() == nullptr) + return true; + + Runtime runtime = CpuRuntimeObj::getInstance(); + int *data = (int *)runtime->alloc(index->getBytes()); + index->getRuntime()->copyBlobToCPU( + (void *)data, index->getRawDataPtr(), index->getBytes()); + + bool ret = true; + auto value = inputs[0]->getDims()[axis]; + for (size_t i = 0; i < index->size(); ++i) { + if (data[i] < 0 || data[i] >= value) { + ret = false; + break; + } + } + runtime->dealloc(data); + return ret; +} + +std::string GatherObj::toString() const { + std::ostringstream os; + os << "Gather" + << "[" << getGuid() << "]"; + os << "("; + if (inputs.size() == 2) { + os << vecToString(inputs[0]->getDims()) << ","; + os << vecToString(inputs[1]->getDims()) << ","; + } + os << "axis=" << axis << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} +vector GatherObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.emplace(ret.begin(), enum_to_underlying(type)); + for (auto it : inputs[1]->getDims()) + ret.emplace_back(it); + ret.emplace_back(axis); + return ret; +} + +vector GatherObj::getOpAttrVector() const { + return {enum_to_underlying(type), axis}; +} + +} // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_cuda_gather.cc b/test/kernels/cuda/test_cuda_gather.cc new file mode 100644 index 00000000..09e4b28e --- /dev/null +++ b/test/kernels/cuda/test_cuda_gather.cc @@ -0,0 +1,244 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "cuda/gather.h" +#include "operators/gather.h" + +#include "test.h" +namespace infini { +/* +test1: +input = [ + [1, 2], + [3, 4], + [5, 6], + ] + indices = [ + [0, 1], + [1, 2], + ] + output = [ + [ + [1, 2], + [3, 4], + ], + [ + [3, 4], + [5, 6], + ], + ] + axis=0 + */ + +/* +test2 +input = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ] + indices = [ + [0, 2], + ] + axis = 1, + output = [ + [[0, 2]], + [[3, 5]], + [[6, 8]], + ] +*/ +/* +test3 +input=[[[ 0, 1], + [ 2, 3], + [ 4, 5], + [ 6, 7]], + + [[ 8, 9], + [10, 11], + [12, 13], + [14, 15]]] //(2,4,2) +indices=[[0],[3],[1]] //(3,1) +axis=1 +output= + +*/ + +int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) { + int offset = 0; + for (int i = metaData.inNDim - 1, k = metaData.outNDim - 1; i >= 0; --i) { + int idx = 0; + if (i == metaData.axis) { + int idxOffset = 0; + for (int j = metaData.idxNDim - 1; j >= 0; --j) { + int p = gOffset % metaData.idxDim[j]; + gOffset = gOffset / metaData.idxDim[j]; + idxOffset += p * metaData.idxStride[j]; + } + + idx = metaData.indexValue[idxOffset]; + k = k - metaData.idxNDim; + + } else { + idx = gOffset % metaData.outDim[k]; + gOffset = gOffset / metaData.outDim[k]; + --k; + } + offset += idx * metaData.inStride[i]; + } + return offset; +} + +TEST(Gather, offsetTrans) { + { + GatherMetaData meta; + int data[] = {0, 1, 1, 2}; + meta.indexValue = data; + meta.axis = 0; + meta.inNDim = 2; + meta.outNDim = 3; + meta.idxNDim = 2; + int tmp[] = {2, 2, 2, 0}; + memcpy(&meta.outDim, &tmp, sizeof(tmp)); + int tmp2[] = {2, 2, 0, 0}; + memcpy(&meta.idxDim, &tmp2, sizeof(tmp)); + int tmp3[] = {2, 1, 0, 0}; + memcpy(&meta.idxStride, &tmp3, sizeof(tmp)); + memcpy(&meta.inStride, &tmp3, sizeof(tmp)); + + EXPECT_EQ(gatheredOffset2Offset(0, meta), 0); + EXPECT_EQ(gatheredOffset2Offset(1, meta), 1); + EXPECT_EQ(gatheredOffset2Offset(2, meta), 2); + EXPECT_EQ(gatheredOffset2Offset(3, meta), 3); + EXPECT_EQ(gatheredOffset2Offset(4, meta), 2); + EXPECT_EQ(gatheredOffset2Offset(5, meta), 3); + EXPECT_EQ(gatheredOffset2Offset(6, meta), 4); + EXPECT_EQ(gatheredOffset2Offset(7, meta), 5); + } + { + GatherMetaData meta; + int data[] = {0, 2}; + meta.indexValue = data; + meta.axis = 1; + meta.inNDim = 2; + meta.outNDim = 3; + meta.idxNDim = 2; + + int tmp[] = {3, 1, 2, 0}; + memcpy(&meta.outDim, &tmp, sizeof(tmp)); + int tmp2[] = {1, 2, 0, 0}; + memcpy(&meta.idxDim, &tmp2, sizeof(tmp2)); + int tmp3[] = {2, 1, 0, 0}; + memcpy(&meta.idxStride, &tmp3, sizeof(tmp3)); + int tmp4[] = {3, 1, 0, 0}; + memcpy(&meta.inStride, &tmp4, sizeof(tmp4)); + + EXPECT_EQ(gatheredOffset2Offset(0, meta), 0); + EXPECT_EQ(gatheredOffset2Offset(1, meta), 2); + EXPECT_EQ(gatheredOffset2Offset(2, meta), 3); + EXPECT_EQ(gatheredOffset2Offset(3, meta), 5); + EXPECT_EQ(gatheredOffset2Offset(4, meta), 6); + EXPECT_EQ(gatheredOffset2Offset(5, meta), 8); + } + { + GatherMetaData meta; + int data[] = {0, 3, 1}; + meta.indexValue = data; + meta.axis = 1; + meta.inNDim = 3; + meta.outNDim = 4; + meta.idxNDim = 2; + + int tmp[] = {2, 3, 1, 2}; + memcpy(&meta.outDim, &tmp, sizeof(tmp)); + int tmp2[] = {3, 1, 0, 0}; + memcpy(&meta.idxDim, &tmp2, sizeof(tmp2)); + int tmp3[] = {1, 1, 0, 0}; + memcpy(&meta.idxStride, &tmp3, sizeof(tmp3)); + int tmp4[] = {8, 2, 1, 0}; + memcpy(&meta.inStride, &tmp4, sizeof(tmp4)); + + EXPECT_EQ(gatheredOffset2Offset(0, meta), 0); + EXPECT_EQ(gatheredOffset2Offset(1, meta), 1); + EXPECT_EQ(gatheredOffset2Offset(2, meta), 6); + EXPECT_EQ(gatheredOffset2Offset(3, meta), 7); + EXPECT_EQ(gatheredOffset2Offset(4, meta), 2); + EXPECT_EQ(gatheredOffset2Offset(5, meta), 3); + EXPECT_EQ(gatheredOffset2Offset(6, meta), 8); + EXPECT_EQ(gatheredOffset2Offset(7, meta), 9); + EXPECT_EQ(gatheredOffset2Offset(8, meta), 14); + EXPECT_EQ(gatheredOffset2Offset(9, meta), 15); + EXPECT_EQ(gatheredOffset2Offset(10, meta), 10); + EXPECT_EQ(gatheredOffset2Offset(11, meta), 11); + } +} + +TEST(Gather, Cuda) { + { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto input = gCpu->addTensor({3, 2}, DataType::Float32); + auto index = gCpu->addTensor({2, 2}, DataType::UInt32); + gCpu->dataMalloc(); + input->copyData(vector{1, 2, 3, 4, 5, 6}); + index->copyData(vector{0, 1, 1, 2}); + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp( + gCuda->cloneTensor(input), gCuda->cloneTensor(index), nullptr, 0); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData(vector{1, 2, 3, 4, 3, 4, 5, 6})); + } + { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto input = gCpu->addTensor({3, 3}, DataType::Float32); + auto index = gCpu->addTensor({1, 2}, DataType::UInt32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + index->copyData(vector{0, 2}); + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp( + gCuda->cloneTensor(input), gCuda->cloneTensor(index), nullptr, 1); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData(vector{0, 2, 3, 5, 6, 8})); + } + { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32); + auto index = gCpu->addTensor({3, 1}, DataType::UInt32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + index->copyData(vector{0, 3, 1}); + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp( + gCuda->cloneTensor(input), gCuda->cloneTensor(index), nullptr, 1); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData( + vector{0, 1, 6, 7, 2, 3, 8, 9, 14, 15, 10, 11})); + } +} + +} // namespace infini \ No newline at end of file diff --git a/test/operators/test_gather.cc b/test/operators/test_gather.cc new file mode 100644 index 00000000..2bc8fc2e --- /dev/null +++ b/test/operators/test_gather.cc @@ -0,0 +1,19 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/gather.h" + +#include "test.h" + +namespace infini { + +TEST(Gather, ShapeInference) { + Runtime runtime = CpuRuntimeObj::getInstance(); + + Graph g = make_ref(runtime); + Tensor i = g->addTensor({1, 3, 4, 4}, DataType::UInt32); + Tensor index = g->addTensor({2, 1, 2}, DataType::UInt32); + auto op = g->addOp(i, index, nullptr, 1); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4})); +} +} // namespace infini \ No newline at end of file