better transposed convreduce

This commit is contained in:
xxcclong 2023-04-23 21:36:25 +08:00
parent 777aebafc9
commit 830b28913c
1 changed files with 66 additions and 8 deletions

View File

@ -41,6 +41,56 @@ __global__ void conv2dreduce_kernel_(float *__restrict__ input,
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
}
}
__global__ void convTranspose2dreduce_kernel2_(
float *__restrict__ input, float *__restrict__ bias,
float *__restrict__ output, const bool PReLU, const int n, const int f,
const int h, const int w, const int oh, const int ow, const int r,
const int s, const int ph, const int pw, const int dh, const int dw,
const int sh, const int sw) {
int warp_id = (blockDim.x / 32) * blockIdx.x + threadIdx.x / 32;
int lane = threadIdx.x % 32;
int nid = warp_id / (f * oh * ow);
int fid = (warp_id - nid * (f * oh * ow)) / (oh * ow);
int hid = (warp_id - nid * (f * oh * ow) - fid * (oh * ow)) / ow;
int wid = warp_id % ow;
if (hid >= oh || wid >= ow || nid > n || fid > f)
return;
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk,
nchunck = h * hchunk;
float *nfinput = input + nid * nchunck + fid * fchunck;
// view as conv, the true ph and pw
int tph = r - ph - 1, tpw = s - pw - 1;
int th = (h - 1) * sh + 1, tw = (w - 1) * sw + 1;
float imm = 0.0;
int ihst = hid - tph;
int iwst = wid - tpw;
for (int idx = lane; idx < r * s; idx += 32) {
int ri = idx / s;
int si = idx % s;
int ihid = ihst + r - ri - 1;
int iwid = iwst + s - si - 1;
if (ihid >= 0 && ihid < th && iwid >= 0 && iwid < tw &&
(ihid % sh == 0) && (iwid % sw == 0)) {
imm += *(nfinput + (ihid / sh) * hchunk + (iwid / sw) * wchunk +
ri * s + si);
}
}
for (int k = 16; k > 0; k >>= 1) {
imm += __shfl_down_sync(0xffffffff, imm, k); // sum
}
if (lane == 0) {
if (bias) {
imm += bias[fid];
}
if (PReLU) {
imm = imm > 0.0 ? imm : 0.0;
}
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
}
}
__global__ void convTranspose2dreduce_kernel_(
float *__restrict__ input, float *__restrict__ bias,
@ -167,15 +217,23 @@ void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
ow, h, w);
} else {
// puts("why use this conv2dreduce");
block.x = 32;
block.y = 32;
int block_x_num = (oh + block.x - 1) / block.x;
int block_y_num = (ow + block.y - 1) / block.y;
grid.x = n * (block_x_num);
grid.y = f * (block_y_num);
convTranspose2dreduce_kernel_<<<grid, block, 0>>>(
// block.x = 32;
// block.y = 32;
// int block_x_num = (oh + block.x - 1) / block.x;
// int block_y_num = (ow + block.y - 1) / block.y;
// grid.x = n * (block_x_num);
// grid.y = f * (block_y_num);
// convTranspose2dreduce_kernel_<<<grid, block, 0>>>(
// input, bias, output, (bool)act, n, f, h, w, oh, ow, r, s, ph, pw,
// dh, dw, sh, sw, block_x_num, block_y_num);
block.x = 128;
block.y = 1;
grid.x = (n * f * ow * oh + block.x / 32 - 1) / (block.x / 32);
grid.y = 1;
convTranspose2dreduce_kernel2_<<<grid, block, 0>>>(
input, bias, output, (bool)act, n, f, h, w, oh, ow, r, s, ph, pw,
dh, dw, sh, sw, block_x_num, block_y_num);
dh, dw, sh, sw);
}
}
} // namespace infini