forked from jiuyuan/InfiniTensor
warp reduce
This commit is contained in:
parent
131d1cb6d0
commit
9b6c44dd40
|
@ -1,17 +1,78 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
|
||||
template <int Br, int Bc>
|
||||
__device__ float matmul(const float *__restrict A, const float *__restrict B,
|
||||
int d, int indA, int indB) {
|
||||
float sum_qk = 0.0f;
|
||||
for (int index = 0; index < d; index++) {
|
||||
sum_qk += A[indA * d + index] * B[indB * d + index];
|
||||
}
|
||||
return sum_qk;
|
||||
}
|
||||
template <int Br, int Bc>
|
||||
__device__ float matmulShare(const float *__restrict inputQ,
|
||||
const float *__restrict inputK, float *Qds,
|
||||
float *Kds, int N, int d, int width, int indQ,
|
||||
int indK) {
|
||||
float sum_qk = 0.0f;
|
||||
for (int ph = 0; ph < width; ph++) {
|
||||
if (indQ < N && threadIdx.x + ph * Bc < d) {
|
||||
Qds[threadIdx.y * Bc + threadIdx.x] =
|
||||
inputQ[indQ * d + threadIdx.x + ph * Bc];
|
||||
} else {
|
||||
Qds[threadIdx.y * Bc + threadIdx.x] = 0.0f;
|
||||
}
|
||||
if (threadIdx.y < Bc) {
|
||||
Kds[threadIdx.y * Bc + threadIdx.x] = 0.0f;
|
||||
}
|
||||
if (threadIdx.y < Bc) {
|
||||
if (indK < N && threadIdx.y + ph * Bc < d) {
|
||||
Kds[threadIdx.y * Bc + threadIdx.x] =
|
||||
inputK[indK * d + threadIdx.y + ph * Bc];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int index = 0; index < Bc; index++) {
|
||||
sum_qk = std::fma(Qds[threadIdx.y * Bc + index],
|
||||
Kds[index * Bc + threadIdx.x], sum_qk);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
return sum_qk;
|
||||
}
|
||||
template <typename T> struct SumOp {
|
||||
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
|
||||
return max(a, b);
|
||||
}
|
||||
};
|
||||
template <template <typename> class ReductionOp, typename T,
|
||||
int thread_group_width = warpSize>
|
||||
__inline__ __device__ T WarpAllReduce(T val) {
|
||||
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
|
||||
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
template <int Br, int Bc>
|
||||
__global__ void _attentionKernel(const float *__restrict inputQ,
|
||||
const float *__restrict inputK,
|
||||
const float *__restrict inputV, int N, int d,
|
||||
|
||||
float *__restrict output) {
|
||||
|
||||
int Tc = (N + Bc - 1) / Bc;
|
||||
|
||||
__shared__ float sumQK[Br * Bc];
|
||||
__shared__ float sumSV[Br * Bc];
|
||||
__shared__ float block_max[Br * Bc];
|
||||
__shared__ float block_sum[Br * Bc];
|
||||
float sumSV;
|
||||
__shared__ float block_max[Br];
|
||||
__shared__ float block_sum[Br];
|
||||
__shared__ float Vds[Bc * Bc];
|
||||
__shared__ float Qds[Br * Bc];
|
||||
__shared__ float Kds[Bc * Bc];
|
||||
|
@ -22,80 +83,49 @@ __global__ void _attentionKernel(const float *__restrict inputQ,
|
|||
float newSum;
|
||||
newMax = -__FLT_MAX__;
|
||||
oldMax = -__FLT_MAX__;
|
||||
newSum = 1.0f;
|
||||
newSum = 0.0f;
|
||||
|
||||
float out = 0.0f;
|
||||
for (int j = 0; j < Tc; j++) {
|
||||
sumSV[threadIdx.y * Bc + threadIdx.x] = 0.0f;
|
||||
sumSV = 0.0f;
|
||||
int indK = threadIdx.x + j * Bc;
|
||||
float sum_qk = 0.0f;
|
||||
for (int ph = 0; ph < gridDim.x; ph++) {
|
||||
if (indQ < N && threadIdx.x + ph * Bc < d) {
|
||||
Qds[threadIdx.y * Bc + threadIdx.x] =
|
||||
inputQ[indQ * d + threadIdx.x + ph * Bc];
|
||||
} else {
|
||||
Qds[threadIdx.y * Bc + threadIdx.x] = 0.0f;
|
||||
}
|
||||
if (threadIdx.y < Bc) {
|
||||
Kds[threadIdx.y * Bc + threadIdx.x] = 0.0f;
|
||||
}
|
||||
if (threadIdx.y < Bc) {
|
||||
if (indK < N && threadIdx.y + ph * Bc < d) {
|
||||
Kds[threadIdx.y * Bc + threadIdx.x] =
|
||||
inputK[indK * d + threadIdx.y + ph * Bc];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int index = 0; index < Bc; index++) {
|
||||
sum_qk = std::fma(Qds[threadIdx.y * Bc + index],
|
||||
Kds[index * Bc + threadIdx.x], sum_qk);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float tmp_qk = 0.0f;
|
||||
sum_qk = matmulShare<Br, Bc>(inputQ, inputK, Qds, Kds, N, d, gridDim.x,
|
||||
indQ, indK);
|
||||
if (indQ < N && indK < N) {
|
||||
block_max[threadIdx.y * Bc + threadIdx.x] = sum_qk;
|
||||
block_sum[threadIdx.y * Bc + threadIdx.x] = 1.0f;
|
||||
sumQK[threadIdx.y * Bc + threadIdx.x] = sum_qk;
|
||||
tmp_qk = sum_qk;
|
||||
|
||||
} else {
|
||||
block_max[threadIdx.y * Bc + threadIdx.x] = -__FLT_MAX__;
|
||||
block_sum[threadIdx.y * Bc + threadIdx.x] = 0.0f;
|
||||
sumQK[threadIdx.y * Bc + threadIdx.x] = 0.0f;
|
||||
sum_qk = -__FLT_MAX__;
|
||||
tmp_qk = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int strip = Bc / 2; strip > 0; strip /= 2) {
|
||||
if (threadIdx.x < strip) {
|
||||
if (block_max[threadIdx.y * Bc + threadIdx.x] >
|
||||
block_max[threadIdx.y * Bc + threadIdx.x + strip]) {
|
||||
block_sum[threadIdx.y * Bc + threadIdx.x] =
|
||||
block_sum[threadIdx.y * Bc + threadIdx.x] +
|
||||
block_sum[threadIdx.y * Bc + threadIdx.x + strip] *
|
||||
__expf(block_max[threadIdx.y * Bc + threadIdx.x +
|
||||
strip] -
|
||||
block_max[threadIdx.y * Bc + threadIdx.x]);
|
||||
} else {
|
||||
block_sum[threadIdx.y * Bc + threadIdx.x] =
|
||||
block_sum[threadIdx.y * Bc + threadIdx.x + strip] +
|
||||
block_sum[threadIdx.y * Bc + threadIdx.x] *
|
||||
__expf(block_max[threadIdx.y * Bc + threadIdx.x] -
|
||||
block_max[threadIdx.y * Bc + threadIdx.x +
|
||||
strip]);
|
||||
block_max[threadIdx.y * Bc + threadIdx.x] =
|
||||
block_max[threadIdx.y * Bc + threadIdx.x + strip];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// softmax reduce
|
||||
sum_qk = WarpAllReduce<MaxOp, float, Bc>(sum_qk);
|
||||
if (threadIdx.x == 0) {
|
||||
block_max[threadIdx.y] = sum_qk;
|
||||
}
|
||||
|
||||
if (newMax > block_max[threadIdx.y * Bc]) {
|
||||
newSum = newSum + block_sum[threadIdx.y * Bc] *
|
||||
__expf(block_max[threadIdx.y * Bc] - newMax);
|
||||
__syncthreads();
|
||||
float localMax = block_max[threadIdx.y];
|
||||
//--------------------
|
||||
float sum_s = 0.0f;
|
||||
if (indQ < N && indK < N) {
|
||||
sum_s = __expf(tmp_qk - localMax);
|
||||
}
|
||||
sum_s = WarpAllReduce<SumOp, float, Bc>(sum_s);
|
||||
if (threadIdx.x == 0) {
|
||||
block_sum[threadIdx.y] = sum_s;
|
||||
}
|
||||
__syncthreads();
|
||||
float localSum = block_sum[threadIdx.y];
|
||||
if (newMax > localMax) {
|
||||
newSum = std::fma(localSum, __expf(localMax - newMax), newSum);
|
||||
// newSum = newSum + localSum * __expf(localMax - newMax);
|
||||
} else {
|
||||
newSum = block_sum[threadIdx.y * Bc] +
|
||||
newSum * __expf(newMax - block_max[threadIdx.y * Bc]);
|
||||
newMax = block_max[threadIdx.y * Bc];
|
||||
newSum = std::fma(newSum, __expf(newMax - localMax), localSum);
|
||||
// newSum = localSum + newSum * __expf(newMax - localMax);
|
||||
newMax = localMax;
|
||||
}
|
||||
if (threadIdx.y < Bc) {
|
||||
if (threadIdx.y + j * Bc < N && indV < d) {
|
||||
|
@ -106,20 +136,18 @@ __global__ void _attentionKernel(const float *__restrict inputQ,
|
|||
}
|
||||
}
|
||||
if (indQ < N && indK < N) {
|
||||
sumQK[threadIdx.y * Bc + threadIdx.x] =
|
||||
__expf(sumQK[threadIdx.y * Bc + threadIdx.x] - newMax);
|
||||
sumQK[threadIdx.y * Bc + threadIdx.x] = __expf(tmp_qk - newMax);
|
||||
} else {
|
||||
sumQK[threadIdx.y * Bc + threadIdx.x] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int phc = 0; phc < Bc; phc++) {
|
||||
sumSV[threadIdx.y * Bc + threadIdx.x] = std::fma(
|
||||
sumQK[threadIdx.y * Bc + phc], Vds[threadIdx.x * Bc + phc],
|
||||
sumSV[threadIdx.y * Bc + threadIdx.x]);
|
||||
sumSV = std::fma(sumQK[threadIdx.y * Bc + phc],
|
||||
Vds[threadIdx.x * Bc + phc], sumSV);
|
||||
}
|
||||
out = __expf(oldMax - newMax) * out +
|
||||
sumSV[threadIdx.y * Bc + threadIdx.x];
|
||||
out = std::fma(__expf(oldMax - newMax), out, sumSV);
|
||||
// out = __expf(oldMax - newMax) * out + sumSV;
|
||||
oldMax = newMax;
|
||||
|
||||
//__syncthreads();
|
||||
|
|
Loading…
Reference in New Issue