Compare commits

...

5 Commits

Author SHA1 Message Date
xgqdut2016 7146294baa memcopy instead of special kernel 2024-05-06 14:49:39 +08:00
xgqdut2016 73e3f1fc6f add currency operator 2024-04-10 15:01:22 +08:00
xgqdut2016 86133c8d0a modified expand 2024-04-10 11:16:54 +08:00
xgqdut2016 2761d46737 modified div_kernel 2024-04-10 10:51:35 +08:00
xgqdut2016 aa1c3222ed modified transpose and where 2024-04-10 10:17:45 +08:00
11 changed files with 535 additions and 225 deletions

View File

@ -3,10 +3,11 @@
#include "operators/unary.h" #include "operators/unary.h"
#include "utils/small_array.h" #include "utils/small_array.h"
namespace infini { namespace infini {
void expandKernel(int dType, void *input, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3);
void expandKernel(int dType, void *input, void *output, int nDims, void expandKernel(int dType, void *input, void *output, int nDims,
int outputsize, SmallArray inputShape, int outputsize, SmallArray inputShape,
SmallArray outputShape); SmallArray outputShape);
void expandRowKernel(int dType, void *input, void *output, int n_rows, void expandRowKernel(int dType, void *input, void *output, int n_rows,
int row_len); int row_len);
}; // namespace infini }; // namespace infini

View File

@ -1,16 +1,14 @@
#pragma once #pragma once
#include "operators/unary.h" #include "operators/unary.h"
#include "utils/small_array.h" #include "utils/small_array.h"
namespace infini { namespace infini {
void whereKernel(const float *inputX, const float *inputY, void whereKernel(int dTypeIndex, void *inputX, void *inputY,
const uint8_t *condition, float *output, int nDims, const uint8_t *condition, void *output, int a0, int a1, int a2,
int outputsize, SmallArray inputXShape, SmallArray inputYShape, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
SmallArray conditionShape, SmallArray outputShape, int xSize, int c3, int d0, int d1, int d2, int d3);
int ySize, int cSize); void whereKernel(int dTypeIndex, void *inputX, void *inputY,
void whereKernel(const half *inputX, const half *inputY, const uint8_t *condition, void *output, int nDims,
const uint8_t *condition, half *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape, int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize, SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize); int ySize, int cSize);

View File

@ -5,34 +5,42 @@
constexpr unsigned int num_threads() { return 32 * 4; } constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; } constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); } constexpr int block_work_size() { return thread_work_size() * num_threads(); }
const int repeat = 1;
template <class T> template <class T>
__global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2, __global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) { int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x; int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) { int stride1 = c2 * c3;
int c0_index = i / (c1 * c2 * c3); int stride0 = c1 * stride1;
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3); int n = c0 * stride0;
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3; int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3; for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0; bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
int a1_index = c1_index % a1; bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
int a2_index = c2_index % a2; if (aIdx || bIdx) {
int a3_index = c3_index % a3; int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0; xIdx = (c0_index % a0) * a1 * a2 * a3 +
int b1_index = c1_index % b1; (c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
int b2_index = c2_index % b2; c3_index % a3;
int b3_index = c3_index % b3; }
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + if (bIdx) {
a2_index * a3 + a3_index] /
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + yIdx = (c0_index % b0) * b1 * b2 * b3 +
b2_index * b3 + b3_index]; (c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
c3_index % b3;
}
}
((T *)z)[i] = ((T *)x)[xIdx] / ((T *)y)[yIdx];
} }
} }
@ -41,28 +49,36 @@ __global__ void _add_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) { int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x; int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) { int stride1 = c2 * c3;
int c0_index = i / (c1 * c2 * c3); int stride0 = c1 * stride1;
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3); int n = c0 * stride0;
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3; int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3; for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0; bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
int a1_index = c1_index % a1; bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
int a2_index = c2_index % a2; if (aIdx || bIdx) {
int a3_index = c3_index % a3; int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0; xIdx = (c0_index % a0) * a1 * a2 * a3 +
int b1_index = c1_index % b1; (c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
int b2_index = c2_index % b2; c3_index % a3;
int b3_index = c3_index % b3; }
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + if (bIdx) {
a2_index * a3 + a3_index] +
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + yIdx = (c0_index % b0) * b1 * b2 * b3 +
b2_index * b3 + b3_index]; (c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
c3_index % b3;
}
}
((T *)z)[i] = ((T *)x)[xIdx] + ((T *)y)[yIdx];
} }
} }
@ -71,29 +87,36 @@ __global__ void _pow_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) { int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x; int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) { int stride1 = c2 * c3;
int c0_index = i / (c1 * c2 * c3); int stride0 = c1 * stride1;
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3); int n = c0 * stride0;
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3; int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3; for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0; bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
int a1_index = c1_index % a1; bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
int a2_index = c2_index % a2; if (aIdx || bIdx) {
int a3_index = c3_index % a3; int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0; xIdx = (c0_index % a0) * a1 * a2 * a3 +
int b1_index = c1_index % b1; (c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
int b2_index = c2_index % b2; c3_index % a3;
int b3_index = c3_index % b3; }
((T *)z)[i] = if (bIdx) {
pow(((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index], yIdx = (c0_index % b0) * b1 * b2 * b3 +
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + (c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
b2_index * b3 + b3_index]); c3_index % b3;
}
}
((T *)z)[i] = pow(((T *)x)[xIdx], ((T *)y)[yIdx]);
} }
} }
@ -102,31 +125,36 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) { int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x; int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) { int stride1 = c2 * c3;
int c0_index = i / (c1 * c2 * c3); int stride0 = c1 * stride1;
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3); int n = c0 * stride0;
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3; int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3; for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0; bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
int a1_index = c1_index % a1; bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
int a2_index = c2_index % a2; if (aIdx || bIdx) {
int a3_index = c3_index % a3; int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0; xIdx = (c0_index % a0) * a1 * a2 * a3 +
int b1_index = c1_index % b1; (c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
int b2_index = c2_index % b2; c3_index % a3;
int b3_index = c3_index % b3; }
((bool *)z)[i] = if (bIdx) {
((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index] < yIdx = (c0_index % b0) * b1 * b2 * b3 +
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + (c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
b2_index * b3 + b3_index] c3_index % b3;
? true }
: false; }
((bool *)z)[i] = ((T *)x)[xIdx] < ((T *)y)[yIdx] ? true : false;
} }
} }
@ -176,7 +204,6 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
default: \ default: \
IT_TODO_HALT(); \ IT_TODO_HALT(); \
} }
template <class T> template <class T>
__global__ void _div_const_kernel(void const *__restrict__ x, __global__ void _div_const_kernel(void const *__restrict__ x,
void const *__restrict__ y, void const *__restrict__ y,
@ -269,7 +296,8 @@ void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
SWITCH_DTYPE(div, dType) SWITCH_DTYPE(div, dType)
} }
void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
@ -278,7 +306,8 @@ void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
SWITCH_DTYPE(add, dType) SWITCH_DTYPE(add, dType)
} }
void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
@ -286,7 +315,8 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int c3) { int c3) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
if (dType == 1) { if (dType == 1) {
_pow_kernel<float> _pow_kernel<float>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
@ -324,7 +354,8 @@ void less_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int c3) { int c3) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
SWITCH_DTYPE(less, dType) SWITCH_DTYPE(less, dType)
} }

