modified transpose and where

This commit is contained in:
xgqdut2016 2024-04-10 10:17:45 +08:00
parent d1de3ab5c2
commit aa1c3222ed
6 changed files with 228 additions and 138 deletions

View File

@ -7,5 +7,5 @@ namespace infini {
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
SmallArray strides, SmallArray outputShape);
void transposeSpecial_kernel(int dType, void *input, void *output, int size);
}; // namespace infini

View File

@ -1,17 +1,10 @@
#pragma once
#include "operators/unary.h"
#include "utils/small_array.h"
namespace infini {
void whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize);
void whereKernel(const half *inputX, const half *inputY,
const uint8_t *condition, half *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize);
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
const uint8_t *condition, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3, int d0, int d1, int d2, int d3);
}; // namespace infini

View File

@ -16,31 +16,54 @@ class TransposeCuda : public CudaKernelWithoutConfig {
void *const outputData = output->getRawDataPtr<void *>();
const auto &inputShape = input->getDims();
const auto &outputShape = output->getDims();
const auto &perm = op->getPermute();
const int dType = op->getDType().getIndex();
int size = input->size();
int nDims = input->getDims().size();
// Compute strides
SmallArray strides, buffer;
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int curStride = 1;
for (int i = nDims - 1; i >= 0; --i) {
buffer.data[i] = curStride;
curStride *= inputShape[i];
}
for (int i = 0; i < nDims; ++i) {
strides.data[i] = buffer.data[perm[i]];
//----------------
bool condition = true;
int gnum = 0;
for (int i = 0; i < nDims; i++) {
if (inputShape[i] > 1) {
while (gnum < nDims) {
if (outputShape[gnum] > 1) {
gnum += 1;
break;
} else {
gnum += 1;
}
}
if (inputShape[i] != outputShape[gnum - 1]) {
condition = false;
break;
}
}
}
//----------------
if (condition) {
transposeSpecial_kernel(dType, inputData, outputData, size);
} else {
const auto &perm = op->getPermute();
SmallArray outputDims;
for (int i = 0; i < nDims; ++i) {
outputDims.data[i] = outputShape[i];
}
// Compute strides
SmallArray strides, buffer;
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int curStride = 1;
for (int i = nDims - 1; i >= 0; --i) {
buffer.data[i] = curStride;
curStride *= inputShape[i];
}
for (int i = 0; i < nDims; ++i) {
strides.data[i] = buffer.data[perm[i]];
}
const int dType = op->getDType().getIndex();
transpose_kernel(dType, inputData, outputData, nDims, size, strides,
outputDims);
SmallArray outputDims;
for (int i = 0; i < nDims; ++i) {
outputDims.data[i] = outputShape[i];
}
transpose_kernel(dType, inputData, outputData, nDims, size, strides,
outputDims);
}
}
};

View File

@ -24,8 +24,8 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
}
#define CASE(T) \
_transpose_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(input, output, nDims, size, strides, outputShape);
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
input, output, nDims, size, strides, outputShape);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
@ -68,7 +68,47 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
default: \
IT_TODO_HALT(); \
}
template <class T>
__global__ void _transposeSpecial_kernel(void *input, void *output, int size) {
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (outputIdx < size) {
((T *)output)[outputIdx] = ((T *)input)[outputIdx];
}
}
#define CASESpecial(T) \
_transposeSpecial_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
input, output, size);
#define SWITCHSpecial_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASESpecial(1) break; \
case 2: \
CASESpecial(2) break; \
case 3: \
CASESpecial(3) break; \
case 4: \
CASESpecial(4) break; \
case 5: \
CASESpecial(5) break; \
case 6: \
CASESpecial(6) break; \
case 7: \
CASESpecial(7) break; \
case 10: \
CASESpecial(10) break; \
case 11: \
CASESpecial(11) break; \
case 12: \
CASESpecial(12) break; \
case 13: \
CASESpecial(13) break; \
case 16: \
CASESpecial(16) break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
SmallArray strides, SmallArray outputShape) {
@ -76,5 +116,10 @@ void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
int gridsize = (size + block_work_size() - 1) / block_work_size();
SWITCH_DTYPE(dType)
}
void transposeSpecial_kernel(int dType, void *input, void *output, int size) {
int blocksize = block_work_size();
int gridsize = (size + block_work_size() - 1) / block_work_size();
SWITCHSpecial_DTYPE(dType)
}
} // namespace infini

View File

@ -1,8 +1,8 @@
#include "operators/where.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "cuda/cuda_where.h"
#include "utils/operator_utils.h"
namespace infini {
@ -15,40 +15,31 @@ class WhereCuda : public CudaKernelWithoutConfig {
void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
const auto &opInputXShape = op->getInputs(0)->getDims();
const auto &opInputYShape = op->getInputs(1)->getDims();
const auto &opConditionShape = op->getInputs(2)->getDims();
const auto &opOutputShape = op->getOutput()->getDims();
const int xSize = op->getInputs(0)->getRank();
const int ySize = op->getInputs(1)->getRank();
const int cSize = op->getInputs(2)->getRank();
auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getInputs(2)->getDims();
auto d_dim = op->getOutput()->getDims();
int nDims = op->getOutput()->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1;
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
for (int i = nDims - 1; i >= 0; --i) {
outputShape.data[i] = opOutputShape[i];
outputsize *= outputShape.data[i];
}
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4 ||
d_dim.size() > 4)
IT_TODO_HALT();
if (op->getDType() == DataType::Float32) {
whereKernel((float *)inputXData, (float *)inputYData,
(uint8_t *)conditionData, (float *)outputData, nDims,
outputsize, inputXShape, inputYShape, conditionShape,
outputShape, xSize, ySize, cSize);
} else if (op->getDType() == DataType::Float16) {
whereKernel((half *)inputXData, (half *)inputYData,
(uint8_t *)conditionData, (half *)outputData, nDims,
outputsize, inputXShape, inputYShape, conditionShape,
outputShape, xSize, ySize, cSize);
} else {
IT_ASSERT(false);
}
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
int c[4] = {1, 1, 1, 1};
int d[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
std::copy(d_dim.begin(), d_dim.end(), d + (4 - d_dim.size()));
const int dTypeIndex = op->getDType().getIndex();
whereKernel(dTypeIndex, inputXData, inputYData,
(uint8_t *)conditionData, outputData, a[0], a[1], a[2],
a[3], b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3], d[0],
d[1], d[2], d[3]);
}
};

