From bef4c422a0b90fdedc378fc545f7aa6c89e3c5c8 Mon Sep 17 00:00:00 2001 From: Liyan Zheng Date: Sat, 22 Oct 2022 13:26:41 +0800 Subject: [PATCH] Add: improve conv2dreduce kernel --- src/kernels/cuda/conv2dreduce.cu | 44 ++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/kernels/cuda/conv2dreduce.cu b/src/kernels/cuda/conv2dreduce.cu index 7402b940..fa7285d2 100644 --- a/src/kernels/cuda/conv2dreduce.cu +++ b/src/kernels/cuda/conv2dreduce.cu @@ -9,18 +9,31 @@ conv2dreduce_kernel_(float *__restrict__ input, float *__restrict__ bias, const int dh, const int dw, const int sh, const int sw) { // output shape: (n, oh, ow, f) // input shape: (n, h, w, f, r, s) - int nid = blockIdx.x, fid = blockIdx.y; - int hid = threadIdx.x, wid = threadIdx.y; - const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk, nchunck = n * hchunk; - float *nfinput = input + nid * nchunck + fid * fchunck; - if (nid < n && fid < f && hid < oh && wid < ow) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int out_N_offset = h * w * f, out_H_offset = w * f, out_W_offset = f, + out_F_offset = 1; + const int num = out_N_offset * n; + if (tid < num) { + // output index + int tmptid = tid; + const int nid = tmptid / out_N_offset; + tmptid -= nid * out_N_offset; + const int hid = tmptid / out_H_offset; + tmptid -= hid * out_H_offset; + const int wid = tmptid / out_W_offset; + tmptid -= wid * out_W_offset; + const int fid = tmptid / out_F_offset; + + // Input index + const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk, + nchunck = n * hchunk; + float *__restrict__ nfinput = input + nid * nchunck + fid * fchunck; float imm = 0.0; - int ihst = hid * sh - ph; - int iwst = wid * sw - pw; + const int ihst = hid * sh, iwst = wid * sw; for (int ri = 0; ri < r; ++ri) { for (int si = 0; si < s; ++si) { - int ihid = ihst + ri * dh; - int iwid = iwst + si * dw; + int ihid = ihst + (ri - r / 2) * dh; + int iwid = iwst + (si - s / 2) * dw; if (ihid >= 0 && ihid < h && iwid >= 0 && iwid < w) { imm += *(nfinput + ihid * hchunk + iwid * wchunk + ri * s + si); @@ -33,7 +46,7 @@ conv2dreduce_kernel_(float *__restrict__ input, float *__restrict__ bias, if (PReLU) { imm = imm > 0.0 ? imm : paramReLU * imm; } - output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm; + output[tid] = imm; } } @@ -83,11 +96,14 @@ void conv2dreduce_kernel(float *input, float *bias, float *output, bool PReLU, float paramReLU, int n, int h, int w, int f, int r, int s, int oh, int ow, int ph, int pw, int sh, int sw, int dh, int dw) { - dim3 grid(n, f); - dim3 block(oh, ow); + IT_ASSERT(sh == 1 && sw == 1, "conv2dreduce_kernel only support sh=sw=1"); + const int blocksize = 512; + const int gridsize = (n * f * oh * ow + blocksize - 1) / blocksize; + cudaStream_t stream(cudaStreamPerThread); - conv2dreduce_kernel_<<>>(input, bias, output, PReLU, paramReLU, n, f, h, w, - oh, ow, r, s, ph, pw, dh, dw, sh, sw); + conv2dreduce_kernel_<<>>( + input, bias, output, PReLU, paramReLU, n, f, h, w, oh, ow, r, s, ph, pw, + dh, dw, sh, sw); } void convTranspose2dreduce_kernel(float *input, float *bias, float *output,