View File

@ -12,22 +12,33 @@ class ExpandCuda : public CudaKernelWithoutConfig {
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>()); void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
const auto &in_Shape = op->getInputs(0)->getDims(); // input shape auto a_dim = op->getInputs(0)->getDims();
const auto &out_Shape = op->getShape(); // output shape auto b_dim = op->getOutput()->getDims(); // output shape
SmallArray inputShape, outputShape;
int nDims = op->getInputs(0)->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1; // the length of the output vector after flatten
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = out_Shape[i];
inputShape.data[i] = in_Shape[i];
outputsize *= out_Shape[i];
}
const int dType = op->getDType().getIndex(); const int dType = op->getDType().getIndex();
expandKernel(dType, inputData, outputData, nDims, outputsize, if (a_dim.size() > 4 || b_dim.size() > 4) {
inputShape, outputShape); SmallArray inputShape, outputShape;
int nDims = op->getInputs(0)->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1; // the length of the output vector after flatten
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = b_dim[i];
inputShape.data[i] = a_dim[i];
outputsize *= b_dim[i];
}
const int dType = op->getDType().getIndex();
expandKernel(dType, inputData, outputData, nDims, outputsize,
inputShape, outputShape);
} else {
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
expandKernel(dType, inputData, outputData, a[0], a[1], a[2], a[3],
b[0], b[1], b[2], b[3]);
}
} }
}; };

