forked from jiuyuan/InfiniTensor
add infer index function (#175)
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
331f7ab2b8
commit
595a9906d2
|
@ -5,7 +5,8 @@
|
|||
namespace infini {
|
||||
void whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape);
|
||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||
int ySize, int cSize);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,11 +3,11 @@
|
|||
namespace infini {
|
||||
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
|
||||
int nDims, int size) {
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
for (int i = nDims - size - 1; i >= 0; --i) {
|
||||
modifyShape.data[i] = 1;
|
||||
}
|
||||
for (int i = size - 1; i >= 0; --i) {
|
||||
modifyShape.data[i + nDims - size] = originShape[i];
|
||||
for (int i = nDims - 1; i >= nDims - size; --i) {
|
||||
modifyShape.data[i] = originShape[i - nDims + size];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,21 +23,23 @@ class WhereCuda : public CudaKernelWithoutConfig {
|
|||
const int xSize = op->getInputs(0)->getRank();
|
||||
const int ySize = op->getInputs(1)->getRank();
|
||||
const int cSize = op->getInputs(2)->getRank();
|
||||
|
||||
int nDims = op->getOutput()->getDims().size();
|
||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||
|
||||
int outputsize = 1;
|
||||
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
outputShape.data[i] = opOutputShape[i];
|
||||
outputsize *= outputShape.data[i];
|
||||
}
|
||||
|
||||
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
|
||||
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
|
||||
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
|
||||
|
||||
whereKernel((float *)inputXData, (float *)inputYData,
|
||||
(uint8_t *)conditionData, (float *)outputData, nDims,
|
||||
inputXShape, inputYShape, conditionShape, outputShape);
|
||||
outputsize, inputXShape, inputYShape, conditionShape,
|
||||
outputShape, xSize, ySize, cSize);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,61 +1,40 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
__device__ int inferIndex(infini::SmallArray inputShape,
|
||||
infini::SmallArray outputShape, int nDims, int size,
|
||||
int outputIdx) {
|
||||
int inputIdx = 0;
|
||||
int tempInput = 1;
|
||||
int tempOutput = 1;
|
||||
for (int i = nDims - 1; i >= nDims - size; --i) {
|
||||
tempOutput = outputIdx % outputShape.data[i];
|
||||
if (inputShape.data[i] != 1) {
|
||||
inputIdx += tempInput * tempOutput;
|
||||
}
|
||||
tempInput *= inputShape.data[i];
|
||||
outputIdx /= outputShape.data[i];
|
||||
}
|
||||
return inputIdx;
|
||||
}
|
||||
__global__ void _whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
int outputsize, infini::SmallArray inputXShape,
|
||||
infini::SmallArray inputYShape,
|
||||
infini::SmallArray conditionShape,
|
||||
infini::SmallArray outputShape) {
|
||||
infini::SmallArray outputShape, int xSize,
|
||||
int ySize, int cSize) {
|
||||
|
||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (outputIdx < outputsize) {
|
||||
int inputXIdx = 0;
|
||||
int temp_inputX = 1;
|
||||
int conditionIdx =
|
||||
inferIndex(conditionShape, outputShape, nDims, cSize, outputIdx);
|
||||
int inputXIdx =
|
||||
inferIndex(inputXShape, outputShape, nDims, xSize, outputIdx);
|
||||
|
||||
int inputYIdx = 0;
|
||||
int temp_inputY = 1;
|
||||
int inputYIdx =
|
||||
inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx);
|
||||
|
||||
int conditionIdx = 0;
|
||||
int temp_condition = 1;
|
||||
|
||||
int tmp = 1; // stored s,k,j,i in order
|
||||
int v = outputIdx; // v = i(JKS) + j(KS) + k(S) + s
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
if (i == 0) {
|
||||
tmp = v; // i = outputIdx/(JKS)
|
||||
} else {
|
||||
tmp = v % outputShape.data[i]; // store s,k,j in order
|
||||
}
|
||||
if (inputXShape.data[i] == 1) {
|
||||
inputXIdx += 0;
|
||||
} else {
|
||||
inputXIdx +=
|
||||
tmp *
|
||||
temp_inputX; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
}
|
||||
temp_inputX *= inputXShape.data[i];
|
||||
//----------------------------
|
||||
if (inputYShape.data[i] == 1) {
|
||||
inputYIdx += 0;
|
||||
} else {
|
||||
inputYIdx +=
|
||||
tmp *
|
||||
temp_inputY; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
}
|
||||
temp_inputY *= inputYShape.data[i];
|
||||
//--------------------------
|
||||
if (conditionShape.data[i] == 1) {
|
||||
conditionIdx += 0;
|
||||
} else {
|
||||
conditionIdx +=
|
||||
tmp *
|
||||
temp_condition; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
}
|
||||
temp_condition *= conditionShape.data[i];
|
||||
//-------------------------
|
||||
v = v / outputShape.data[i];
|
||||
}
|
||||
output[outputIdx] =
|
||||
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
|
||||
}
|
||||
|
@ -64,17 +43,26 @@ __global__ void _whereKernel(const float *inputX, const float *inputY,
|
|||
namespace infini {
|
||||
void whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape) {
|
||||
int outputsize = 1;
|
||||
|
||||
for (int i = 0; i < nDims; i++) {
|
||||
outputsize *= outputShape.data[i];
|
||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||
int ySize, int cSize) {
|
||||
int blocksize;
|
||||
if (outputsize > 511) {
|
||||
blocksize = 1024;
|
||||
} else if (outputsize > 255) {
|
||||
blocksize = 512;
|
||||
} else if (outputsize > 127) {
|
||||
blocksize = 256;
|
||||
} else if (outputsize > 63) {
|
||||
blocksize = 128;
|
||||
} else if (outputsize > 31) {
|
||||
blocksize = 64;
|
||||
} else {
|
||||
blocksize = 32;
|
||||
}
|
||||
int blocksize = 32 * 16;
|
||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||
_whereKernel<<<gridsize, blocksize>>>(
|
||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||
inputYShape, conditionShape, outputShape);
|
||||
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue