From 830b28913ce36f46a556f20d83ecf7f88501ea1f Mon Sep 17 00:00:00 2001 From: xxcclong Date: Sun, 23 Apr 2023 21:36:25 +0800 Subject: [PATCH] better transposed convreduce --- src/kernels/cuda/conv2dreduce.cu | 74 ++++++++++++++++++++++++++++---- 1 file changed, 66 insertions(+), 8 deletions(-) diff --git a/src/kernels/cuda/conv2dreduce.cu b/src/kernels/cuda/conv2dreduce.cu index 6254fac0..a7026c16 100644 --- a/src/kernels/cuda/conv2dreduce.cu +++ b/src/kernels/cuda/conv2dreduce.cu @@ -41,6 +41,56 @@ __global__ void conv2dreduce_kernel_(float *__restrict__ input, output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm; } } +__global__ void convTranspose2dreduce_kernel2_( + float *__restrict__ input, float *__restrict__ bias, + float *__restrict__ output, const bool PReLU, const int n, const int f, + const int h, const int w, const int oh, const int ow, const int r, + const int s, const int ph, const int pw, const int dh, const int dw, + const int sh, const int sw) { + int warp_id = (blockDim.x / 32) * blockIdx.x + threadIdx.x / 32; + int lane = threadIdx.x % 32; + int nid = warp_id / (f * oh * ow); + int fid = (warp_id - nid * (f * oh * ow)) / (oh * ow); + int hid = (warp_id - nid * (f * oh * ow) - fid * (oh * ow)) / ow; + int wid = warp_id % ow; + if (hid >= oh || wid >= ow || nid > n || fid > f) + return; + + const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk, + nchunck = h * hchunk; + float *nfinput = input + nid * nchunck + fid * fchunck; + // view as conv, the true ph and pw + int tph = r - ph - 1, tpw = s - pw - 1; + int th = (h - 1) * sh + 1, tw = (w - 1) * sw + 1; + + float imm = 0.0; + int ihst = hid - tph; + int iwst = wid - tpw; + for (int idx = lane; idx < r * s; idx += 32) { + int ri = idx / s; + int si = idx % s; + int ihid = ihst + r - ri - 1; + int iwid = iwst + s - si - 1; + if (ihid >= 0 && ihid < th && iwid >= 0 && iwid < tw && + (ihid % sh == 0) && (iwid % sw == 0)) { + imm += *(nfinput + (ihid / sh) * hchunk + (iwid / sw) * wchunk + + ri * s + si); + } + } + + for (int k = 16; k > 0; k >>= 1) { + imm += __shfl_down_sync(0xffffffff, imm, k); // sum + } + if (lane == 0) { + if (bias) { + imm += bias[fid]; + } + if (PReLU) { + imm = imm > 0.0 ? imm : 0.0; + } + output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm; + } +} __global__ void convTranspose2dreduce_kernel_( float *__restrict__ input, float *__restrict__ bias, @@ -167,15 +217,23 @@ void convTranspose2dreduce_kernel(float *input, float *bias, float *output, ow, h, w); } else { // puts("why use this conv2dreduce"); - block.x = 32; - block.y = 32; - int block_x_num = (oh + block.x - 1) / block.x; - int block_y_num = (ow + block.y - 1) / block.y; - grid.x = n * (block_x_num); - grid.y = f * (block_y_num); - convTranspose2dreduce_kernel_<<>>( + // block.x = 32; + // block.y = 32; + // int block_x_num = (oh + block.x - 1) / block.x; + // int block_y_num = (ow + block.y - 1) / block.y; + // grid.x = n * (block_x_num); + // grid.y = f * (block_y_num); + // convTranspose2dreduce_kernel_<<>>( + // input, bias, output, (bool)act, n, f, h, w, oh, ow, r, s, ph, pw, + // dh, dw, sh, sw, block_x_num, block_y_num); + + block.x = 128; + block.y = 1; + grid.x = (n * f * ow * oh + block.x / 32 - 1) / (block.x / 32); + grid.y = 1; + convTranspose2dreduce_kernel2_<<>>( input, bias, output, (bool)act, n, f, h, w, oh, ow, r, s, ph, pw, - dh, dw, sh, sw, block_x_num, block_y_num); + dh, dw, sh, sw); } } } // namespace infini