forked from jiuyuan/InfiniTensor
modified transpose and where
This commit is contained in:
parent
d1de3ab5c2
commit
aa1c3222ed
|
@ -7,5 +7,5 @@ namespace infini {
|
|||
|
||||
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
|
||||
SmallArray strides, SmallArray outputShape);
|
||||
|
||||
void transposeSpecial_kernel(int dType, void *input, void *output, int size);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -1,17 +1,10 @@
|
|||
#pragma once
|
||||
#include "operators/unary.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||
int ySize, int cSize);
|
||||
void whereKernel(const half *inputX, const half *inputY,
|
||||
const uint8_t *condition, half *output, int nDims,
|
||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||
int ySize, int cSize);
|
||||
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
|
||||
const uint8_t *condition, void *output, int a0, int a1, int a2,
|
||||
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3, int d0, int d1, int d2, int d3);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -16,10 +16,33 @@ class TransposeCuda : public CudaKernelWithoutConfig {
|
|||
void *const outputData = output->getRawDataPtr<void *>();
|
||||
const auto &inputShape = input->getDims();
|
||||
const auto &outputShape = output->getDims();
|
||||
|
||||
const auto &perm = op->getPermute();
|
||||
const int dType = op->getDType().getIndex();
|
||||
int size = input->size();
|
||||
int nDims = input->getDims().size();
|
||||
//----------------
|
||||
bool condition = true;
|
||||
int gnum = 0;
|
||||
for (int i = 0; i < nDims; i++) {
|
||||
if (inputShape[i] > 1) {
|
||||
while (gnum < nDims) {
|
||||
if (outputShape[gnum] > 1) {
|
||||
gnum += 1;
|
||||
break;
|
||||
} else {
|
||||
gnum += 1;
|
||||
}
|
||||
}
|
||||
if (inputShape[i] != outputShape[gnum - 1]) {
|
||||
condition = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
//----------------
|
||||
if (condition) {
|
||||
transposeSpecial_kernel(dType, inputData, outputData, size);
|
||||
} else {
|
||||
const auto &perm = op->getPermute();
|
||||
|
||||
// Compute strides
|
||||
SmallArray strides, buffer;
|
||||
|
@ -38,10 +61,10 @@ class TransposeCuda : public CudaKernelWithoutConfig {
|
|||
outputDims.data[i] = outputShape[i];
|
||||
}
|
||||
|
||||
const int dType = op->getDType().getIndex();
|
||||
transpose_kernel(dType, inputData, outputData, nDims, size, strides,
|
||||
outputDims);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class DepthToSpaceCuda : public CudaKernelWithoutConfig {
|
||||
|
|
|
@ -24,8 +24,8 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
|
|||
}
|
||||
#define CASE(T) \
|
||||
_transpose_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(input, output, nDims, size, strides, outputShape);
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
|
||||
input, output, nDims, size, strides, outputShape);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
|
@ -68,7 +68,47 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
|
|||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
template <class T>
|
||||
__global__ void _transposeSpecial_kernel(void *input, void *output, int size) {
|
||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (outputIdx < size) {
|
||||
((T *)output)[outputIdx] = ((T *)input)[outputIdx];
|
||||
}
|
||||
}
|
||||
#define CASESpecial(T) \
|
||||
_transposeSpecial_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
|
||||
input, output, size);
|
||||
|
||||
#define SWITCHSpecial_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASESpecial(1) break; \
|
||||
case 2: \
|
||||
CASESpecial(2) break; \
|
||||
case 3: \
|
||||
CASESpecial(3) break; \
|
||||
case 4: \
|
||||
CASESpecial(4) break; \
|
||||
case 5: \
|
||||
CASESpecial(5) break; \
|
||||
case 6: \
|
||||
CASESpecial(6) break; \
|
||||
case 7: \
|
||||
CASESpecial(7) break; \
|
||||
case 10: \
|
||||
CASESpecial(10) break; \
|
||||
case 11: \
|
||||
CASESpecial(11) break; \
|
||||
case 12: \
|
||||
CASESpecial(12) break; \
|
||||
case 13: \
|
||||
CASESpecial(13) break; \
|
||||
case 16: \
|
||||
CASESpecial(16) break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
namespace infini {
|
||||
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
|
||||
SmallArray strides, SmallArray outputShape) {
|
||||
|
@ -76,5 +116,10 @@ void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
|
|||
int gridsize = (size + block_work_size() - 1) / block_work_size();
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
void transposeSpecial_kernel(int dType, void *input, void *output, int size) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (size + block_work_size() - 1) / block_work_size();
|
||||
SWITCHSpecial_DTYPE(dType)
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
#include "operators/where.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "cuda/cuda_where.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -15,40 +15,31 @@ class WhereCuda : public CudaKernelWithoutConfig {
|
|||
void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
const auto &opInputXShape = op->getInputs(0)->getDims();
|
||||
const auto &opInputYShape = op->getInputs(1)->getDims();
|
||||
const auto &opConditionShape = op->getInputs(2)->getDims();
|
||||
const auto &opOutputShape = op->getOutput()->getDims();
|
||||
|
||||
const int xSize = op->getInputs(0)->getRank();
|
||||
const int ySize = op->getInputs(1)->getRank();
|
||||
const int cSize = op->getInputs(2)->getRank();
|
||||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getInputs(2)->getDims();
|
||||
auto d_dim = op->getOutput()->getDims();
|
||||
|
||||
int nDims = op->getOutput()->getDims().size();
|
||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||
int outputsize = 1;
|
||||
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
outputShape.data[i] = opOutputShape[i];
|
||||
outputsize *= outputShape.data[i];
|
||||
}
|
||||
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
|
||||
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
|
||||
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
|
||||
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4 ||
|
||||
d_dim.size() > 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
if (op->getDType() == DataType::Float32) {
|
||||
whereKernel((float *)inputXData, (float *)inputYData,
|
||||
(uint8_t *)conditionData, (float *)outputData, nDims,
|
||||
outputsize, inputXShape, inputYShape, conditionShape,
|
||||
outputShape, xSize, ySize, cSize);
|
||||
} else if (op->getDType() == DataType::Float16) {
|
||||
whereKernel((half *)inputXData, (half *)inputYData,
|
||||
(uint8_t *)conditionData, (half *)outputData, nDims,
|
||||
outputsize, inputXShape, inputYShape, conditionShape,
|
||||
outputShape, xSize, ySize, cSize);
|
||||
} else {
|
||||
IT_ASSERT(false);
|
||||
}
|
||||
int a[4] = {1, 1, 1, 1};
|
||||
int b[4] = {1, 1, 1, 1};
|
||||
int c[4] = {1, 1, 1, 1};
|
||||
int d[4] = {1, 1, 1, 1};
|
||||
|
||||
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
|
||||
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
|
||||
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
|
||||
std::copy(d_dim.begin(), d_dim.end(), d + (4 - d_dim.size()));
|
||||
|
||||
const int dTypeIndex = op->getDType().getIndex();
|
||||
whereKernel(dTypeIndex, inputXData, inputYData,
|
||||
(uint8_t *)conditionData, outputData, a[0], a[1], a[2],
|
||||
a[3], b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3], d[0],
|
||||
d[1], d[2], d[3]);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,94 +1,132 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "utils/small_array.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
const int repeat = 3;
|
||||
|
||||
__device__ int inferIndex(infini::SmallArray inputShape,
|
||||
infini::SmallArray outputShape, int nDims, int size,
|
||||
int outputIdx) {
|
||||
int inputIdx = 0;
|
||||
int tempInput = 1;
|
||||
int tempOutput = 1;
|
||||
for (int i = nDims - 1; i >= nDims - size; --i) {
|
||||
tempOutput = outputIdx % outputShape.data[i];
|
||||
if (inputShape.data[i] != 1) {
|
||||
inputIdx += tempInput * tempOutput;
|
||||
}
|
||||
tempInput *= inputShape.data[i];
|
||||
outputIdx /= outputShape.data[i];
|
||||
}
|
||||
return inputIdx;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void
|
||||
_whereKernel(const T *inputX, const T *inputY, const uint8_t *condition,
|
||||
T *output, int nDims, int outputsize,
|
||||
infini::SmallArray inputXShape, infini::SmallArray inputYShape,
|
||||
infini::SmallArray conditionShape, infini::SmallArray outputShape,
|
||||
int xSize, int ySize, int cSize) {
|
||||
_whereKernel(void *inputX, void *inputY, const uint8_t *condition, void *output,
|
||||
int a0, int a1, int a2, int a3, int b0, int b1, int b2, int b3,
|
||||
int c0, int c1, int c2, int c3, int d0, int d1, int d2, int d3) {
|
||||
|
||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (outputIdx < outputsize) {
|
||||
int conditionIdx =
|
||||
inferIndex(conditionShape, outputShape, nDims, cSize, outputIdx);
|
||||
int inputXIdx =
|
||||
inferIndex(inputXShape, outputShape, nDims, xSize, outputIdx);
|
||||
int stride1 = d2 * d3;
|
||||
int stride0 = d1 * stride1;
|
||||
int n = d0 * stride0;
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
|
||||
for (int i = repeat * index; i < end; i++) {
|
||||
int inputXIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
|
||||
int inputYIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
|
||||
int conditionIdx = (c0 * c1 * c2 * c3 == n ? i : 0);
|
||||
|
||||
int inputYIdx =
|
||||
inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx);
|
||||
|
||||
output[outputIdx] =
|
||||
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
|
||||
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
|
||||
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
|
||||
bool cIdx = (c0 * c1 * c2 * c3 < n && c0 * c1 * c2 * c3 > 1);
|
||||
if (aIdx || bIdx || cIdx) {
|
||||
int d0_index = i / stride0;
|
||||
int d1_index = (i % stride0) / stride1;
|
||||
int d2_index = (i % stride1) / d3;
|
||||
int d3_index = i % d3;
|
||||
if (aIdx) {
|
||||
int a0_index = d0_index % a0;
|
||||
int a1_index = d1_index % a1;
|
||||
int a2_index = d2_index % a2;
|
||||
int a3_index = d3_index % a3;
|
||||
inputXIdx = a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
||||
a2_index * a3 + a3_index;
|
||||
}
|
||||
if (bIdx) {
|
||||
int b0_index = d0_index % b0;
|
||||
int b1_index = d1_index % b1;
|
||||
int b2_index = d2_index % b2;
|
||||
int b3_index = d3_index % b3;
|
||||
inputYIdx = b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
||||
b2_index * b3 + b3_index;
|
||||
}
|
||||
if (cIdx) {
|
||||
int c0_index = d0_index % c0;
|
||||
int c1_index = d1_index % c1;
|
||||
int c2_index = d2_index % c2;
|
||||
int c3_index = d3_index % c3;
|
||||
conditionIdx = c0_index * c1 * c2 * c3 + c1_index * c2 * c3 +
|
||||
c2_index * c3 + c3_index;
|
||||
}
|
||||
}
|
||||
|
||||
((T *)output)[i] = condition[conditionIdx] ? ((T *)inputX)[inputXIdx]
|
||||
: ((T *)inputY)[inputYIdx];
|
||||
}
|
||||
}
|
||||
#define CASE(T) \
|
||||
_whereKernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
|
||||
inputX, inputY, condition, output, a0, a1, a2, a3, b0, b1, b2, b3, \
|
||||
c0, c1, c2, c3, d0, d1, d2, d3);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE(1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE(2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE(3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE(4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE(5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE(6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE(7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE(10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE(11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE(12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE(13) \
|
||||
break; \
|
||||
case 16: \
|
||||
CASE(16) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
namespace infini {
|
||||
void whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||
int ySize, int cSize) {
|
||||
|
||||
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
|
||||
const uint8_t *condition, void *output, int a0, int a1, int a2,
|
||||
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3, int d0, int d1, int d2, int d3) {
|
||||
int blocksize;
|
||||
if (outputsize > 511) {
|
||||
int outputsize = d0 * d1 * d2 * d3;
|
||||
if (outputsize > 511 * repeat) {
|
||||
blocksize = 1024;
|
||||
} else if (outputsize > 255) {
|
||||
} else if (outputsize > 255 * repeat) {
|
||||
blocksize = 512;
|
||||
} else if (outputsize > 127) {
|
||||
} else if (outputsize > 127 * repeat) {
|
||||
blocksize = 256;
|
||||
} else if (outputsize > 63) {
|
||||
} else if (outputsize > 63 * repeat) {
|
||||
blocksize = 128;
|
||||
} else if (outputsize > 31) {
|
||||
} else if (outputsize > 31 * repeat) {
|
||||
blocksize = 64;
|
||||
} else {
|
||||
blocksize = 32;
|
||||
}
|
||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||
_whereKernel<float>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
|
||||
}
|
||||
void whereKernel(const half *inputX, const half *inputY,
|
||||
const uint8_t *condition, half *output, int nDims,
|
||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||
int ySize, int cSize) {
|
||||
int blocksize;
|
||||
if (outputsize > 511) {
|
||||
blocksize = 1024;
|
||||
} else if (outputsize > 255) {
|
||||
blocksize = 512;
|
||||
} else if (outputsize > 127) {
|
||||
blocksize = 256;
|
||||
} else if (outputsize > 63) {
|
||||
blocksize = 128;
|
||||
} else if (outputsize > 31) {
|
||||
blocksize = 64;
|
||||
} else {
|
||||
blocksize = 32;
|
||||
}
|
||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||
_whereKernel<half>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
|
||||
int gridsize = (outputsize + repeat * blocksize - 1) / (repeat * blocksize);
|
||||
|
||||
SWITCH_DTYPE(dTypeIndex)
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue