diff --git a/src/kernels/cuda/attention.cu b/src/kernels/cuda/attention.cu index 4de1f3dd..2327b1c5 100644 --- a/src/kernels/cuda/attention.cu +++ b/src/kernels/cuda/attention.cu @@ -1,17 +1,78 @@ #include "cuda/cuda_common.h" + +template +__device__ float matmul(const float *__restrict A, const float *__restrict B, + int d, int indA, int indB) { + float sum_qk = 0.0f; + for (int index = 0; index < d; index++) { + sum_qk += A[indA * d + index] * B[indB * d + index]; + } + return sum_qk; +} +template +__device__ float matmulShare(const float *__restrict inputQ, + const float *__restrict inputK, float *Qds, + float *Kds, int N, int d, int width, int indQ, + int indK) { + float sum_qk = 0.0f; + for (int ph = 0; ph < width; ph++) { + if (indQ < N && threadIdx.x + ph * Bc < d) { + Qds[threadIdx.y * Bc + threadIdx.x] = + inputQ[indQ * d + threadIdx.x + ph * Bc]; + } else { + Qds[threadIdx.y * Bc + threadIdx.x] = 0.0f; + } + if (threadIdx.y < Bc) { + Kds[threadIdx.y * Bc + threadIdx.x] = 0.0f; + } + if (threadIdx.y < Bc) { + if (indK < N && threadIdx.y + ph * Bc < d) { + Kds[threadIdx.y * Bc + threadIdx.x] = + inputK[indK * d + threadIdx.y + ph * Bc]; + } + } + + __syncthreads(); + for (int index = 0; index < Bc; index++) { + sum_qk = std::fma(Qds[threadIdx.y * Bc + index], + Kds[index * Bc + threadIdx.x], sum_qk); + } + __syncthreads(); + } + return sum_qk; +} +template struct SumOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return a + b; + } +}; + +template struct MaxOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return max(a, b); + } +}; +template