View File

@ -6,7 +6,31 @@
constexpr unsigned int num_threads() { return 32 * 4; } constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; } constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); } constexpr int block_work_size() { return thread_work_size() * num_threads(); }
const int repeat = 1;
template <class T>
__global__ void _expandKernel(void *input, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride1 = b2 * b3;
int stride0 = b1 * stride1;
int n = b0 * stride0;
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
if (aIdx) {
int b0_index = i / stride0;
int b1_index = (i % stride0) / stride1;
int b2_index = (i % stride1) / b3;
int b3_index = i % b3;
xIdx = (b0_index % a0) * a1 * a2 * a3 + (b1_index % a1) * a2 * a3 +
(b2_index % a2) * a3 + b3_index % a3;
}
((T *)output)[i] = ((T *)input)[xIdx];
}
}
template <class T> template <class T>
__global__ void _expandKernel(void *input, void *output, int nDims, __global__ void _expandKernel(void *input, void *output, int nDims,
int outputsize, infini::SmallArray inputShape, int outputsize, infini::SmallArray inputShape,
@ -38,7 +62,6 @@ __global__ void _expandKernel(void *input, void *output, int nDims,
((T *)output)[outputIdx] = ((T *)input)[inputIdx]; ((T *)output)[outputIdx] = ((T *)input)[inputIdx];
} }
} }
template <class T> template <class T>
static __global__ void _expandRowKernel(void *__restrict__ dst, static __global__ void _expandRowKernel(void *__restrict__ dst,
void const *__restrict__ src) { void const *__restrict__ src) {
@ -50,9 +73,9 @@ static __global__ void _expandRowKernel(void *__restrict__ dst,
namespace infini { namespace infini {
#define CASE(T) \ #define CASE(T) \
_expandKernel<DT_CUDA<T>::t><<<gridsize, blocksize, \ _expandKernel<DT_CUDA<T>::t> \
0, CUDAStream::getCurrentStream()>>>( \ <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
input, output, nDims, outputsize, inputShape, outputShape); input, output, a0, a1, a2, a3, b0, b1, b2, b3);
#define SWITCH_DTYPE(DTYPE) \ #define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \ switch (DTYPE) { \
@ -96,14 +119,56 @@ namespace infini {
IT_TODO_HALT(); \ IT_TODO_HALT(); \
} }
void expandKernel(int dType, void *input, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3) {
int blocksize = block_work_size();
int outputsize = b0 * b1 * b2 * b3;
int gridsize = (outputsize + repeat * block_work_size() - 1) /
(repeat * block_work_size());
SWITCH_DTYPE(dType)
}
#define CASECurrency(T) \
_expandKernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
input, output, nDims, outputsize, inputShape, outputShape);
#define SWITCHCurrency_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASECurrency(1) break; \
case 2: \
CASECurrency(2) break; \
case 3: \
CASECurrency(3) break; \
case 4: \
CASECurrency(4) break; \
case 5: \
CASECurrency(5) break; \
case 6: \
CASECurrency(6) break; \
case 7: \
CASECurrency(7) break; \
case 10: \
CASECurrency(10) break; \
case 11: \
CASECurrency(11) break; \
case 12: \
CASECurrency(12) break; \
case 13: \
CASECurrency(13) break; \
case 16: \
CASECurrency(16) break; \
default: \
IT_TODO_HALT(); \
}
void expandKernel(int dType, void *input, void *output, int nDims, void expandKernel(int dType, void *input, void *output, int nDims,
int outputsize, SmallArray inputShape, int outputsize, SmallArray inputShape,
SmallArray outputShape) { SmallArray outputShape) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (outputsize + block_work_size() - 1) / block_work_size(); int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
SWITCH_DTYPE(dType) SWITCHCurrency_DTYPE(dType)
} }
#define CASE_ROW(T) \ #define CASE_ROW(T) \
_expandRowKernel<float> \ _expandRowKernel<float> \
<<<grid, block, 0, CUDAStream::getCurrentStream()>>>(output, input); <<<grid, block, 0, CUDAStream::getCurrentStream()>>>(output, input);
@ -150,7 +215,8 @@ void expandKernel(int dType, void *input, void *output, int nDims,
IT_TODO_HALT(); \ IT_TODO_HALT(); \
} }
// Optimization for expanding a row vector. The row length must be a multiple of 32 // Optimization for expanding a row vector. The row length must be a multiple of
// 32
void expandRowKernel(int dType, void *input, void *output, int n_rows, void expandRowKernel(int dType, void *input, void *output, int n_rows,
int row_len) { int row_len) {
// Factorize row_len: row_len = a x b x 32 (32 is the warp size), b<=32 // Factorize row_len: row_len = a x b x 32 (32 is the warp size), b<=32
@ -160,7 +226,8 @@ void expandRowKernel(int dType, void *input, void *output, int n_rows,
// block: b x 32 // block: b x 32
auto c = row_len / 32, b = c; auto c = row_len / 32, b = c;
if (b > 32) { if (b > 32) {
for (b = 32; c % b != 0; --b); for (b = 32; c % b != 0; --b)
;
} }
auto a = c / b; auto a = c / b;
dim3 grid(a, n_rows), block(32, b); dim3 grid(a, n_rows), block(32, b);

View File

@ -87,20 +87,7 @@ class matmulCublas : public Kernel {
beta_naive = 1.f; beta_naive = 1.f;
auto inC = op->getInputs(2); auto inC = op->getInputs(2);
auto out = op->getOutput(); auto out = op->getOutput();
SmallArray inputShape, outputShape;
int nDims = out->getRank();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
// FIXME(constroy): use size_t for outputsize.
int outputsize = 1; // the length of the output vector after flatten
int offset = nDims - inC->getRank();
for (int i = 0; i < offset; ++i)
inputShape.data[i] = 1;
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = out->getDims()[i];
outputsize *= outputShape.data[i];
if (i >= offset)
inputShape.data[i] = inC->getDims()[i - offset];
}
const int dType = dataType.getIndex(); const int dType = dataType.getIndex();
// Bias in linear layer is row vector of (1,n), n is the number of // Bias in linear layer is row vector of (1,n), n is the number of
@ -111,9 +98,40 @@ class matmulCublas : public Kernel {
out->size() / inC->getDims()[0], out->size() / inC->getDims()[0],
inC->getDims()[0]); inC->getDims()[0]);
} else { } else {
expandKernel(dType, inC->getRawDataPtr<void *>(), auto a_dim = out->getDims();
out->getRawDataPtr<void *>(), nDims, outputsize, auto b_dim = inC->getDims(); // output shape
inputShape, outputShape);
if (a_dim.size() > 4 || b_dim.size() > 4) {
SmallArray inputShape, outputShape;
int nDims = out->getRank();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
// FIXME(constroy): use size_t for outputsize.
int outputsize =
1; // the length of the output vector after flatten
int offset = nDims - inC->getRank();
for (int i = 0; i < offset; ++i)
inputShape.data[i] = 1;
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = out->getDims()[i];
outputsize *= outputShape.data[i];
if (i >= offset)
inputShape.data[i] = inC->getDims()[i - offset];
}
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), nDims,
outputsize, inputShape, outputShape);
} else {
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(),
a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(),
b + (4 - b_dim.size()));
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), a[0], a[1], a[2],
a[3], b[0], b[1], b[2], b[3]);
}
} }
} }
// TODO:use compute type // TODO:use compute type

