Compare commits

...

5 Commits

Author SHA1 Message Date
xgqdut2016 7146294baa memcopy instead of special kernel 2024-05-06 14:49:39 +08:00
xgqdut2016 73e3f1fc6f add currency operator 2024-04-10 15:01:22 +08:00
xgqdut2016 86133c8d0a modified expand 2024-04-10 11:16:54 +08:00
xgqdut2016 2761d46737 modified div_kernel 2024-04-10 10:51:35 +08:00
xgqdut2016 aa1c3222ed modified transpose and where 2024-04-10 10:17:45 +08:00
11 changed files with 535 additions and 225 deletions

View File

@ -3,10 +3,11 @@
#include "operators/unary.h"
#include "utils/small_array.h"
namespace infini {
void expandKernel(int dType, void *input, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3);
void expandKernel(int dType, void *input, void *output, int nDims,
int outputsize, SmallArray inputShape,
SmallArray outputShape);
void expandRowKernel(int dType, void *input, void *output, int n_rows,
int row_len);
}; // namespace infini

View File

@ -1,16 +1,14 @@
#pragma once
#include "operators/unary.h"
#include "utils/small_array.h"
namespace infini {
void whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize);
void whereKernel(const half *inputX, const half *inputY,
const uint8_t *condition, half *output, int nDims,
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
const uint8_t *condition, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3, int d0, int d1, int d2, int d3);
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
const uint8_t *condition, void *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize);

View File

