modified div_kernel

This commit is contained in:
xgqdut2016 2024-04-10 10:51:35 +08:00
parent aa1c3222ed
commit 2761d46737
1 changed files with 117 additions and 86 deletions

View File

@ -5,34 +5,42 @@
constexpr unsigned int num_threads() { return 32 * 4; } constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; } constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); } constexpr int block_work_size() { return thread_work_size() * num_threads(); }
const int repeat = 1;
template <class T> template <class T>
__global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2, __global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) { int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x; int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) { int stride1 = c2 * c3;
int c0_index = i / (c1 * c2 * c3); int stride0 = c1 * stride1;
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3); int n = c0 * stride0;
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3; int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3; for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0; bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
int a1_index = c1_index % a1; bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
int a2_index = c2_index % a2; if (aIdx || bIdx) {
int a3_index = c3_index % a3; int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0; xIdx = (c0_index % a0) * a1 * a2 * a3 +
int b1_index = c1_index % b1; (c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
int b2_index = c2_index % b2; c3_index % a3;
int b3_index = c3_index % b3; }
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + if (bIdx) {
a2_index * a3 + a3_index] /
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + yIdx = (c0_index % b0) * b1 * b2 * b3 +
b2_index * b3 + b3_index]; (c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
c3_index % b3;
}
}
((T *)z)[i] = ((T *)x)[xIdx] / ((T *)y)[yIdx];
} }
} }
@ -41,28 +49,36 @@ __global__ void _add_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) { int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x; int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) { int stride1 = c2 * c3;
int c0_index = i / (c1 * c2 * c3); int stride0 = c1 * stride1;
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3); int n = c0 * stride0;
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3; int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3; for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0; bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
int a1_index = c1_index % a1; bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
int a2_index = c2_index % a2; if (aIdx || bIdx) {
int a3_index = c3_index % a3; int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0; xIdx = (c0_index % a0) * a1 * a2 * a3 +
int b1_index = c1_index % b1; (c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
int b2_index = c2_index % b2; c3_index % a3;
int b3_index = c3_index % b3; }
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + if (bIdx) {
a2_index * a3 + a3_index] +
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + yIdx = (c0_index % b0) * b1 * b2 * b3 +
b2_index * b3 + b3_index]; (c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
c3_index % b3;
}
}
((T *)z)[i] = ((T *)x)[xIdx] + ((T *)y)[yIdx];
} }
} }
@ -71,29 +87,36 @@ __global__ void _pow_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) { int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x; int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) { int stride1 = c2 * c3;
int c0_index = i / (c1 * c2 * c3); int stride0 = c1 * stride1;
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3); int n = c0 * stride0;
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3; int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3; for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0; bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
int a1_index = c1_index % a1; bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
int a2_index = c2_index % a2; if (aIdx || bIdx) {
int a3_index = c3_index % a3; int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0; xIdx = (c0_index % a0) * a1 * a2 * a3 +
int b1_index = c1_index % b1; (c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
int b2_index = c2_index % b2; c3_index % a3;
int b3_index = c3_index % b3; }
((T *)z)[i] = if (bIdx) {
pow(((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index], yIdx = (c0_index % b0) * b1 * b2 * b3 +
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + (c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
b2_index * b3 + b3_index]); c3_index % b3;
}
}
((T *)z)[i] = pow(((T *)x)[xIdx], ((T *)y)[yIdx]);
} }
} }
@ -102,31 +125,36 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) { int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x; int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) { int stride1 = c2 * c3;
int c0_index = i / (c1 * c2 * c3); int stride0 = c1 * stride1;
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3); int n = c0 * stride0;
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3; int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3; for (int i = repeat * index; i < end; i++) {
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
int a0_index = c0_index % a0; bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
int a1_index = c1_index % a1; bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
int a2_index = c2_index % a2; if (aIdx || bIdx) {
int a3_index = c3_index % a3; int c0_index = i / stride0;
int c1_index = (i % stride0) / stride1;
int c2_index = (i % stride1) / c3;
int c3_index = i % c3;
if (aIdx) {
int b0_index = c0_index % b0; xIdx = (c0_index % a0) * a1 * a2 * a3 +
int b1_index = c1_index % b1; (c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
int b2_index = c2_index % b2; c3_index % a3;
int b3_index = c3_index % b3; }
((bool *)z)[i] = if (bIdx) {
((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index] < yIdx = (c0_index % b0) * b1 * b2 * b3 +
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + (c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
b2_index * b3 + b3_index] c3_index % b3;
? true }
: false; }
((bool *)z)[i] = ((T *)x)[xIdx] < ((T *)y)[yIdx] ? true : false;
} }
} }
@ -176,7 +204,6 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
default: \ default: \
IT_TODO_HALT(); \ IT_TODO_HALT(); \
} }
template <class T> template <class T>
__global__ void _div_const_kernel(void const *__restrict__ x, __global__ void _div_const_kernel(void const *__restrict__ x,
void const *__restrict__ y, void const *__restrict__ y,
@ -269,7 +296,8 @@ void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
SWITCH_DTYPE(div, dType) SWITCH_DTYPE(div, dType)
} }
void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
@ -278,7 +306,8 @@ void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
SWITCH_DTYPE(add, dType) SWITCH_DTYPE(add, dType)
} }
void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
@ -286,7 +315,8 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int c3) { int c3) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
if (dType == 1) { if (dType == 1) {
_pow_kernel<float> _pow_kernel<float>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
@ -324,7 +354,8 @@ void less_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int c3) { int c3) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize =
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
SWITCH_DTYPE(less, dType) SWITCH_DTYPE(less, dType)
} }