View File

@ -16,31 +16,57 @@ class TransposeCuda : public CudaKernelWithoutConfig {
void *const outputData = output->getRawDataPtr<void *>(); void *const outputData = output->getRawDataPtr<void *>();
const auto &inputShape = input->getDims(); const auto &inputShape = input->getDims();
const auto &outputShape = output->getDims(); const auto &outputShape = output->getDims();
const int dType = op->getDType().getIndex();
const auto &perm = op->getPermute();
int size = input->size(); int size = input->size();
int nDims = input->getDims().size(); int nDims = input->getDims().size();
//----------------
// Compute strides bool condition = true;
SmallArray strides, buffer; int gnum = 0;
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); for (int i = 0; i < nDims; i++) {
int curStride = 1; if (inputShape[i] > 1) {
for (int i = nDims - 1; i >= 0; --i) { while (gnum < nDims) {
buffer.data[i] = curStride; if (outputShape[gnum] > 1) {
curStride *= inputShape[i]; gnum += 1;
} break;
for (int i = 0; i < nDims; ++i) { } else {
strides.data[i] = buffer.data[perm[i]]; gnum += 1;
}
}
if (inputShape[i] != outputShape[gnum - 1]) {
condition = false;
break;
}
}
} }
//----------------
if (condition) {
cudaMemcpyAsync(outputData, inputData, op->getInputs(0)->getBytes(),
cudaMemcpyDeviceToDevice,
CUDAStream::getCurrentStream());
SmallArray outputDims; } else {
for (int i = 0; i < nDims; ++i) { const auto &perm = op->getPermute();
outputDims.data[i] = outputShape[i];
}
const int dType = op->getDType().getIndex(); // Compute strides
transpose_kernel(dType, inputData, outputData, nDims, size, strides, SmallArray strides, buffer;
outputDims); IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int curStride = 1;
for (int i = nDims - 1; i >= 0; --i) {
buffer.data[i] = curStride;
curStride *= inputShape[i];
}
for (int i = 0; i < nDims; ++i) {
strides.data[i] = buffer.data[perm[i]];
}
SmallArray outputDims;
for (int i = 0; i < nDims; ++i) {
outputDims.data[i] = outputShape[i];
}
transpose_kernel(dType, inputData, outputData, nDims, size, strides,
outputDims);
}
} }
}; };