@ -5,34 +5,42 @@
constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
const int repeat = 1;
template <class T>
__global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) {
int c0_index = i / (c1 * c2 * c3);
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
int stride1 = c2 * c3;
int stride0 = c1 * stride1;
int n = c0 * stride0;
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0;
int a1_index = c1_index % a1;
int a2_index = c2_index % a2;
int a3_index = c3_index % a3;
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
if (aIdx || bIdx) {
int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0;
int b1_index = c1_index % b1;
int b2_index = c2_index % b2;
int b3_index = c3_index % b3;
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index] /
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
b2_index * b3 + b3_index];
xIdx = (c0_index % a0) * a1 * a2 * a3 +
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
c3_index % a3;
}
if (bIdx) {
yIdx = (c0_index % b0) * b1 * b2 * b3 +
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
c3_index % b3;
}
}
((T *)z)[i] = ((T *)x)[xIdx] / ((T *)y)[yIdx];
}
}
@ -41,28 +49,36 @@ __global__ void _add_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) {
int c0_index = i / (c1 * c2 * c3);
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
int stride1 = c2 * c3;
int stride0 = c1 * stride1;
int n = c0 * stride0;
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0;
int a1_index = c1_index % a1;
int a2_index = c2_index % a2;
int a3_index = c3_index % a3;
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
if (aIdx || bIdx) {
int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0;
int b1_index = c1_index % b1;
int b2_index = c2_index % b2;
int b3_index = c3_index % b3;
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index] +
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
b2_index * b3 + b3_index];
xIdx = (c0_index % a0) * a1 * a2 * a3 +
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
c3_index % a3;
}
if (bIdx) {
yIdx = (c0_index % b0) * b1 * b2 * b3 +
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
c3_index % b3;
}
}
((T *)z)[i] = ((T *)x)[xIdx] + ((T *)y)[yIdx];
}
}
@ -71,29 +87,36 @@ __global__ void _pow_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) {
int c0_index = i / (c1 * c2 * c3);
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
int stride1 = c2 * c3;
int stride0 = c1 * stride1;
int n = c0 * stride0;
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0;
int a1_index = c1_index % a1;
int a2_index = c2_index % a2;
int a3_index = c3_index % a3;
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
if (aIdx || bIdx) {
int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0;
int b1_index = c1_index % b1;
int b2_index = c2_index % b2;
int b3_index = c3_index % b3;
((T *)z)[i] =
pow(((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index],
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
b2_index * b3 + b3_index]);
xIdx = (c0_index % a0) * a1 * a2 * a3 +
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
c3_index % a3;
}
if (bIdx) {
yIdx = (c0_index % b0) * b1 * b2 * b3 +
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
c3_index % b3;
}
}
((T *)z)[i] = pow(((T *)x)[xIdx], ((T *)y)[yIdx]);
}
}
@ -102,31 +125,36 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) {
int c0_index = i / (c1 * c2 * c3);
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
int stride1 = c2 * c3;
int stride0 = c1 * stride1;
int n = c0 * stride0;
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0;
int a1_index = c1_index % a1;
int a2_index = c2_index % a2;
int a3_index = c3_index % a3;
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
if (aIdx || bIdx) {
int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0;
int b1_index = c1_index % b1;
int b2_index = c2_index % b2;
int b3_index = c3_index % b3;
((bool *)z)[i] =
((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index] <
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
b2_index * b3 + b3_index]
? true
: false;
xIdx = (c0_index % a0) * a1 * a2 * a3 +
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
c3_index % a3;
}
if (bIdx) {
yIdx = (c0_index % b0) * b1 * b2 * b3 +
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
c3_index % b3;
}
}
((bool *)z)[i] = ((T *)x)[xIdx] < ((T *)y)[yIdx] ? true : false;
}
}
@ -176,7 +204,6 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
default: \
IT_TODO_HALT(); \
}
template <class T>
__global__ void _div_const_kernel(void const *__restrict__ x,
void const *__restrict__ y,
@ -269,7 +296,8 @@ void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
SWITCH_DTYPE(div, dType)
}
void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
@ -278,7 +306,8 @@ void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
SWITCH_DTYPE(add, dType)
}
void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
@ -286,7 +315,8 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int c3) {
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
if (dType == 1) {
_pow_kernel<float>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
@ -324,7 +354,8 @@ void less_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int c3) {
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
SWITCH_DTYPE(less, dType)
}

View File

@ -12,22 +12,33 @@ class ExpandCuda : public CudaKernelWithoutConfig {
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
const auto &in_Shape = op->getInputs(0)->getDims(); // input shape
const auto &out_Shape = op->getShape(); // output shape
auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getOutput()->getDims(); // output shape
SmallArray inputShape, outputShape;
int nDims = op->getInputs(0)->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1; // the length of the output vector after flatten
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = out_Shape[i];
inputShape.data[i] = in_Shape[i];
outputsize *= out_Shape[i];
}
const int dType = op->getDType().getIndex();
expandKernel(dType, inputData, outputData, nDims, outputsize,
inputShape, outputShape);
if (a_dim.size() > 4 || b_dim.size() > 4) {
SmallArray inputShape, outputShape;
int nDims = op->getInputs(0)->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1; // the length of the output vector after flatten
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = b_dim[i];
inputShape.data[i] = a_dim[i];
outputsize *= b_dim[i];
}
const int dType = op->getDType().getIndex();
expandKernel(dType, inputData, outputData, nDims, outputsize,
inputShape, outputShape);
} else {
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
expandKernel(dType, inputData, outputData, a[0], a[1], a[2], a[3],
b[0], b[1], b[2], b[3]);
}
}
};

View File

