Add: improve conv2dreduce kernel

This commit is contained in:
Liyan Zheng 2022-10-22 13:26:41 +08:00
parent 67c06733e6
commit bef4c422a0
1 changed files with 30 additions and 14 deletions

View File

@ -9,18 +9,31 @@ conv2dreduce_kernel_(float *__restrict__ input, float *__restrict__ bias,
const int dh, const int dw, const int sh, const int sw) { const int dh, const int dw, const int sh, const int sw) {
// output shape: (n, oh, ow, f) // output shape: (n, oh, ow, f)
// input shape: (n, h, w, f, r, s) // input shape: (n, h, w, f, r, s)
int nid = blockIdx.x, fid = blockIdx.y; const int tid = blockIdx.x * blockDim.x + threadIdx.x;
int hid = threadIdx.x, wid = threadIdx.y; const int out_N_offset = h * w * f, out_H_offset = w * f, out_W_offset = f,
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk, nchunck = n * hchunk; out_F_offset = 1;
float *nfinput = input + nid * nchunck + fid * fchunck; const int num = out_N_offset * n;
if (nid < n && fid < f && hid < oh && wid < ow) { if (tid < num) {
// output index
int tmptid = tid;
const int nid = tmptid / out_N_offset;
tmptid -= nid * out_N_offset;
const int hid = tmptid / out_H_offset;
tmptid -= hid * out_H_offset;
const int wid = tmptid / out_W_offset;
tmptid -= wid * out_W_offset;
const int fid = tmptid / out_F_offset;
// Input index
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk,
nchunck = n * hchunk;
float *__restrict__ nfinput = input + nid * nchunck + fid * fchunck;
float imm = 0.0; float imm = 0.0;
int ihst = hid * sh - ph; const int ihst = hid * sh, iwst = wid * sw;
int iwst = wid * sw - pw;
for (int ri = 0; ri < r; ++ri) { for (int ri = 0; ri < r; ++ri) {
for (int si = 0; si < s; ++si) { for (int si = 0; si < s; ++si) {
int ihid = ihst + ri * dh; int ihid = ihst + (ri - r / 2) * dh;
int iwid = iwst + si * dw; int iwid = iwst + (si - s / 2) * dw;
if (ihid >= 0 && ihid < h && iwid >= 0 && iwid < w) { if (ihid >= 0 && ihid < h && iwid >= 0 && iwid < w) {
imm += *(nfinput + ihid * hchunk + iwid * wchunk + ri * s + imm += *(nfinput + ihid * hchunk + iwid * wchunk + ri * s +
si); si);
@ -33,7 +46,7 @@ conv2dreduce_kernel_(float *__restrict__ input, float *__restrict__ bias,
if (PReLU) { if (PReLU) {
imm = imm > 0.0 ? imm : paramReLU * imm; imm = imm > 0.0 ? imm : paramReLU * imm;
} }
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm; output[tid] = imm;
} }
} }
@ -83,11 +96,14 @@ void conv2dreduce_kernel(float *input, float *bias, float *output, bool PReLU,
float paramReLU, int n, int h, int w, int f, int r, float paramReLU, int n, int h, int w, int f, int r,
int s, int oh, int ow, int ph, int pw, int sh, int sw, int s, int oh, int ow, int ph, int pw, int sh, int sw,
int dh, int dw) { int dh, int dw) {
dim3 grid(n, f); IT_ASSERT(sh == 1 && sw == 1, "conv2dreduce_kernel only support sh=sw=1");
dim3 block(oh, ow); const int blocksize = 512;
const int gridsize = (n * f * oh * ow + blocksize - 1) / blocksize;
cudaStream_t stream(cudaStreamPerThread); cudaStream_t stream(cudaStreamPerThread);
conv2dreduce_kernel_<<<grid, block, 0, stream>>>(input, bias, output, PReLU, paramReLU, n, f, h, w, conv2dreduce_kernel_<<<gridsize, blocksize, 0, stream>>>(
oh, ow, r, s, ph, pw, dh, dw, sh, sw); input, bias, output, PReLU, paramReLU, n, f, h, w, oh, ow, r, s, ph, pw,
dh, dw, sh, sw);
} }
void convTranspose2dreduce_kernel(float *input, float *bias, float *output, void convTranspose2dreduce_kernel(float *input, float *bias, float *output,