View File

@ -24,8 +24,8 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
} }
#define CASE(T) \ #define CASE(T) \
_transpose_kernel<DT_CUDA<T>::t> \ _transpose_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \ <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
(input, output, nDims, size, strides, outputShape); input, output, nDims, size, strides, outputShape);
#define SWITCH_DTYPE(DTYPE) \ #define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \ switch (DTYPE) { \

View File

@ -1,8 +1,8 @@
#include "operators/where.h" #include "operators/where.h"
#include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h" #include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "cuda/cuda_where.h" #include "cuda/cuda_where.h"
#include "utils/operator_utils.h"
namespace infini { namespace infini {
@ -15,39 +15,50 @@ class WhereCuda : public CudaKernelWithoutConfig {
void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>()); void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>()); void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>()); void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
const auto &opInputXShape = op->getInputs(0)->getDims();
const auto &opInputYShape = op->getInputs(1)->getDims();
const auto &opConditionShape = op->getInputs(2)->getDims();
const auto &opOutputShape = op->getOutput()->getDims();
const int xSize = op->getInputs(0)->getRank(); auto a_dim = op->getInputs(0)->getDims();
const int ySize = op->getInputs(1)->getRank(); auto b_dim = op->getInputs(1)->getDims();
const int cSize = op->getInputs(2)->getRank(); auto c_dim = op->getInputs(2)->getDims();
auto d_dim = op->getOutput()->getDims();
const int dTypeIndex = op->getDType().getIndex();
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4 ||
d_dim.size() > 4) {
const int xSize = op->getInputs(0)->getRank();
const int ySize = op->getInputs(1)->getRank();
const int cSize = op->getInputs(2)->getRank();
int nDims = op->getOutput()->getDims().size(); int nDims = op->getOutput()->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1; int outputsize = 1;
SmallArray inputXShape, inputYShape, conditionShape, outputShape; SmallArray inputXShape, inputYShape, conditionShape, outputShape;
for (int i = nDims - 1; i >= 0; --i) { for (int i = nDims - 1; i >= 0; --i) {
outputShape.data[i] = opOutputShape[i]; outputShape.data[i] = d_dim[i];
outputsize *= outputShape.data[i]; outputsize *= outputShape.data[i];
}
broadcastShape(a_dim, inputXShape, nDims, xSize);
broadcastShape(b_dim, inputYShape, nDims, ySize);
broadcastShape(c_dim, conditionShape, nDims, cSize);
whereKernel(dTypeIndex, inputXData, inputYData,
(uint8_t *)conditionData, outputData, nDims, outputsize,
inputXShape, inputYShape, conditionShape, outputShape,
xSize, ySize, cSize);
} }
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
if (op->getDType() == DataType::Float32) { else {
whereKernel((float *)inputXData, (float *)inputYData, int a[4] = {1, 1, 1, 1};
(uint8_t *)conditionData, (float *)outputData, nDims, int b[4] = {1, 1, 1, 1};
outputsize, inputXShape, inputYShape, conditionShape, int c[4] = {1, 1, 1, 1};
outputShape, xSize, ySize, cSize); int d[4] = {1, 1, 1, 1};
} else if (op->getDType() == DataType::Float16) {
whereKernel((half *)inputXData, (half *)inputYData, std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
(uint8_t *)conditionData, (half *)outputData, nDims, std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
outputsize, inputXShape, inputYShape, conditionShape, std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
outputShape, xSize, ySize, cSize); std::copy(d_dim.begin(), d_dim.end(), d + (4 - d_dim.size()));
} else {
IT_ASSERT(false); whereKernel(dTypeIndex, inputXData, inputYData,
(uint8_t *)conditionData, outputData, a[0], a[1], a[2],
a[3], b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3],
d[0], d[1], d[2], d[3]);
} }
} }
}; };

