forked from jiuyuan/InfiniTensor
modified div_kernel
This commit is contained in:
parent
aa1c3222ed
commit
2761d46737
|
@ -5,34 +5,42 @@
|
||||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||||
constexpr int thread_work_size() { return 4; }
|
constexpr int thread_work_size() { return 4; }
|
||||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||||
|
const int repeat = 1;
|
||||||
template <class T>
|
template <class T>
|
||||||
__global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
__global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||||
int c1, int c2, int c3) {
|
int c1, int c2, int c3) {
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
|
||||||
int n = c0 * c1 * c2 * c3;
|
|
||||||
|
|
||||||
for (int i = index; i < n; i += stride) {
|
int stride1 = c2 * c3;
|
||||||
int c0_index = i / (c1 * c2 * c3);
|
int stride0 = c1 * stride1;
|
||||||
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
|
int n = c0 * stride0;
|
||||||
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
|
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
|
||||||
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
|
for (int i = repeat * index; i < end; i++) {
|
||||||
|
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
|
||||||
|
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
|
||||||
|
|
||||||
int a0_index = c0_index % a0;
|
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
|
||||||
int a1_index = c1_index % a1;
|
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
|
||||||
int a2_index = c2_index % a2;
|
if (aIdx || bIdx) {
|
||||||
int a3_index = c3_index % a3;
|
int c0_index = i / stride0;
|
||||||
|
int c1_index = (i % stride0) / stride1;
|
||||||
|
int c2_index = (i % stride1) / c3;
|
||||||
|
int c3_index = i % c3;
|
||||||
|
if (aIdx) {
|
||||||
|
|
||||||
int b0_index = c0_index % b0;
|
xIdx = (c0_index % a0) * a1 * a2 * a3 +
|
||||||
int b1_index = c1_index % b1;
|
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
|
||||||
int b2_index = c2_index % b2;
|
c3_index % a3;
|
||||||
int b3_index = c3_index % b3;
|
}
|
||||||
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
if (bIdx) {
|
||||||
a2_index * a3 + a3_index] /
|
|
||||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
yIdx = (c0_index % b0) * b1 * b2 * b3 +
|
||||||
b2_index * b3 + b3_index];
|
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
|
||||||
|
c3_index % b3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
((T *)z)[i] = ((T *)x)[xIdx] / ((T *)y)[yIdx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,28 +49,36 @@ __global__ void _add_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||||
int c1, int c2, int c3) {
|
int c1, int c2, int c3) {
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
|
||||||
int n = c0 * c1 * c2 * c3;
|
|
||||||
|
|
||||||
for (int i = index; i < n; i += stride) {
|
int stride1 = c2 * c3;
|
||||||
int c0_index = i / (c1 * c2 * c3);
|
int stride0 = c1 * stride1;
|
||||||
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
|
int n = c0 * stride0;
|
||||||
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
|
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
|
||||||
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
|
for (int i = repeat * index; i < end; i++) {
|
||||||
|
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
|
||||||
|
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
|
||||||
|
|
||||||
int a0_index = c0_index % a0;
|
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
|
||||||
int a1_index = c1_index % a1;
|
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
|
||||||
int a2_index = c2_index % a2;
|
if (aIdx || bIdx) {
|
||||||
int a3_index = c3_index % a3;
|
int c0_index = i / stride0;
|
||||||
|
int c1_index = (i % stride0) / stride1;
|
||||||
|
int c2_index = (i % stride1) / c3;
|
||||||
|
int c3_index = i % c3;
|
||||||
|
if (aIdx) {
|
||||||
|
|
||||||
int b0_index = c0_index % b0;
|
xIdx = (c0_index % a0) * a1 * a2 * a3 +
|
||||||
int b1_index = c1_index % b1;
|
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
|
||||||
int b2_index = c2_index % b2;
|
c3_index % a3;
|
||||||
int b3_index = c3_index % b3;
|
}
|
||||||
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
if (bIdx) {
|
||||||
a2_index * a3 + a3_index] +
|
|
||||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
yIdx = (c0_index % b0) * b1 * b2 * b3 +
|
||||||
b2_index * b3 + b3_index];
|
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
|
||||||
|
c3_index % b3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
((T *)z)[i] = ((T *)x)[xIdx] + ((T *)y)[yIdx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,29 +87,36 @@ __global__ void _pow_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||||
int c1, int c2, int c3) {
|
int c1, int c2, int c3) {
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
|
||||||
int n = c0 * c1 * c2 * c3;
|
|
||||||
|
|
||||||
for (int i = index; i < n; i += stride) {
|
int stride1 = c2 * c3;
|
||||||
int c0_index = i / (c1 * c2 * c3);
|
int stride0 = c1 * stride1;
|
||||||
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
|
int n = c0 * stride0;
|
||||||
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
|
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
|
||||||
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
|
for (int i = repeat * index; i < end; i++) {
|
||||||
|
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
|
||||||
|
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
|
||||||
|
|
||||||
int a0_index = c0_index % a0;
|
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
|
||||||
int a1_index = c1_index % a1;
|
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
|
||||||
int a2_index = c2_index % a2;
|
if (aIdx || bIdx) {
|
||||||
int a3_index = c3_index % a3;
|
int c0_index = i / stride0;
|
||||||
|
int c1_index = (i % stride0) / stride1;
|
||||||
|
int c2_index = (i % stride1) / c3;
|
||||||
|
int c3_index = i % c3;
|
||||||
|
if (aIdx) {
|
||||||
|
|
||||||
int b0_index = c0_index % b0;
|
xIdx = (c0_index % a0) * a1 * a2 * a3 +
|
||||||
int b1_index = c1_index % b1;
|
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
|
||||||
int b2_index = c2_index % b2;
|
c3_index % a3;
|
||||||
int b3_index = c3_index % b3;
|
}
|
||||||
((T *)z)[i] =
|
if (bIdx) {
|
||||||
pow(((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
|
||||||
a2_index * a3 + a3_index],
|
yIdx = (c0_index % b0) * b1 * b2 * b3 +
|
||||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
|
||||||
b2_index * b3 + b3_index]);
|
c3_index % b3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
((T *)z)[i] = pow(((T *)x)[xIdx], ((T *)y)[yIdx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,31 +125,36 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||||
int c1, int c2, int c3) {
|
int c1, int c2, int c3) {
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
|
||||||
int n = c0 * c1 * c2 * c3;
|
|
||||||
|
|
||||||
for (int i = index; i < n; i += stride) {
|
int stride1 = c2 * c3;
|
||||||
int c0_index = i / (c1 * c2 * c3);
|
int stride0 = c1 * stride1;
|
||||||
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
|
int n = c0 * stride0;
|
||||||
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
|
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
|
||||||
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
|
for (int i = repeat * index; i < end; i++) {
|
||||||
|
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
|
||||||
|
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
|
||||||
|
|
||||||
int a0_index = c0_index % a0;
|
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
|
||||||
int a1_index = c1_index % a1;
|
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
|
||||||
int a2_index = c2_index % a2;
|
if (aIdx || bIdx) {
|
||||||
int a3_index = c3_index % a3;
|
int c0_index = i / stride0;
|
||||||
|
int c1_index = (i % stride0) / stride1;
|
||||||
|
int c2_index = (i % stride1) / c3;
|
||||||
|
int c3_index = i % c3;
|
||||||
|
if (aIdx) {
|
||||||
|
|
||||||
int b0_index = c0_index % b0;
|
xIdx = (c0_index % a0) * a1 * a2 * a3 +
|
||||||
int b1_index = c1_index % b1;
|
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
|
||||||
int b2_index = c2_index % b2;
|
c3_index % a3;
|
||||||
int b3_index = c3_index % b3;
|
}
|
||||||
((bool *)z)[i] =
|
if (bIdx) {
|
||||||
((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
|
||||||
a2_index * a3 + a3_index] <
|
yIdx = (c0_index % b0) * b1 * b2 * b3 +
|
||||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
|
||||||
b2_index * b3 + b3_index]
|
c3_index % b3;
|
||||||
? true
|
}
|
||||||
: false;
|
}
|
||||||
|
((bool *)z)[i] = ((T *)x)[xIdx] < ((T *)y)[yIdx] ? true : false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,7 +204,6 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
default: \
|
default: \
|
||||||
IT_TODO_HALT(); \
|
IT_TODO_HALT(); \
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
__global__ void _div_const_kernel(void const *__restrict__ x,
|
__global__ void _div_const_kernel(void const *__restrict__ x,
|
||||||
void const *__restrict__ y,
|
void const *__restrict__ y,
|
||||||
|
@ -269,7 +296,8 @@ void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int num = c0 * c1 * c2 * c3;
|
int num = c0 * c1 * c2 * c3;
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize =
|
||||||
|
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
|
||||||
SWITCH_DTYPE(div, dType)
|
SWITCH_DTYPE(div, dType)
|
||||||
}
|
}
|
||||||
void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
|
@ -278,7 +306,8 @@ void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int num = c0 * c1 * c2 * c3;
|
int num = c0 * c1 * c2 * c3;
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize =
|
||||||
|
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
|
||||||
SWITCH_DTYPE(add, dType)
|
SWITCH_DTYPE(add, dType)
|
||||||
}
|
}
|
||||||
void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
|
@ -286,7 +315,8 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
int c3) {
|
int c3) {
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int num = c0 * c1 * c2 * c3;
|
int num = c0 * c1 * c2 * c3;
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize =
|
||||||
|
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
|
||||||
if (dType == 1) {
|
if (dType == 1) {
|
||||||
_pow_kernel<float>
|
_pow_kernel<float>
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
|
@ -324,7 +354,8 @@ void less_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
int c3) {
|
int c3) {
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int num = c0 * c1 * c2 * c3;
|
int num = c0 * c1 * c2 * c3;
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize =
|
||||||
|
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
|
||||||
SWITCH_DTYPE(less, dType)
|
SWITCH_DTYPE(less, dType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue