From 595a9906d2b1df7697956e0f27983122f8df5d13 Mon Sep 17 00:00:00 2001 From: xgqdut2016 <140036308+xgqdut2016@users.noreply.github.com> Date: Fri, 24 Nov 2023 09:24:25 +0800 Subject: [PATCH] add infer index function (#175) Co-authored-by: Haojie Wang --- include/cuda/cuda_where.h | 5 +- include/utils/broadcast_shape.h | 6 +-- src/kernels/cuda/where.cc | 8 +-- src/kernels/cuda/where.cu | 94 ++++++++++++++------------------- 4 files changed, 52 insertions(+), 61 deletions(-) diff --git a/include/cuda/cuda_where.h b/include/cuda/cuda_where.h index 15ad29ec..bc6d3e81 100644 --- a/include/cuda/cuda_where.h +++ b/include/cuda/cuda_where.h @@ -5,7 +5,8 @@ namespace infini { void whereKernel(const float *inputX, const float *inputY, const uint8_t *condition, float *output, int nDims, - SmallArray inputXShape, SmallArray inputYShape, - SmallArray conditionShape, SmallArray outputShape); + int outputsize, SmallArray inputXShape, SmallArray inputYShape, + SmallArray conditionShape, SmallArray outputShape, int xSize, + int ySize, int cSize); }; // namespace infini diff --git a/include/utils/broadcast_shape.h b/include/utils/broadcast_shape.h index 1f45ddcc..e794ff90 100644 --- a/include/utils/broadcast_shape.h +++ b/include/utils/broadcast_shape.h @@ -3,11 +3,11 @@ namespace infini { void broadcastShape(const Shape &originShape, SmallArray &modifyShape, int nDims, int size) { - for (int i = nDims - 1; i >= 0; --i) { + for (int i = nDims - size - 1; i >= 0; --i) { modifyShape.data[i] = 1; } - for (int i = size - 1; i >= 0; --i) { - modifyShape.data[i + nDims - size] = originShape[i]; + for (int i = nDims - 1; i >= nDims - size; --i) { + modifyShape.data[i] = originShape[i - nDims + size]; } } diff --git a/src/kernels/cuda/where.cc b/src/kernels/cuda/where.cc index 9898ab7d..df5e4476 100644 --- a/src/kernels/cuda/where.cc +++ b/src/kernels/cuda/where.cc @@ -23,21 +23,23 @@ class WhereCuda : public CudaKernelWithoutConfig { const int xSize = op->getInputs(0)->getRank(); const int ySize = op->getInputs(1)->getRank(); const int cSize = op->getInputs(2)->getRank(); + int nDims = op->getOutput()->getDims().size(); IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); - + int outputsize = 1; SmallArray inputXShape, inputYShape, conditionShape, outputShape; for (int i = nDims - 1; i >= 0; --i) { outputShape.data[i] = opOutputShape[i]; + outputsize *= outputShape.data[i]; } - broadcastShape(opInputXShape, inputXShape, nDims, xSize); broadcastShape(opInputYShape, inputYShape, nDims, ySize); broadcastShape(opConditionShape, conditionShape, nDims, cSize); whereKernel((float *)inputXData, (float *)inputYData, (uint8_t *)conditionData, (float *)outputData, nDims, - inputXShape, inputYShape, conditionShape, outputShape); + outputsize, inputXShape, inputYShape, conditionShape, + outputShape, xSize, ySize, cSize); } }; diff --git a/src/kernels/cuda/where.cu b/src/kernels/cuda/where.cu index ce6579f8..ac8b514a 100644 --- a/src/kernels/cuda/where.cu +++ b/src/kernels/cuda/where.cu @@ -1,61 +1,40 @@ #include "cuda/cuda_common.h" #include "utils/small_array.h" +__device__ int inferIndex(infini::SmallArray inputShape, + infini::SmallArray outputShape, int nDims, int size, + int outputIdx) { + int inputIdx = 0; + int tempInput = 1; + int tempOutput = 1; + for (int i = nDims - 1; i >= nDims - size; --i) { + tempOutput = outputIdx % outputShape.data[i]; + if (inputShape.data[i] != 1) { + inputIdx += tempInput * tempOutput; + } + tempInput *= inputShape.data[i]; + outputIdx /= outputShape.data[i]; + } + return inputIdx; +} __global__ void _whereKernel(const float *inputX, const float *inputY, const uint8_t *condition, float *output, int nDims, int outputsize, infini::SmallArray inputXShape, infini::SmallArray inputYShape, infini::SmallArray conditionShape, - infini::SmallArray outputShape) { + infini::SmallArray outputShape, int xSize, + int ySize, int cSize) { int outputIdx = blockIdx.x * blockDim.x + threadIdx.x; if (outputIdx < outputsize) { - int inputXIdx = 0; - int temp_inputX = 1; + int conditionIdx = + inferIndex(conditionShape, outputShape, nDims, cSize, outputIdx); + int inputXIdx = + inferIndex(inputXShape, outputShape, nDims, xSize, outputIdx); - int inputYIdx = 0; - int temp_inputY = 1; + int inputYIdx = + inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx); - int conditionIdx = 0; - int temp_condition = 1; - - int tmp = 1; // stored s,k,j,i in order - int v = outputIdx; // v = i(JKS) + j(KS) + k(S) + s - for (int i = nDims - 1; i >= 0; --i) { - if (i == 0) { - tmp = v; // i = outputIdx/(JKS) - } else { - tmp = v % outputShape.data[i]; // store s,k,j in order - } - if (inputXShape.data[i] == 1) { - inputXIdx += 0; - } else { - inputXIdx += - tmp * - temp_inputX; // otherwise +i(JKS) or j(KS) or k(S) or s - } - temp_inputX *= inputXShape.data[i]; - //---------------------------- - if (inputYShape.data[i] == 1) { - inputYIdx += 0; - } else { - inputYIdx += - tmp * - temp_inputY; // otherwise +i(JKS) or j(KS) or k(S) or s - } - temp_inputY *= inputYShape.data[i]; - //-------------------------- - if (conditionShape.data[i] == 1) { - conditionIdx += 0; - } else { - conditionIdx += - tmp * - temp_condition; // otherwise +i(JKS) or j(KS) or k(S) or s - } - temp_condition *= conditionShape.data[i]; - //------------------------- - v = v / outputShape.data[i]; - } output[outputIdx] = condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx]; } @@ -64,17 +43,26 @@ __global__ void _whereKernel(const float *inputX, const float *inputY, namespace infini { void whereKernel(const float *inputX, const float *inputY, const uint8_t *condition, float *output, int nDims, - SmallArray inputXShape, SmallArray inputYShape, - SmallArray conditionShape, SmallArray outputShape) { - int outputsize = 1; - - for (int i = 0; i < nDims; i++) { - outputsize *= outputShape.data[i]; + int outputsize, SmallArray inputXShape, SmallArray inputYShape, + SmallArray conditionShape, SmallArray outputShape, int xSize, + int ySize, int cSize) { + int blocksize; + if (outputsize > 511) { + blocksize = 1024; + } else if (outputsize > 255) { + blocksize = 512; + } else if (outputsize > 127) { + blocksize = 256; + } else if (outputsize > 63) { + blocksize = 128; + } else if (outputsize > 31) { + blocksize = 64; + } else { + blocksize = 32; } - int blocksize = 32 * 16; int gridsize = (outputsize + blocksize - 1) / blocksize; _whereKernel<<>>( inputX, inputY, condition, output, nDims, outputsize, inputXShape, - inputYShape, conditionShape, outputShape); + inputYShape, conditionShape, outputShape, xSize, ySize, cSize); } } // namespace infini