View File

@ -1,6 +1,109 @@
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include "utils/small_array.h" #include "utils/small_array.h"
const int repeat = 1;
template <typename T>
__global__ void
_whereKernel(void *inputX, void *inputY, const uint8_t *condition, void *output,
int a0, int a1, int a2, int a3, int b0, int b1, int b2, int b3,
int c0, int c1, int c2, int c3, int d0, int d1, int d2, int d3) {
int stride1 = d2 * d3;
int stride0 = d1 * stride1;
int n = d0 * stride0;
int index = threadIdx.x + blockIdx.x * blockDim.x;
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
for (int i = repeat * index; i < end; i++) {
int inputXIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int inputYIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int conditionIdx = (c0 * c1 * c2 * c3 == n ? i : 0);
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
bool cIdx = (c0 * c1 * c2 * c3 < n && c0 * c1 * c2 * c3 > 1);
if (aIdx || bIdx || cIdx) {
int d0_index = i / stride0;
int d1_index = (i % stride0) / stride1;
int d2_index = (i % stride1) / d3;
int d3_index = i % d3;
if (aIdx) {
int a0_index = d0_index % a0;
int a1_index = d1_index % a1;
int a2_index = d2_index % a2;
int a3_index = d3_index % a3;
inputXIdx = a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index;
}
if (bIdx) {
int b0_index = d0_index % b0;
int b1_index = d1_index % b1;
int b2_index = d2_index % b2;
int b3_index = d3_index % b3;
inputYIdx = b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
b2_index * b3 + b3_index;
}
if (cIdx) {
int c0_index = d0_index % c0;
int c1_index = d1_index % c1;
int c2_index = d2_index % c2;
int c3_index = d3_index % c3;
conditionIdx = c0_index * c1 * c2 * c3 + c1_index * c2 * c3 +
c2_index * c3 + c3_index;
}
}
((T *)output)[i] = condition[conditionIdx] ? ((T *)inputX)[inputXIdx]
: ((T *)inputY)[inputYIdx];
}
}
#define CASE(T) \
_whereKernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
inputX, inputY, condition, output, a0, a1, a2, a3, b0, b1, b2, b3, \
c0, c1, c2, c3, d0, d1, d2, d3);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
__device__ int inferIndex(infini::SmallArray inputShape, __device__ int inferIndex(infini::SmallArray inputShape,
infini::SmallArray outputShape, int nDims, int size, infini::SmallArray outputShape, int nDims, int size,
int outputIdx) { int outputIdx) {
@ -19,11 +122,10 @@ __device__ int inferIndex(infini::SmallArray inputShape,
} }
template <typename T> template <typename T>
__global__ void __global__ void
_whereKernel(const T *inputX, const T *inputY, const uint8_t *condition, _whereKernel(void *inputX, void *inputY, const uint8_t *condition, void *output,
T *output, int nDims, int outputsize, int nDims, int outputsize, infini::SmallArray inputXShape,
infini::SmallArray inputXShape, infini::SmallArray inputYShape, infini::SmallArray inputYShape, infini::SmallArray conditionShape,
infini::SmallArray conditionShape, infini::SmallArray outputShape, infini::SmallArray outputShape, int xSize, int ySize, int cSize) {
int xSize, int ySize, int cSize) {
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x; int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (outputIdx < outputsize) { if (outputIdx < outputsize) {
@ -35,14 +137,74 @@ _whereKernel(const T *inputX, const T *inputY, const uint8_t *condition,
int inputYIdx = int inputYIdx =
inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx); inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx);
output[outputIdx] = ((T *)output)[outputIdx] = condition[conditionIdx]
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx]; ? ((T *)inputX)[inputXIdx]
: ((T *)inputY)[inputYIdx];
} }
} }
#define CASECurrency(T) \
_whereKernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
inputX, inputY, condition, output, nDims, outputsize, inputXShape, \
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
#define SWITCHCurrency_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASECurrency(1) break; \
case 2: \
CASECurrency(2) break; \
case 3: \
CASECurrency(3) break; \
case 4: \
CASECurrency(4) break; \
case 5: \
CASECurrency(5) break; \
case 6: \
CASECurrency(6) break; \
case 7: \
CASECurrency(7) break; \
case 10: \
CASECurrency(10) break; \
case 11: \
CASECurrency(11) break; \
case 12: \
CASECurrency(12) break; \
case 13: \
CASECurrency(13) break; \
case 16: \
CASECurrency(16) break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
const uint8_t *condition, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3, int d0, int d1, int d2, int d3) {
int blocksize;
int outputsize = d0 * d1 * d2 * d3;
if (outputsize > 511 * repeat) {
blocksize = 1024;
} else if (outputsize > 255 * repeat) {
blocksize = 512;
} else if (outputsize > 127 * repeat) {
blocksize = 256;
} else if (outputsize > 63 * repeat) {
blocksize = 128;
} else if (outputsize > 31 * repeat) {
blocksize = 64;
} else {
blocksize = 32;
}
int gridsize = (outputsize + repeat * blocksize - 1) / (repeat * blocksize);
SWITCH_DTYPE(dTypeIndex)
}
namespace infini { void whereKernel(int dTypeIndex, void *inputX, void *inputY,
void whereKernel(const float *inputX, const float *inputY, const uint8_t *condition, void *output, int nDims,
const uint8_t *condition, float *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape, int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize, SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize) { int ySize, int cSize) {
@ -61,34 +223,8 @@ void whereKernel(const float *inputX, const float *inputY,
blocksize = 32; blocksize = 32;
} }
int gridsize = (outputsize + blocksize - 1) / blocksize; int gridsize = (outputsize + blocksize - 1) / blocksize;
_whereKernel<float>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( SWITCHCurrency_DTYPE(dTypeIndex)
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
}
void whereKernel(const half *inputX, const half *inputY,
const uint8_t *condition, half *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize) {
int blocksize;
if (outputsize > 511) {
blocksize = 1024;
} else if (outputsize > 255) {
blocksize = 512;
} else if (outputsize > 127) {
blocksize = 256;
} else if (outputsize > 63) {
blocksize = 128;
} else if (outputsize > 31) {
blocksize = 64;
} else {
blocksize = 32;
}
int gridsize = (outputsize + blocksize - 1) / blocksize;
_whereKernel<half>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
} }
} // namespace infini } // namespace infini

View File

@ -84,6 +84,17 @@ void test_whereFp16(
} }
TEST(CUDA_WhereFp32, run) { TEST(CUDA_WhereFp32, run) {
test_whereFp32(
Shape{2, 2, 3, 1, 2},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7.,
8., 9., 10., 11., 12., 13., 14., 15.,
16., 17., 18., 19., 20., 21., 22., 23.},
Shape{2, 2, 3, 1, 2},
vector<float>{0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.},
Shape{2, 3, 1, 2}, vector<uint8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.,
0., 13., 14., 0., 0., 0., 18., 19., 0., 21., 22., 23.});
test_whereFp32( test_whereFp32(
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},