add infer index function (#175)

Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
xgqdut2016 2023-11-24 09:24:25 +08:00 committed by GitHub
parent 331f7ab2b8
commit 595a9906d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 61 deletions

View File

@ -5,7 +5,8 @@
namespace infini {
void whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims,
SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape);
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize);
}; // namespace infini

View File

@ -3,11 +3,11 @@
namespace infini {
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
int nDims, int size) {
for (int i = nDims - 1; i >= 0; --i) {
for (int i = nDims - size - 1; i >= 0; --i) {
modifyShape.data[i] = 1;
}
for (int i = size - 1; i >= 0; --i) {
modifyShape.data[i + nDims - size] = originShape[i];
for (int i = nDims - 1; i >= nDims - size; --i) {
modifyShape.data[i] = originShape[i - nDims + size];
}
}

View File

@ -23,21 +23,23 @@ class WhereCuda : public CudaKernelWithoutConfig {
const int xSize = op->getInputs(0)->getRank();
const int ySize = op->getInputs(1)->getRank();
const int cSize = op->getInputs(2)->getRank();
int nDims = op->getOutput()->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1;
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
for (int i = nDims - 1; i >= 0; --i) {
outputShape.data[i] = opOutputShape[i];
outputsize *= outputShape.data[i];
}
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
whereKernel((float *)inputXData, (float *)inputYData,
(uint8_t *)conditionData, (float *)outputData, nDims,
inputXShape, inputYShape, conditionShape, outputShape);
outputsize, inputXShape, inputYShape, conditionShape,
outputShape, xSize, ySize, cSize);
}
};

View File

@ -1,61 +1,40 @@
#include "cuda/cuda_common.h"
#include "utils/small_array.h"
__device__ int inferIndex(infini::SmallArray inputShape,
infini::SmallArray outputShape, int nDims, int size,
int outputIdx) {
int inputIdx = 0;
int tempInput = 1;
int tempOutput = 1;
for (int i = nDims - 1; i >= nDims - size; --i) {
tempOutput = outputIdx % outputShape.data[i];
if (inputShape.data[i] != 1) {
inputIdx += tempInput * tempOutput;
}
tempInput *= inputShape.data[i];
outputIdx /= outputShape.data[i];
}
return inputIdx;
}
__global__ void _whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims,
int outputsize, infini::SmallArray inputXShape,
infini::SmallArray inputYShape,
infini::SmallArray conditionShape,
infini::SmallArray outputShape) {
infini::SmallArray outputShape, int xSize,
int ySize, int cSize) {
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (outputIdx < outputsize) {
int inputXIdx = 0;
int temp_inputX = 1;
int conditionIdx =
inferIndex(conditionShape, outputShape, nDims, cSize, outputIdx);
int inputXIdx =
inferIndex(inputXShape, outputShape, nDims, xSize, outputIdx);
int inputYIdx = 0;
int temp_inputY = 1;
int inputYIdx =
inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx);
int conditionIdx = 0;
int temp_condition = 1;
int tmp = 1; // stored s,k,j,i in order
int v = outputIdx; // v = i(JKS) + j(KS) + k(S) + s
for (int i = nDims - 1; i >= 0; --i) {
if (i == 0) {
tmp = v; // i = outputIdx/(JKS)
} else {
tmp = v % outputShape.data[i]; // store s,k,j in order
}
if (inputXShape.data[i] == 1) {
inputXIdx += 0;
} else {
inputXIdx +=
tmp *
temp_inputX; // otherwise +i(JKS) or j(KS) or k(S) or s
}
temp_inputX *= inputXShape.data[i];
//----------------------------
if (inputYShape.data[i] == 1) {
inputYIdx += 0;
} else {
inputYIdx +=
tmp *
temp_inputY; // otherwise +i(JKS) or j(KS) or k(S) or s
}
temp_inputY *= inputYShape.data[i];
//--------------------------
if (conditionShape.data[i] == 1) {
conditionIdx += 0;
} else {
conditionIdx +=
tmp *
temp_condition; // otherwise +i(JKS) or j(KS) or k(S) or s
}
temp_condition *= conditionShape.data[i];
//-------------------------
v = v / outputShape.data[i];
}
output[outputIdx] =
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
}
@ -64,17 +43,26 @@ __global__ void _whereKernel(const float *inputX, const float *inputY,
namespace infini {
void whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims,
SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape) {
int outputsize = 1;
for (int i = 0; i < nDims; i++) {
outputsize *= outputShape.data[i];
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize) {
int blocksize;
if (outputsize > 511) {
blocksize = 1024;
} else if (outputsize > 255) {
blocksize = 512;
} else if (outputsize > 127) {
blocksize = 256;
} else if (outputsize > 63) {
blocksize = 128;
} else if (outputsize > 31) {
blocksize = 64;
} else {
blocksize = 32;
}
int blocksize = 32 * 16;
int gridsize = (outputsize + blocksize - 1) / blocksize;
_whereKernel<<<gridsize, blocksize>>>(
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputYShape, conditionShape, outputShape);
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
}
} // namespace infini