@ -6,7 +6,31 @@
constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
const int repeat = 1;
template <class T>
__global__ void _expandKernel(void *input, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride1 = b2 * b3;
int stride0 = b1 * stride1;
int n = b0 * stride0;
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
if (aIdx) {
int b0_index = i / stride0;
int b1_index = (i % stride0) / stride1;
int b2_index = (i % stride1) / b3;
int b3_index = i % b3;
xIdx = (b0_index % a0) * a1 * a2 * a3 + (b1_index % a1) * a2 * a3 +
(b2_index % a2) * a3 + b3_index % a3;
}
((T *)output)[i] = ((T *)input)[xIdx];
}
}
template <class T>
__global__ void _expandKernel(void *input, void *output, int nDims,
int outputsize, infini::SmallArray inputShape,
@ -38,7 +62,6 @@ __global__ void _expandKernel(void *input, void *output, int nDims,
((T *)output)[outputIdx] = ((T *)input)[inputIdx];
}
}
template <class T>
static __global__ void _expandRowKernel(void *__restrict__ dst,
void const *__restrict__ src) {
@ -50,9 +73,9 @@ static __global__ void _expandRowKernel(void *__restrict__ dst,
namespace infini {
#define CASE(T) \
_expandKernel<DT_CUDA<T>::t><<<gridsize, blocksize, \
0, CUDAStream::getCurrentStream()>>>( \
input, output, nDims, outputsize, inputShape, outputShape);
_expandKernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
input, output, a0, a1, a2, a3, b0, b1, b2, b3);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
@ -96,14 +119,56 @@ namespace infini {
IT_TODO_HALT(); \
}
void expandKernel(int dType, void *input, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3) {
int blocksize = block_work_size();
int outputsize = b0 * b1 * b2 * b3;
int gridsize = (outputsize + repeat * block_work_size() - 1) /
(repeat * block_work_size());
SWITCH_DTYPE(dType)
}
#define CASECurrency(T) \
_expandKernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
input, output, nDims, outputsize, inputShape, outputShape);
#define SWITCHCurrency_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASECurrency(1) break; \
case 2: \
CASECurrency(2) break; \
case 3: \
CASECurrency(3) break; \
case 4: \
CASECurrency(4) break; \
case 5: \
CASECurrency(5) break; \
case 6: \
CASECurrency(6) break; \
case 7: \
CASECurrency(7) break; \
case 10: \
CASECurrency(10) break; \
case 11: \
CASECurrency(11) break; \
case 12: \
CASECurrency(12) break; \
case 13: \
CASECurrency(13) break; \
case 16: \
CASECurrency(16) break; \
default: \
IT_TODO_HALT(); \
}
void expandKernel(int dType, void *input, void *output, int nDims,
int outputsize, SmallArray inputShape,
SmallArray outputShape) {
int blocksize = block_work_size();
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
SWITCH_DTYPE(dType)
SWITCHCurrency_DTYPE(dType)
}
#define CASE_ROW(T) \
_expandRowKernel<float> \
<<<grid, block, 0, CUDAStream::getCurrentStream()>>>(output, input);
@ -150,7 +215,8 @@ void expandKernel(int dType, void *input, void *output, int nDims,
IT_TODO_HALT(); \
}
// Optimization for expanding a row vector. The row length must be a multiple of 32
// Optimization for expanding a row vector. The row length must be a multiple of
// 32
void expandRowKernel(int dType, void *input, void *output, int n_rows,
int row_len) {
// Factorize row_len: row_len = a x b x 32 (32 is the warp size), b<=32
@ -160,7 +226,8 @@ void expandRowKernel(int dType, void *input, void *output, int n_rows,
// block: b x 32
auto c = row_len / 32, b = c;
if (b > 32) {
for (b = 32; c % b != 0; --b);
for (b = 32; c % b != 0; --b)
;
}
auto a = c / b;
dim3 grid(a, n_rows), block(32, b);

View File

@ -87,20 +87,7 @@ class matmulCublas : public Kernel {
beta_naive = 1.f;
auto inC = op->getInputs(2);
auto out = op->getOutput();
SmallArray inputShape, outputShape;
int nDims = out->getRank();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
// FIXME(constroy): use size_t for outputsize.
int outputsize = 1; // the length of the output vector after flatten
int offset = nDims - inC->getRank();
for (int i = 0; i < offset; ++i)
inputShape.data[i] = 1;
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = out->getDims()[i];
outputsize *= outputShape.data[i];
if (i >= offset)
inputShape.data[i] = inC->getDims()[i - offset];
}
const int dType = dataType.getIndex();
// Bias in linear layer is row vector of (1,n), n is the number of
@ -111,9 +98,40 @@ class matmulCublas : public Kernel {
out->size() / inC->getDims()[0],
inC->getDims()[0]);
} else {
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), nDims, outputsize,
inputShape, outputShape);
auto a_dim = out->getDims();
auto b_dim = inC->getDims(); // output shape
if (a_dim.size() > 4 || b_dim.size() > 4) {
SmallArray inputShape, outputShape;
int nDims = out->getRank();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
// FIXME(constroy): use size_t for outputsize.
int outputsize =
1; // the length of the output vector after flatten
int offset = nDims - inC->getRank();
for (int i = 0; i < offset; ++i)
inputShape.data[i] = 1;
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = out->getDims()[i];
outputsize *= outputShape.data[i];
if (i >= offset)
inputShape.data[i] = inC->getDims()[i - offset];
}
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), nDims,
outputsize, inputShape, outputShape);
} else {
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(),
a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(),
b + (4 - b_dim.size()));
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), a[0], a[1], a[2],
a[3], b[0], b[1], b[2], b[3]);
}
}
}
// TODO:use compute type

