warp reduce

This commit is contained in:
xgqdut2016 2024-05-11 16:53:02 +08:00
parent 131d1cb6d0
commit 9b6c44dd40
1 changed files with 102 additions and 74 deletions

View File

@ -1,17 +1,78 @@
#include "cuda/cuda_common.h"
template <int Br, int Bc>
__device__ float matmul(const float *__restrict A, const float *__restrict B,
int d, int indA, int indB) {
float sum_qk = 0.0f;
for (int index = 0; index < d; index++) {
sum_qk += A[indA * d + index] * B[indB * d + index];
}
return sum_qk;
}
template <int Br, int Bc>
__device__ float matmulShare(const float *__restrict inputQ,
const float *__restrict inputK, float *Qds,
float *Kds, int N, int d, int width, int indQ,
int indK) {
float sum_qk = 0.0f;
for (int ph = 0; ph < width; ph++) {
if (indQ < N && threadIdx.x + ph * Bc < d) {
Qds[threadIdx.y * Bc + threadIdx.x] =
inputQ[indQ * d + threadIdx.x + ph * Bc];
} else {
Qds[threadIdx.y * Bc + threadIdx.x] = 0.0f;
}
if (threadIdx.y < Bc) {
Kds[threadIdx.y * Bc + threadIdx.x] = 0.0f;
}
if (threadIdx.y < Bc) {
if (indK < N && threadIdx.y + ph * Bc < d) {
Kds[threadIdx.y * Bc + threadIdx.x] =
inputK[indK * d + threadIdx.y + ph * Bc];
}
}
__syncthreads();
for (int index = 0; index < Bc; index++) {
sum_qk = std::fma(Qds[threadIdx.y * Bc + index],
Kds[index * Bc + threadIdx.x], sum_qk);
}
__syncthreads();
}
return sum_qk;
}
template <typename T> struct SumOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return a + b;
}
};
template <typename T> struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return max(a, b);
}
};
template <template <typename> class ReductionOp, typename T,
int thread_group_width = warpSize>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <int Br, int Bc>
__global__ void _attentionKernel(const float *__restrict inputQ,
const float *__restrict inputK,
const float *__restrict inputV, int N, int d,
float *__restrict output) {
int Tc = (N + Bc - 1) / Bc;
__shared__ float sumQK[Br * Bc];
__shared__ float sumSV[Br * Bc];
__shared__ float block_max[Br * Bc];
__shared__ float block_sum[Br * Bc];
float sumSV;
__shared__ float block_max[Br];
__shared__ float block_sum[Br];
__shared__ float Vds[Bc * Bc];
__shared__ float Qds[Br * Bc];
__shared__ float Kds[Bc * Bc];
@ -22,80 +83,49 @@ __global__ void _attentionKernel(const float *__restrict inputQ,
float newSum;
newMax = -__FLT_MAX__;
oldMax = -__FLT_MAX__;
newSum = 1.0f;
newSum = 0.0f;
float out = 0.0f;
for (int j = 0; j < Tc; j++) {
sumSV[threadIdx.y * Bc + threadIdx.x] = 0.0f;
sumSV = 0.0f;
int indK = threadIdx.x + j * Bc;
float sum_qk = 0.0f;
for (int ph = 0; ph < gridDim.x; ph++) {
if (indQ < N && threadIdx.x + ph * Bc < d) {
Qds[threadIdx.y * Bc + threadIdx.x] =
inputQ[indQ * d + threadIdx.x + ph * Bc];
} else {
Qds[threadIdx.y * Bc + threadIdx.x] = 0.0f;
}
if (threadIdx.y < Bc) {
Kds[threadIdx.y * Bc + threadIdx.x] = 0.0f;
}
if (threadIdx.y < Bc) {
if (indK < N && threadIdx.y + ph * Bc < d) {
Kds[threadIdx.y * Bc + threadIdx.x] =
inputK[indK * d + threadIdx.y + ph * Bc];
}
}
__syncthreads();
for (int index = 0; index < Bc; index++) {
sum_qk = std::fma(Qds[threadIdx.y * Bc + index],
Kds[index * Bc + threadIdx.x], sum_qk);
}
__syncthreads();
}
float tmp_qk = 0.0f;
sum_qk = matmulShare<Br, Bc>(inputQ, inputK, Qds, Kds, N, d, gridDim.x,
indQ, indK);
if (indQ < N && indK < N) {
block_max[threadIdx.y * Bc + threadIdx.x] = sum_qk;
block_sum[threadIdx.y * Bc + threadIdx.x] = 1.0f;
sumQK[threadIdx.y * Bc + threadIdx.x] = sum_qk;
tmp_qk = sum_qk;
} else {
block_max[threadIdx.y * Bc + threadIdx.x] = -__FLT_MAX__;
block_sum[threadIdx.y * Bc + threadIdx.x] = 0.0f;
sumQK[threadIdx.y * Bc + threadIdx.x] = 0.0f;
sum_qk = -__FLT_MAX__;
tmp_qk = 0.0f;
}
__syncthreads();
for (int strip = Bc / 2; strip > 0; strip /= 2) {
if (threadIdx.x < strip) {
if (block_max[threadIdx.y * Bc + threadIdx.x] >
block_max[threadIdx.y * Bc + threadIdx.x + strip]) {
block_sum[threadIdx.y * Bc + threadIdx.x] =
block_sum[threadIdx.y * Bc + threadIdx.x] +
block_sum[threadIdx.y * Bc + threadIdx.x + strip] *
__expf(block_max[threadIdx.y * Bc + threadIdx.x +
strip] -
block_max[threadIdx.y * Bc + threadIdx.x]);
} else {
block_sum[threadIdx.y * Bc + threadIdx.x] =
block_sum[threadIdx.y * Bc + threadIdx.x + strip] +
block_sum[threadIdx.y * Bc + threadIdx.x] *
__expf(block_max[threadIdx.y * Bc + threadIdx.x] -
block_max[threadIdx.y * Bc + threadIdx.x +
strip]);
block_max[threadIdx.y * Bc + threadIdx.x] =
block_max[threadIdx.y * Bc + threadIdx.x + strip];
}
}
__syncthreads();
// softmax reduce
sum_qk = WarpAllReduce<MaxOp, float, Bc>(sum_qk);
if (threadIdx.x == 0) {
block_max[threadIdx.y] = sum_qk;
}
if (newMax > block_max[threadIdx.y * Bc]) {
newSum = newSum + block_sum[threadIdx.y * Bc] *
__expf(block_max[threadIdx.y * Bc] - newMax);
__syncthreads();
float localMax = block_max[threadIdx.y];
//--------------------
float sum_s = 0.0f;
if (indQ < N && indK < N) {
sum_s = __expf(tmp_qk - localMax);
}
sum_s = WarpAllReduce<SumOp, float, Bc>(sum_s);
if (threadIdx.x == 0) {
block_sum[threadIdx.y] = sum_s;
}
__syncthreads();
float localSum = block_sum[threadIdx.y];
if (newMax > localMax) {
newSum = std::fma(localSum, __expf(localMax - newMax), newSum);
// newSum = newSum + localSum * __expf(localMax - newMax);
} else {
newSum = block_sum[threadIdx.y * Bc] +
newSum * __expf(newMax - block_max[threadIdx.y * Bc]);
newMax = block_max[threadIdx.y * Bc];
newSum = std::fma(newSum, __expf(newMax - localMax), localSum);
// newSum = localSum + newSum * __expf(newMax - localMax);
newMax = localMax;
}
if (threadIdx.y < Bc) {
if (threadIdx.y + j * Bc < N && indV < d) {
@ -106,20 +136,18 @@ __global__ void _attentionKernel(const float *__restrict inputQ,
}
}
if (indQ < N && indK < N) {
sumQK[threadIdx.y * Bc + threadIdx.x] =
__expf(sumQK[threadIdx.y * Bc + threadIdx.x] - newMax);
sumQK[threadIdx.y * Bc + threadIdx.x] = __expf(tmp_qk - newMax);
} else {
sumQK[threadIdx.y * Bc + threadIdx.x] = 0.0f;
}
__syncthreads();
for (int phc = 0; phc < Bc; phc++) {
sumSV[threadIdx.y * Bc + threadIdx.x] = std::fma(
sumQK[threadIdx.y * Bc + phc], Vds[threadIdx.x * Bc + phc],
sumSV[threadIdx.y * Bc + threadIdx.x]);
sumSV = std::fma(sumQK[threadIdx.y * Bc + phc],
Vds[threadIdx.x * Bc + phc], sumSV);
}
out = __expf(oldMax - newMax) * out +
sumSV[threadIdx.y * Bc + threadIdx.x];
out = std::fma(__expf(oldMax - newMax), out, sumSV);
// out = __expf(oldMax - newMax) * out + sumSV;
oldMax = newMax;
//__syncthreads();