diff --git a/src/kernels/cuda/attention.cu b/src/kernels/cuda/attention.cu index 8af835e6..ba4b5d8c 100644 --- a/src/kernels/cuda/attention.cu +++ b/src/kernels/cuda/attention.cu @@ -1,90 +1,107 @@ #include "cuda/cuda_common.h" +const int Rq = 4; +const int Rv = 8; // 必须是4的倍数 +const int Br = 16; +const int Bc = 16; +const int Bk = 4; // 必须是4的倍数 + template __device__ void matmulRQK(const float *__restrict inputQ, - const float *__restrict inputK, float *Qds, - float *Kds, int N, int d, int width, int indQ, - int indK, float *regLeft, float *val) { - + const float *__restrict inputK, float *shareQK, + float *shareVK, int N, int d, int width, int indQ, + int indK, float *val) { + float a[4]; for (int ph = 0; ph < width; ph++) { - if (threadIdx.y < Bc) { - Kds[threadIdx.y * Bc + threadIdx.x] = 0.0f; - } - if (threadIdx.y < Bc) { - if (indK < N && threadIdx.y + ph * Bc < d) { - Kds[threadIdx.y * Bc + threadIdx.x] = - inputK[indK * d + threadIdx.y + ph * Bc]; + for (int index_k = 0; index_k < Bk; index_k++) { + (float4 &)a[0] = (float4 &) + inputK[(indK + index_k) * d + (threadIdx.y + ph * Bc) * Bk]; + for (int idk = 0; idk < Bk; idk++) { + if (threadIdx.y < Bc) { + shareVK[(threadIdx.y * Bk + idk) * Bc * Bk + + threadIdx.x * Bk + index_k] = a[idk]; + if (indK + index_k >= N || + (threadIdx.y + ph * Bc) * Bk + idk >= d) { + + shareVK[(threadIdx.y * Bk + idk) * Bc * Bk + + threadIdx.x * Bk + index_k] = 0.0f; + } + } } } for (int index_q = 0; index_q < Rq; index_q++) { - if (indQ + index_q < N && threadIdx.x + ph * Bc < d) { - Qds[(threadIdx.y * Rq + index_q) * Bc + threadIdx.x] = - inputQ[(indQ + index_q) * d + threadIdx.x + ph * Bc]; - } else { - Qds[(threadIdx.y * Rq + index_q) * Bc + threadIdx.x] = 0.0f; + (float4 &)shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk + + threadIdx.x * Bk] = (float4 &) + inputQ[(indQ + index_q) * d + (threadIdx.x + ph * Bc) * Bk]; + for (int idk = 0; idk < Bk; idk++) { + if (indQ + index_q >= N || + (threadIdx.x + ph * Bc) * Bk + idk >= d) { + shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk + + threadIdx.x * Bk + idk] = 0.0f; + } } } __syncthreads(); - for (int index = 0; index < Bc; index++) { + + for (int index = 0; index < Bc * Bk; index++) { for (int index_q = 0; index_q < Rq; index_q++) { - regLeft[index_q] = - Qds[(threadIdx.y * Rq + index_q) * Bc + index]; - val[index_q] = - std::fma(regLeft[index_q], Kds[index * Bc + threadIdx.x], - val[index_q]); + for (int index_k = 0; index_k < Bk; index_k++) { + val[index_q * Bk + index_k] = std::fma( + shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk + index], + shareVK[index * Bc * Bk + threadIdx.x * Bk + index_k], + val[index_q * Bk + index_k]); + } } } __syncthreads(); } } template -__device__ void matmulSV(float *sumQK, const float *__restrict inputV, - float *Vds, int N, int d, int j, int indQ, int indK, - int indV, float *regLeft, float *regRight, float *val, - float *newMax, float *sumSV) { - +__device__ void matmulSV(float *shareQK, const float *__restrict inputV, + float *shareVK, int N, int d, int j, int indQ, + int indK, int indV, float *val, float *newMax, + float *sumSV) { if (threadIdx.y < Bc) { - for (int index_v = 0; index_v < Rv; index_v++) { - if (threadIdx.y + j * Bc < N && indV + index_v < d) { - Vds[threadIdx.y * Bc * Rv + threadIdx.x * Rv + index_v] = - inputV[(threadIdx.y + j * Bc) * d + indV + index_v]; - } else { - Vds[threadIdx.y * Bc * Rv + threadIdx.x * Rv + index_v] = 0.0f; + for (int index_k = 0; index_k < Bk; index_k++) { + for (int id = 0; id < (int)(Rv / 4); id++) { + (float4 &)shareVK[(threadIdx.y * Bk + index_k) * Bc * Rv + + threadIdx.x * Rv + id * 4] = (float4 &) + inputV[((threadIdx.y + j * Bc) * Bk + index_k) * d + indV + + id * 4]; + } + for (int index_v = 0; index_v < Rv; index_v++) { + if ((threadIdx.y + j * Bc) * Bk + index_k >= N || + indV + index_v >= d) { + shareVK[(threadIdx.y * Bk + index_k) * Bc * Rv + + threadIdx.x * Rv + index_v] = 0.0f; + } } } } for (int index_q = 0; index_q < Rq; index_q++) { - if (indQ + index_q < N && indK < N) { - sumQK[(threadIdx.y * Rq + index_q) * Bc + threadIdx.x] = - __expf(val[index_q] - newMax[index_q]); - } else { + for (int index_k = 0; index_k < Bk; index_k++) { + if (indQ + index_q < N && indK + index_k < N) { + shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk + + threadIdx.x * Bk + index_k] = + __expf(val[index_q * Bk + index_k] - newMax[index_q]); + } else { - sumQK[(threadIdx.y * Rq + index_q) * Bc + threadIdx.x] = 0.0f; + shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk + + threadIdx.x * Bk + index_k] = 0.0f; + } } } __syncthreads(); - for (int phc = 0; phc < Bc; phc++) { - for (int index_q = 0; index_q < Rq; index_q++) { - regLeft[index_q] = sumQK[(threadIdx.y * Rq + index_q) * Bc + phc]; - } - //-------- - for (int index_v = 0; index_v < Rv; index_v++) { - regRight[index_v] = Vds[phc * Bc * Rv + threadIdx.x * Rv + index_v]; - } - //-------- + + for (int phc = 0; phc < Bc * Bk; phc++) { for (int index_q = 0; index_q < Rq; index_q++) { + for (int index_v = 0; index_v < Rv; index_v++) { sumSV[index_q * Rv + index_v] += - regLeft[index_q] * regRight[index_v]; + shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk + phc] * + shareVK[phc * Bc * Rv + threadIdx.x * Rv + index_v]; } } - // for (int index_q = 0; index_q < Rq; index_q++) { - // for (int index_v = 0; index_v < Rv; index_v++) { - // sumSV[index_q * Rv + index_v] += - // sumQK[(threadIdx.y * Rq + index_q) * Bc + phc] * - // Vds[phc * Bc * Rv + threadIdx.x * Rv + index_v]; - // } - // } } } template struct SumOp { @@ -99,7 +116,7 @@ template struct MaxOp { } }; template