View File

@ -16,31 +16,57 @@ class TransposeCuda : public CudaKernelWithoutConfig {
void *const outputData = output->getRawDataPtr<void *>();
const auto &inputShape = input->getDims();
const auto &outputShape = output->getDims();
const auto &perm = op->getPermute();
const int dType = op->getDType().getIndex();
int size = input->size();
int nDims = input->getDims().size();
// Compute strides
SmallArray strides, buffer;
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int curStride = 1;
for (int i = nDims - 1; i >= 0; --i) {
buffer.data[i] = curStride;
curStride *= inputShape[i];
}
for (int i = 0; i < nDims; ++i) {
strides.data[i] = buffer.data[perm[i]];
//----------------
bool condition = true;
int gnum = 0;
for (int i = 0; i < nDims; i++) {
if (inputShape[i] > 1) {
while (gnum < nDims) {
if (outputShape[gnum] > 1) {
gnum += 1;
break;
} else {
gnum += 1;
}
}
if (inputShape[i] != outputShape[gnum - 1]) {
condition = false;
break;
}
}
}
//----------------
if (condition) {
cudaMemcpyAsync(outputData, inputData, op->getInputs(0)->getBytes(),
cudaMemcpyDeviceToDevice,
CUDAStream::getCurrentStream());
SmallArray outputDims;
for (int i = 0; i < nDims; ++i) {
outputDims.data[i] = outputShape[i];
}
} else {
const auto &perm = op->getPermute();
const int dType = op->getDType().getIndex();
transpose_kernel(dType, inputData, outputData, nDims, size, strides,
outputDims);
// Compute strides
SmallArray strides, buffer;
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int curStride = 1;
for (int i = nDims - 1; i >= 0; --i) {
buffer.data[i] = curStride;
curStride *= inputShape[i];
}
for (int i = 0; i < nDims; ++i) {
strides.data[i] = buffer.data[perm[i]];
}
SmallArray outputDims;
for (int i = 0; i < nDims; ++i) {
outputDims.data[i] = outputShape[i];
}
transpose_kernel(dType, inputData, outputData, nDims, size, strides,
outputDims);
}
}
};

View File

@ -24,8 +24,8 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
}
#define CASE(T) \
_transpose_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(input, output, nDims, size, strides, outputShape);
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
input, output, nDims, size, strides, outputShape);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \

View File

