forked from jiuyuan/InfiniTensor
add infer index function (#175)
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
331f7ab2b8
commit
595a9906d2
|
@ -5,7 +5,8 @@
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void whereKernel(const float *inputX, const float *inputY,
|
void whereKernel(const float *inputX, const float *inputY,
|
||||||
const uint8_t *condition, float *output, int nDims,
|
const uint8_t *condition, float *output, int nDims,
|
||||||
SmallArray inputXShape, SmallArray inputYShape,
|
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||||
SmallArray conditionShape, SmallArray outputShape);
|
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||||
|
int ySize, int cSize);
|
||||||
|
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -3,11 +3,11 @@
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
|
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
|
||||||
int nDims, int size) {
|
int nDims, int size) {
|
||||||
for (int i = nDims - 1; i >= 0; --i) {
|
for (int i = nDims - size - 1; i >= 0; --i) {
|
||||||
modifyShape.data[i] = 1;
|
modifyShape.data[i] = 1;
|
||||||
}
|
}
|
||||||
for (int i = size - 1; i >= 0; --i) {
|
for (int i = nDims - 1; i >= nDims - size; --i) {
|
||||||
modifyShape.data[i + nDims - size] = originShape[i];
|
modifyShape.data[i] = originShape[i - nDims + size];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,21 +23,23 @@ class WhereCuda : public CudaKernelWithoutConfig {
|
||||||
const int xSize = op->getInputs(0)->getRank();
|
const int xSize = op->getInputs(0)->getRank();
|
||||||
const int ySize = op->getInputs(1)->getRank();
|
const int ySize = op->getInputs(1)->getRank();
|
||||||
const int cSize = op->getInputs(2)->getRank();
|
const int cSize = op->getInputs(2)->getRank();
|
||||||
|
|
||||||
int nDims = op->getOutput()->getDims().size();
|
int nDims = op->getOutput()->getDims().size();
|
||||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||||
|
int outputsize = 1;
|
||||||
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
|
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
|
||||||
for (int i = nDims - 1; i >= 0; --i) {
|
for (int i = nDims - 1; i >= 0; --i) {
|
||||||
outputShape.data[i] = opOutputShape[i];
|
outputShape.data[i] = opOutputShape[i];
|
||||||
|
outputsize *= outputShape.data[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
|
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
|
||||||
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
|
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
|
||||||
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
|
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
|
||||||
|
|
||||||
whereKernel((float *)inputXData, (float *)inputYData,
|
whereKernel((float *)inputXData, (float *)inputYData,
|
||||||
(uint8_t *)conditionData, (float *)outputData, nDims,
|
(uint8_t *)conditionData, (float *)outputData, nDims,
|
||||||
inputXShape, inputYShape, conditionShape, outputShape);
|
outputsize, inputXShape, inputYShape, conditionShape,
|
||||||
|
outputShape, xSize, ySize, cSize);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1,61 +1,40 @@
|
||||||
#include "cuda/cuda_common.h"
|
#include "cuda/cuda_common.h"
|
||||||
#include "utils/small_array.h"
|
#include "utils/small_array.h"
|
||||||
|
|
||||||
|
__device__ int inferIndex(infini::SmallArray inputShape,
|
||||||
|
infini::SmallArray outputShape, int nDims, int size,
|
||||||
|
int outputIdx) {
|
||||||
|
int inputIdx = 0;
|
||||||
|
int tempInput = 1;
|
||||||
|
int tempOutput = 1;
|
||||||
|
for (int i = nDims - 1; i >= nDims - size; --i) {
|
||||||
|
tempOutput = outputIdx % outputShape.data[i];
|
||||||
|
if (inputShape.data[i] != 1) {
|
||||||
|
inputIdx += tempInput * tempOutput;
|
||||||
|
}
|
||||||
|
tempInput *= inputShape.data[i];
|
||||||
|
outputIdx /= outputShape.data[i];
|
||||||
|
}
|
||||||
|
return inputIdx;
|
||||||
|
}
|
||||||
__global__ void _whereKernel(const float *inputX, const float *inputY,
|
__global__ void _whereKernel(const float *inputX, const float *inputY,
|
||||||
const uint8_t *condition, float *output, int nDims,
|
const uint8_t *condition, float *output, int nDims,
|
||||||
int outputsize, infini::SmallArray inputXShape,
|
int outputsize, infini::SmallArray inputXShape,
|
||||||
infini::SmallArray inputYShape,
|
infini::SmallArray inputYShape,
|
||||||
infini::SmallArray conditionShape,
|
infini::SmallArray conditionShape,
|
||||||
infini::SmallArray outputShape) {
|
infini::SmallArray outputShape, int xSize,
|
||||||
|
int ySize, int cSize) {
|
||||||
|
|
||||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if (outputIdx < outputsize) {
|
if (outputIdx < outputsize) {
|
||||||
int inputXIdx = 0;
|
int conditionIdx =
|
||||||
int temp_inputX = 1;
|
inferIndex(conditionShape, outputShape, nDims, cSize, outputIdx);
|
||||||
|
int inputXIdx =
|
||||||
|
inferIndex(inputXShape, outputShape, nDims, xSize, outputIdx);
|
||||||
|
|
||||||
int inputYIdx = 0;
|
int inputYIdx =
|
||||||
int temp_inputY = 1;
|
inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx);
|
||||||
|
|
||||||
int conditionIdx = 0;
|
|
||||||
int temp_condition = 1;
|
|
||||||
|
|
||||||
int tmp = 1; // stored s,k,j,i in order
|
|
||||||
int v = outputIdx; // v = i(JKS) + j(KS) + k(S) + s
|
|
||||||
for (int i = nDims - 1; i >= 0; --i) {
|
|
||||||
if (i == 0) {
|
|
||||||
tmp = v; // i = outputIdx/(JKS)
|
|
||||||
} else {
|
|
||||||
tmp = v % outputShape.data[i]; // store s,k,j in order
|
|
||||||
}
|
|
||||||
if (inputXShape.data[i] == 1) {
|
|
||||||
inputXIdx += 0;
|
|
||||||
} else {
|
|
||||||
inputXIdx +=
|
|
||||||
tmp *
|
|
||||||
temp_inputX; // otherwise +i(JKS) or j(KS) or k(S) or s
|
|
||||||
}
|
|
||||||
temp_inputX *= inputXShape.data[i];
|
|
||||||
//----------------------------
|
|
||||||
if (inputYShape.data[i] == 1) {
|
|
||||||
inputYIdx += 0;
|
|
||||||
} else {
|
|
||||||
inputYIdx +=
|
|
||||||
tmp *
|
|
||||||
temp_inputY; // otherwise +i(JKS) or j(KS) or k(S) or s
|
|
||||||
}
|
|
||||||
temp_inputY *= inputYShape.data[i];
|
|
||||||
//--------------------------
|
|
||||||
if (conditionShape.data[i] == 1) {
|
|
||||||
conditionIdx += 0;
|
|
||||||
} else {
|
|
||||||
conditionIdx +=
|
|
||||||
tmp *
|
|
||||||
temp_condition; // otherwise +i(JKS) or j(KS) or k(S) or s
|
|
||||||
}
|
|
||||||
temp_condition *= conditionShape.data[i];
|
|
||||||
//-------------------------
|
|
||||||
v = v / outputShape.data[i];
|
|
||||||
}
|
|
||||||
output[outputIdx] =
|
output[outputIdx] =
|
||||||
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
|
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
|
||||||
}
|
}
|
||||||
|
@ -64,17 +43,26 @@ __global__ void _whereKernel(const float *inputX, const float *inputY,
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void whereKernel(const float *inputX, const float *inputY,
|
void whereKernel(const float *inputX, const float *inputY,
|
||||||
const uint8_t *condition, float *output, int nDims,
|
const uint8_t *condition, float *output, int nDims,
|
||||||
SmallArray inputXShape, SmallArray inputYShape,
|
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||||
SmallArray conditionShape, SmallArray outputShape) {
|
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||||
int outputsize = 1;
|
int ySize, int cSize) {
|
||||||
|
int blocksize;
|
||||||
for (int i = 0; i < nDims; i++) {
|
if (outputsize > 511) {
|
||||||
outputsize *= outputShape.data[i];
|
blocksize = 1024;
|
||||||
|
} else if (outputsize > 255) {
|
||||||
|
blocksize = 512;
|
||||||
|
} else if (outputsize > 127) {
|
||||||
|
blocksize = 256;
|
||||||
|
} else if (outputsize > 63) {
|
||||||
|
blocksize = 128;
|
||||||
|
} else if (outputsize > 31) {
|
||||||
|
blocksize = 64;
|
||||||
|
} else {
|
||||||
|
blocksize = 32;
|
||||||
}
|
}
|
||||||
int blocksize = 32 * 16;
|
|
||||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||||
_whereKernel<<<gridsize, blocksize>>>(
|
_whereKernel<<<gridsize, blocksize>>>(
|
||||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||||
inputYShape, conditionShape, outputShape);
|
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
|
||||||
}
|
}
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue