fast float4

This commit is contained in:
xgqdut2016 2024-05-29 15:21:54 +08:00
parent a66ff430ec
commit 80cd1c951e
1 changed files with 140 additions and 142 deletions

View File

@ -1,90 +1,107 @@
#include "cuda/cuda_common.h"
const int Rq = 4;
const int Rv = 8; // 必须是4的倍数
const int Br = 16;
const int Bc = 16;
const int Bk = 4; // 必须是4的倍数
template <int Br, int Bc, int Rq>
__device__ void matmulRQK(const float *__restrict inputQ,
const float *__restrict inputK, float *Qds,
float *Kds, int N, int d, int width, int indQ,
int indK, float *regLeft, float *val) {
const float *__restrict inputK, float *shareQK,
float *shareVK, int N, int d, int width, int indQ,
int indK, float *val) {
float a[4];
for (int ph = 0; ph < width; ph++) {
for (int index_k = 0; index_k < Bk; index_k++) {
(float4 &)a[0] = (float4 &)
inputK[(indK + index_k) * d + (threadIdx.y + ph * Bc) * Bk];
for (int idk = 0; idk < Bk; idk++) {
if (threadIdx.y < Bc) {
Kds[threadIdx.y * Bc + threadIdx.x] = 0.0f;
shareVK[(threadIdx.y * Bk + idk) * Bc * Bk +
threadIdx.x * Bk + index_k] = a[idk];
if (indK + index_k >= N ||
(threadIdx.y + ph * Bc) * Bk + idk >= d) {
shareVK[(threadIdx.y * Bk + idk) * Bc * Bk +
threadIdx.x * Bk + index_k] = 0.0f;
}
}
if (threadIdx.y < Bc) {
if (indK < N && threadIdx.y + ph * Bc < d) {
Kds[threadIdx.y * Bc + threadIdx.x] =
inputK[indK * d + threadIdx.y + ph * Bc];
}
}
for (int index_q = 0; index_q < Rq; index_q++) {
if (indQ + index_q < N && threadIdx.x + ph * Bc < d) {
Qds[(threadIdx.y * Rq + index_q) * Bc + threadIdx.x] =
inputQ[(indQ + index_q) * d + threadIdx.x + ph * Bc];
} else {
Qds[(threadIdx.y * Rq + index_q) * Bc + threadIdx.x] = 0.0f;
(float4 &)shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk +
threadIdx.x * Bk] = (float4 &)
inputQ[(indQ + index_q) * d + (threadIdx.x + ph * Bc) * Bk];
for (int idk = 0; idk < Bk; idk++) {
if (indQ + index_q >= N ||
(threadIdx.x + ph * Bc) * Bk + idk >= d) {
shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk +
threadIdx.x * Bk + idk] = 0.0f;
}
}
}
__syncthreads();
for (int index = 0; index < Bc; index++) {
for (int index = 0; index < Bc * Bk; index++) {
for (int index_q = 0; index_q < Rq; index_q++) {
regLeft[index_q] =
Qds[(threadIdx.y * Rq + index_q) * Bc + index];
val[index_q] =
std::fma(regLeft[index_q], Kds[index * Bc + threadIdx.x],
val[index_q]);
for (int index_k = 0; index_k < Bk; index_k++) {
val[index_q * Bk + index_k] = std::fma(
shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk + index],
shareVK[index * Bc * Bk + threadIdx.x * Bk + index_k],
val[index_q * Bk + index_k]);
}
}
}
__syncthreads();
}
}
template <int Br, int Bc, int Rq, int Rv>
__device__ void matmulSV(float *sumQK, const float *__restrict inputV,
float *Vds, int N, int d, int j, int indQ, int indK,
int indV, float *regLeft, float *regRight, float *val,
float *newMax, float *sumSV) {
__device__ void matmulSV(float *shareQK, const float *__restrict inputV,
float *shareVK, int N, int d, int j, int indQ,
int indK, int indV, float *val, float *newMax,
float *sumSV) {
if (threadIdx.y < Bc) {
for (int index_k = 0; index_k < Bk; index_k++) {
for (int id = 0; id < (int)(Rv / 4); id++) {
(float4 &)shareVK[(threadIdx.y * Bk + index_k) * Bc * Rv +
threadIdx.x * Rv + id * 4] = (float4 &)
inputV[((threadIdx.y + j * Bc) * Bk + index_k) * d + indV +
id * 4];
}
for (int index_v = 0; index_v < Rv; index_v++) {
if (threadIdx.y + j * Bc < N && indV + index_v < d) {
Vds[threadIdx.y * Bc * Rv + threadIdx.x * Rv + index_v] =
inputV[(threadIdx.y + j * Bc) * d + indV + index_v];
} else {
Vds[threadIdx.y * Bc * Rv + threadIdx.x * Rv + index_v] = 0.0f;
if ((threadIdx.y + j * Bc) * Bk + index_k >= N ||
indV + index_v >= d) {
shareVK[(threadIdx.y * Bk + index_k) * Bc * Rv +
threadIdx.x * Rv + index_v] = 0.0f;
}
}
}
}
for (int index_q = 0; index_q < Rq; index_q++) {
if (indQ + index_q < N && indK < N) {
sumQK[(threadIdx.y * Rq + index_q) * Bc + threadIdx.x] =
__expf(val[index_q] - newMax[index_q]);
for (int index_k = 0; index_k < Bk; index_k++) {
if (indQ + index_q < N && indK + index_k < N) {
shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk +
threadIdx.x * Bk + index_k] =
__expf(val[index_q * Bk + index_k] - newMax[index_q]);
} else {
sumQK[(threadIdx.y * Rq + index_q) * Bc + threadIdx.x] = 0.0f;
shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk +
threadIdx.x * Bk + index_k] = 0.0f;
}
}
}
__syncthreads();
for (int phc = 0; phc < Bc; phc++) {
for (int index_q = 0; index_q < Rq; index_q++) {
regLeft[index_q] = sumQK[(threadIdx.y * Rq + index_q) * Bc + phc];
}
//--------
for (int index_v = 0; index_v < Rv; index_v++) {
regRight[index_v] = Vds[phc * Bc * Rv + threadIdx.x * Rv + index_v];
}
//--------
for (int phc = 0; phc < Bc * Bk; phc++) {
for (int index_q = 0; index_q < Rq; index_q++) {
for (int index_v = 0; index_v < Rv; index_v++) {
sumSV[index_q * Rv + index_v] +=
regLeft[index_q] * regRight[index_v];
shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk + phc] *
shareVK[phc * Bc * Rv + threadIdx.x * Rv + index_v];
}
}
// for (int index_q = 0; index_q < Rq; index_q++) {
// for (int index_v = 0; index_v < Rv; index_v++) {
// sumSV[index_q * Rv + index_v] +=
// sumQK[(threadIdx.y * Rq + index_q) * Bc + phc] *
// Vds[phc * Bc * Rv + threadIdx.x * Rv + index_v];
// }
// }
}
}
template <typename T> struct SumOp {
@ -99,7 +116,7 @@ template <typename T> struct MaxOp {
}
};
template <template <typename> class ReductionOp, typename T,
int thread_group_width = warpSize>
int thread_group_width = 32>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
@ -114,112 +131,91 @@ __global__ void _attentionKernel(const float *__restrict inputQ,
const float *__restrict inputV, int N, int d,
float *__restrict output) {
__shared__ float sumQK[Rq * Br * Bc];
float sumSV[Rq * Rv];
__shared__ float block_max[Rq][Br];
__shared__ float block_sum[Rq][Br];
__shared__ float shareQK[Rq * Br * Bc * Bk];
__shared__ float shareVK[Bk * Bc * Bc * Rv];
float sumSV[Rq * Rv] = {0.0f};
float newMax[Rq];
float oldMax[Rq];
float newSum[Rq] = {0.0f};
float val[Rq * Bk];
int indV = Rv * (threadIdx.x + blockIdx.x * blockDim.x);
int indQ = Rq * (threadIdx.y + blockIdx.y * blockDim.y);
float newMax[Rq];
float oldMax[Rq];
float newSum[Rq];
float out[Rq * Rv];
for (int index_q = 0; index_q < Rq; index_q++) {
newMax[index_q] = -__FLT_MAX__;
oldMax[index_q] = -__FLT_MAX__;
newSum[index_q] = 0.0f;
for (int index_v = 0; index_v < Rv; index_v++) {
out[index_q * Rv + index_v] = 0.0f;
}
}
float regTmp[Rq];
int Tc = (N + Bc - 1) / Bc;
__shared__ float Vds[Bc * Bc * Rv];
__shared__ float Qds[Rq * Br * Bc];
__shared__ float Kds[Bc * Bc];
float regLeft[Rq];
float regRight[Rv];
float val[Rq];
int width = (d + Bc - 1) / Bc;
int Tc = (N + Bc * Bk - 1) / (Bc * Bk);
int width = (d + Bc * Bk - 1) / (Bc * Bk);
for (int j = 0; j < Tc; j++) {
int indK = threadIdx.x + j * Bc;
int indK = Bk * (threadIdx.x + j * Bc);
for (int index_q = 0; index_q < Rq; index_q++) {
val[index_q] = 0.0f;
for (int index_k = 0; index_k < Bk; index_k++) {
val[index_q * Bk + index_k] = 0.0f;
}
matmulRQK<Br, Bc, Rq>(inputQ, inputK, Qds, Kds, N, d, width, indQ, indK,
regLeft, val);
}
matmulRQK<Br, Bc, Rq>(inputQ, inputK, shareQK, shareVK, N, d, width,
indQ, indK, val);
for (int index_q = 0; index_q < Rq; index_q++) {
if (indQ + index_q < N && indK < N) {
float tmpReduceMax = -__FLT_MAX__;
for (int index_k = 0; index_k < Bk; index_k++) {
if (indQ + index_q < N && indK + index_k < N) {
regTmp[index_q] = val[index_q];
} else {
regTmp[index_q] = -__FLT_MAX__;
tmpReduceMax =
max(tmpReduceMax, val[index_q * Bk + index_k]);
}
}
__syncthreads();
// softmax reduce
for (int index_q = 0; index_q < Rq; index_q++) {
regTmp[index_q] = WarpAllReduce<MaxOp, float, Bc>(regTmp[index_q]);
tmpReduceMax = WarpAllReduce<MaxOp, float, Bc>(tmpReduceMax);
if (threadIdx.x == 0) {
block_max[index_q][threadIdx.y] = regTmp[index_q];
shareQK[threadIdx.y * Rq + index_q] = tmpReduceMax;
}
__syncthreads();
float tmpReduceSum = 0.0f;
for (int index_k = 0; index_k < Bk; index_k++) {
if (indQ + index_q < N && indK + index_k < N) {
tmpReduceSum += __expf(val[index_q * Bk + index_k] -
shareQK[threadIdx.y * Rq + index_q]);
}
}
__syncthreads();
//--------------------
for (int index_q = 0; index_q < Rq; index_q++) {
if (indQ + index_q < N && indK < N) {
regTmp[index_q] =
__expf(val[index_q] - block_max[index_q][threadIdx.y]);
} else {
regTmp[index_q] = 0.0f;
}
}
__syncthreads();
for (int index_q = 0; index_q < Rq; index_q++) {
regTmp[index_q] = WarpAllReduce<SumOp, float, Bc>(regTmp[index_q]);
tmpReduceSum = WarpAllReduce<SumOp, float, Bc>(tmpReduceSum);
if (threadIdx.x == 0) {
block_sum[index_q][threadIdx.y] = regTmp[index_q];
}
shareQK[threadIdx.y * Rq + index_q + Rq * Br] = tmpReduceSum;
}
__syncthreads();
for (int index_q = 0; index_q < Rq; index_q++) {
if (newMax[index_q] > block_max[index_q][threadIdx.y]) {
newSum[index_q] = std::fma(
block_sum[index_q][threadIdx.y],
__expf(block_max[index_q][threadIdx.y] - newMax[index_q]),
if (newMax[index_q] > shareQK[threadIdx.y * Rq + index_q]) {
newSum[index_q] =
std::fma(shareQK[threadIdx.y * Rq + index_q + Rq * Br],
__expf(shareQK[threadIdx.y * Rq + index_q] -
newMax[index_q]),
newSum[index_q]);
} else {
newSum[index_q] = std::fma(
newSum[index_q],
__expf(newMax[index_q] - block_max[index_q][threadIdx.y]),
block_sum[index_q][threadIdx.y]);
newSum[index_q] =
std::fma(newSum[index_q],
__expf(newMax[index_q] -
shareQK[threadIdx.y * Rq + index_q]),
shareQK[threadIdx.y * Rq + index_q + Rq * Br]);
newMax[index_q] = block_max[index_q][threadIdx.y];
newMax[index_q] = shareQK[threadIdx.y * Rq + index_q];
}
}
for (int index_q = 0; index_q < Rq; index_q++) {
// PV
for (int index_v = 0; index_v < Rv; index_v++) {
sumSV[index_q * Rv + index_v] = 0.0f;
}
}
matmulSV<Br, Bc, Rq, Rv>(sumQK, inputV, Vds, N, d, j, indQ, indK, indV,
regLeft, regRight, val, newMax, sumSV);
for (int index_q = 0; index_q < Rq; index_q++) {
for (int index_v = 0; index_v < Rv; index_v++) {
out[index_q * Rv + index_v] = std::fma(
__expf(oldMax[index_q] - newMax[index_q]),
out[index_q * Rv + index_v], sumSV[index_q * Rv + index_v]);
sumSV[index_q * Rv + index_v] *=
__expf(oldMax[index_q] - newMax[index_q]);
}
}
matmulSV<Br, Bc, Rq, Rv>(shareQK, inputV, shareVK, N, d, j, indQ, indK,
indV, val, newMax, sumSV);
for (int index_q = 0; index_q < Rq; index_q++) {
oldMax[index_q] = newMax[index_q];
}
@ -229,9 +225,15 @@ __global__ void _attentionKernel(const float *__restrict inputQ,
for (int index_q = 0; index_q < Rq; index_q++) {
float inv = __fdividef(1.0F, newSum[index_q]);
for (int index_v = 0; index_v < Rv; index_v++) {
if (indQ + index_q < N && indV + index_v < d) {
output[(indQ + index_q) * d + indV + index_v] =
out[index_q * Rv + index_v] * inv;
sumSV[index_q * Rv + index_v] = sumSV[index_q * Rv + index_v] * inv;
}
}
for (int index_q = 0; index_q < Rq; index_q++) {
for (int id = 0; id < (int)(Rv / 4); id++) {
if (indQ + index_q < N) {
(float4 &)output[(indQ + index_q) * d + indV + id * 4] =
(float4 &)sumSV[index_q * Rv + id * 4];
}
}
}
@ -239,16 +241,12 @@ __global__ void _attentionKernel(const float *__restrict inputQ,
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d, float *output) {
int Br = 32;
int Bc = 32; // Br>=Bc
int Rq = 3;
int Rv = 4;
int num_block_x = (d + Rv * Bc - 1) / (Rv * Bc);
int num_block_y = (N + Rq * Br - 1) / (Rq * Br);
dim3 grid_dim(num_block_x, num_block_y, 1);
dim3 block_dim(Bc, Br, 1);
_attentionKernel<32, 32, 3, 4>
_attentionKernel<Br, Bc, Rq, Rv>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
}
} // namespace infini