@ -1,8 +1,8 @@
#include "operators/where.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "cuda/cuda_where.h"
#include "utils/operator_utils.h"
namespace infini {
@ -15,39 +15,50 @@ class WhereCuda : public CudaKernelWithoutConfig {
void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
const auto &opInputXShape = op->getInputs(0)->getDims();
const auto &opInputYShape = op->getInputs(1)->getDims();
const auto &opConditionShape = op->getInputs(2)->getDims();
const auto &opOutputShape = op->getOutput()->getDims();
const int xSize = op->getInputs(0)->getRank();
const int ySize = op->getInputs(1)->getRank();
const int cSize = op->getInputs(2)->getRank();
auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getInputs(2)->getDims();
auto d_dim = op->getOutput()->getDims();
const int dTypeIndex = op->getDType().getIndex();
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4 ||
d_dim.size() > 4) {
const int xSize = op->getInputs(0)->getRank();
const int ySize = op->getInputs(1)->getRank();
const int cSize = op->getInputs(2)->getRank();
int nDims = op->getOutput()->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1;
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
for (int i = nDims - 1; i >= 0; --i) {
outputShape.data[i] = opOutputShape[i];
outputsize *= outputShape.data[i];
int nDims = op->getOutput()->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1;
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
for (int i = nDims - 1; i >= 0; --i) {
outputShape.data[i] = d_dim[i];
outputsize *= outputShape.data[i];
}
broadcastShape(a_dim, inputXShape, nDims, xSize);
broadcastShape(b_dim, inputYShape, nDims, ySize);
broadcastShape(c_dim, conditionShape, nDims, cSize);
whereKernel(dTypeIndex, inputXData, inputYData,
(uint8_t *)conditionData, outputData, nDims, outputsize,
inputXShape, inputYShape, conditionShape, outputShape,
xSize, ySize, cSize);
}
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
if (op->getDType() == DataType::Float32) {
whereKernel((float *)inputXData, (float *)inputYData,
(uint8_t *)conditionData, (float *)outputData, nDims,
outputsize, inputXShape, inputYShape, conditionShape,
outputShape, xSize, ySize, cSize);
} else if (op->getDType() == DataType::Float16) {
whereKernel((half *)inputXData, (half *)inputYData,
(uint8_t *)conditionData, (half *)outputData, nDims,
outputsize, inputXShape, inputYShape, conditionShape,
outputShape, xSize, ySize, cSize);
} else {
IT_ASSERT(false);
else {
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
int c[4] = {1, 1, 1, 1};
int d[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
std::copy(d_dim.begin(), d_dim.end(), d + (4 - d_dim.size()));
whereKernel(dTypeIndex, inputXData, inputYData,
(uint8_t *)conditionData, outputData, a[0], a[1], a[2],
a[3], b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3],
d[0], d[1], d[2], d[3]);
}
}
};

View File

@ -1,6 +1,109 @@
#include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include "utils/small_array.h"
const int repeat = 1;
template <typename T>
__global__ void
_whereKernel(void *inputX, void *inputY, const uint8_t *condition, void *output,
int a0, int a1, int a2, int a3, int b0, int b1, int b2, int b3,
int c0, int c1, int c2, int c3, int d0, int d1, int d2, int d3) {
int stride1 = d2 * d3;
int stride0 = d1 * stride1;
int n = d0 * stride0;
int index = threadIdx.x + blockIdx.x * blockDim.x;
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
for (int i = repeat * index; i < end; i++) {
int inputXIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int inputYIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int conditionIdx = (c0 * c1 * c2 * c3 == n ? i : 0);
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
bool cIdx = (c0 * c1 * c2 * c3 < n && c0 * c1 * c2 * c3 > 1);
if (aIdx || bIdx || cIdx) {
int d0_index = i / stride0;
int d1_index = (i % stride0) / stride1;
int d2_index = (i % stride1) / d3;
int d3_index = i % d3;
if (aIdx) {
int a0_index = d0_index % a0;
int a1_index = d1_index % a1;
int a2_index = d2_index % a2;
int a3_index = d3_index % a3;
inputXIdx = a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index;
}
if (bIdx) {
int b0_index = d0_index % b0;
int b1_index = d1_index % b1;
int b2_index = d2_index % b2;
int b3_index = d3_index % b3;
inputYIdx = b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
b2_index * b3 + b3_index;
}
if (cIdx) {
int c0_index = d0_index % c0;
int c1_index = d1_index % c1;
int c2_index = d2_index % c2;
int c3_index = d3_index % c3;
conditionIdx = c0_index * c1 * c2 * c3 + c1_index * c2 * c3 +
c2_index * c3 + c3_index;
}
}
((T *)output)[i] = condition[conditionIdx] ? ((T *)inputX)[inputXIdx]
: ((T *)inputY)[inputYIdx];
}
}
#define CASE(T) \
_whereKernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
inputX, inputY, condition, output, a0, a1, a2, a3, b0, b1, b2, b3, \
c0, c1, c2, c3, d0, d1, d2, d3);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
__device__ int inferIndex(infini::SmallArray inputShape,
infini::SmallArray outputShape, int nDims, int size,
int outputIdx) {
@ -19,11 +122,10 @@ __device__ int inferIndex(infini::SmallArray inputShape,
}
template <typename T>
__global__ void
_whereKernel(const T *inputX, const T *inputY, const uint8_t *condition,
T *output, int nDims, int outputsize,
infini::SmallArray inputXShape, infini::SmallArray inputYShape,
infini::SmallArray conditionShape, infini::SmallArray outputShape,
int xSize, int ySize, int cSize) {
_whereKernel(void *inputX, void *inputY, const uint8_t *condition, void *output,
int nDims, int outputsize, infini::SmallArray inputXShape,
infini::SmallArray inputYShape, infini::SmallArray conditionShape,
infini::SmallArray outputShape, int xSize, int ySize, int cSize) {
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (outputIdx < outputsize) {
@ -35,14 +137,74 @@ _whereKernel(const T *inputX, const T *inputY, const uint8_t *condition,
int inputYIdx =
inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx);
output[outputIdx] =
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
((T *)output)[outputIdx] = condition[conditionIdx]
? ((T *)inputX)[inputXIdx]
: ((T *)inputY)[inputYIdx];
}
}
#define CASECurrency(T) \
_whereKernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
inputX, inputY, condition, output, nDims, outputsize, inputXShape, \
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
#define SWITCHCurrency_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASECurrency(1) break; \
case 2: \
CASECurrency(2) break; \
case 3: \
CASECurrency(3) break; \
case 4: \
CASECurrency(4) break; \
case 5: \
CASECurrency(5) break; \
case 6: \
CASECurrency(6) break; \
case 7: \
CASECurrency(7) break; \
case 10: \
CASECurrency(10) break; \
case 11: \
CASECurrency(11) break; \
case 12: \
CASECurrency(12) break; \
case 13: \
CASECurrency(13) break; \
case 16: \
CASECurrency(16) break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
const uint8_t *condition, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3, int d0, int d1, int d2, int d3) {
int blocksize;
int outputsize = d0 * d1 * d2 * d3;
if (outputsize > 511 * repeat) {
blocksize = 1024;
} else if (outputsize > 255 * repeat) {
blocksize = 512;
} else if (outputsize > 127 * repeat) {
blocksize = 256;
} else if (outputsize > 63 * repeat) {
blocksize = 128;
} else if (outputsize > 31 * repeat) {
blocksize = 64;
} else {
blocksize = 32;
}
int gridsize = (outputsize + repeat * blocksize - 1) / (repeat * blocksize);
SWITCH_DTYPE(dTypeIndex)
}
namespace infini {
void whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims,
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
const uint8_t *condition, void *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize) {
@ -61,34 +223,8 @@ void whereKernel(const float *inputX, const float *inputY,
blocksize = 32;
}
int gridsize = (outputsize + blocksize - 1) / blocksize;
_whereKernel<float>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
}
void whereKernel(const half *inputX, const half *inputY,
const uint8_t *condition, half *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize) {
int blocksize;
if (outputsize > 511) {
blocksize = 1024;
} else if (outputsize > 255) {
blocksize = 512;
} else if (outputsize > 127) {
blocksize = 256;
} else if (outputsize > 63) {
blocksize = 128;
} else if (outputsize > 31) {
blocksize = 64;
} else {
blocksize = 32;
}
int gridsize = (outputsize + blocksize - 1) / blocksize;
_whereKernel<half>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
SWITCHCurrency_DTYPE(dTypeIndex)
}
} // namespace infini

View File

@ -84,6 +84,17 @@ void test_whereFp16(
}
TEST(CUDA_WhereFp32, run) {
test_whereFp32(
Shape{2, 2, 3, 1, 2},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7.,
8., 9., 10., 11., 12., 13., 14., 15.,
16., 17., 18., 19., 20., 21., 22., 23.},
Shape{2, 2, 3, 1, 2},
vector<float>{0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.},
Shape{2, 3, 1, 2}, vector<uint8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.,
0., 13., 14., 0., 0., 0., 18., 19., 0., 21., 22., 23.});
test_whereFp32(
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},