View File

@ -1,94 +1,132 @@
#include "cuda/cuda_common.h"
#include "utils/small_array.h"
#include "cuda/cuda_utility.h"
const int repeat = 3;
__device__ int inferIndex(infini::SmallArray inputShape,
infini::SmallArray outputShape, int nDims, int size,
int outputIdx) {
int inputIdx = 0;
int tempInput = 1;
int tempOutput = 1;
for (int i = nDims - 1; i >= nDims - size; --i) {
tempOutput = outputIdx % outputShape.data[i];
if (inputShape.data[i] != 1) {
inputIdx += tempInput * tempOutput;
}
tempInput *= inputShape.data[i];
outputIdx /= outputShape.data[i];
}
return inputIdx;
}
template <typename T>
__global__ void
_whereKernel(const T *inputX, const T *inputY, const uint8_t *condition,
T *output, int nDims, int outputsize,
infini::SmallArray inputXShape, infini::SmallArray inputYShape,
infini::SmallArray conditionShape, infini::SmallArray outputShape,
int xSize, int ySize, int cSize) {
_whereKernel(void *inputX, void *inputY, const uint8_t *condition, void *output,
int a0, int a1, int a2, int a3, int b0, int b1, int b2, int b3,
int c0, int c1, int c2, int c3, int d0, int d1, int d2, int d3) {
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (outputIdx < outputsize) {
int conditionIdx =
inferIndex(conditionShape, outputShape, nDims, cSize, outputIdx);
int inputXIdx =
inferIndex(inputXShape, outputShape, nDims, xSize, outputIdx);
int stride1 = d2 * d3;
int stride0 = d1 * stride1;
int n = d0 * stride0;
int index = threadIdx.x + blockIdx.x * blockDim.x;
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
for (int i = repeat * index; i < end; i++) {
int inputXIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int inputYIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int conditionIdx = (c0 * c1 * c2 * c3 == n ? i : 0);
int inputYIdx =
inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx);
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
bool cIdx = (c0 * c1 * c2 * c3 < n && c0 * c1 * c2 * c3 > 1);
if (aIdx || bIdx || cIdx) {
int d0_index = i / stride0;
int d1_index = (i % stride0) / stride1;
int d2_index = (i % stride1) / d3;
int d3_index = i % d3;
if (aIdx) {
int a0_index = d0_index % a0;
int a1_index = d1_index % a1;
int a2_index = d2_index % a2;
int a3_index = d3_index % a3;
inputXIdx = a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index;
}
if (bIdx) {
int b0_index = d0_index % b0;
int b1_index = d1_index % b1;
int b2_index = d2_index % b2;
int b3_index = d3_index % b3;
inputYIdx = b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
b2_index * b3 + b3_index;
}
if (cIdx) {
int c0_index = d0_index % c0;
int c1_index = d1_index % c1;
int c2_index = d2_index % c2;
int c3_index = d3_index % c3;
conditionIdx = c0_index * c1 * c2 * c3 + c1_index * c2 * c3 +
c2_index * c3 + c3_index;
}
}
output[outputIdx] =
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
((T *)output)[i] = condition[conditionIdx] ? ((T *)inputX)[inputXIdx]
: ((T *)inputY)[inputYIdx];
}
}
#define CASE(T) \
_whereKernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
inputX, inputY, condition, output, a0, a1, a2, a3, b0, b1, b2, b3, \
c0, c1, c2, c3, d0, d1, d2, d3);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize) {
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
const uint8_t *condition, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3, int d0, int d1, int d2, int d3) {
int blocksize;
if (outputsize > 511) {
int outputsize = d0 * d1 * d2 * d3;
if (outputsize > 511 * repeat) {
blocksize = 1024;
} else if (outputsize > 255) {
} else if (outputsize > 255 * repeat) {
blocksize = 512;
} else if (outputsize > 127) {
} else if (outputsize > 127 * repeat) {
blocksize = 256;
} else if (outputsize > 63) {
} else if (outputsize > 63 * repeat) {
blocksize = 128;
} else if (outputsize > 31) {
} else if (outputsize > 31 * repeat) {
blocksize = 64;
} else {
blocksize = 32;
}
int gridsize = (outputsize + blocksize - 1) / blocksize;
_whereKernel<float>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
}
void whereKernel(const half *inputX, const half *inputY,
const uint8_t *condition, half *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize) {
int blocksize;
if (outputsize > 511) {
blocksize = 1024;
} else if (outputsize > 255) {
blocksize = 512;
} else if (outputsize > 127) {
blocksize = 256;
} else if (outputsize > 63) {
blocksize = 128;
} else if (outputsize > 31) {
blocksize = 64;
} else {
blocksize = 32;
}
int gridsize = (outputsize + blocksize - 1) / blocksize;
_whereKernel<half>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
int gridsize = (outputsize + repeat * blocksize - 1) / (repeat * blocksize);
SWITCH_DTYPE(dTypeIndex)
}
} // namespace infini