modified the format from master

This commit is contained in:
xgqdut2016 2023-09-21 15:03:57 +08:00
parent 410844c058
commit b1a2d91aba
6 changed files with 120 additions and 92 deletions

View File

@ -2,6 +2,7 @@
#include "operators/unary.h" #include "operators/unary.h"
namespace infini { namespace infini {
void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output); void attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d, float *output);
}; // namespace infini }; // namespace infini

View File

@ -19,7 +19,7 @@ class AttentionObj : public OperatorObj {
* @param inputV The input tensor V. * @param inputV The input tensor V.
*/ */
AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, Tensor inputV, AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output); Tensor output);
OP_CLONE(AttentionObj); OP_CLONE(AttentionObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -1,8 +1,7 @@
#include "operators/attention.h" #include "operators/attention.h"
#include "cuda/cuda_attention.h"
#include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h" #include "cuda/cuda_runtime.h"
#include "cuda/cuda_attention.h"
namespace infini { namespace infini {
@ -19,11 +18,11 @@ class AttentionCuda : public CudaKernelWithoutConfig {
int d = op->getInputs(0)->getDims()[1]; int d = op->getInputs(0)->getDims()[1];
attentionKernel((float *)inputQData, (float *)inputKData, attentionKernel((float *)inputQData, (float *)inputKData,
(float *)inputVData, N, d, (float *)outputData); (float *)inputVData, N, d, (float *)outputData);
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32, AttentionCuda, REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32,
"Attention_CUDA_Float32"); AttentionCuda, "Attention_CUDA_Float32");
}; // namespace infini }; // namespace infini

View File

@ -3,14 +3,16 @@
#define BLOCK_DIM_x 2 #define BLOCK_DIM_x 2
#define BLOCK_DIM_y 2 #define BLOCK_DIM_y 2
#define max_function(a,b) ((a)>(b)?(a):(b)) #define max_function(a, b) ((a) > (b) ? (a) : (b))
__global__ void _attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d,
float *output) {
int i = threadIdx.x + blockIdx.x * blockDim.x; //
int phNumN = (N + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
__global__ void _attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output){
int i = threadIdx.x + blockIdx.x * blockDim.x; //
int phNumN = (N + BLOCK_DIM_y - 1)/BLOCK_DIM_y;
__shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y]; __shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y]; __shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y];
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__; block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
block_sum[threadIdx.x][threadIdx.y] = 0.0f; block_sum[threadIdx.x][threadIdx.y] = 0.0f;
__shared__ float grid_sum[BLOCK_DIM_x]; __shared__ float grid_sum[BLOCK_DIM_x];
@ -20,92 +22,108 @@ __global__ void _attentionKernel(const float *inputQ, const float *inputK, cons
grid_max_old[threadIdx.x] = -__FLT_MAX__; grid_max_old[threadIdx.x] = -__FLT_MAX__;
grid_sum[threadIdx.x] = 0.0f; grid_sum[threadIdx.x] = 0.0f;
__shared__ float S[BLOCK_DIM_x][BLOCK_DIM_y]; __shared__ float S[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float Out_new[BLOCK_DIM_x][BLOCK_DIM_y]; __shared__ float Out_new[BLOCK_DIM_x][BLOCK_DIM_y];
Out_new[threadIdx.x][threadIdx.y] = 0.0f; Out_new[threadIdx.x][threadIdx.y] = 0.0f;
for(int phn = 0; phn < phNumN; phn++){ for (int phn = 0; phn < phNumN; phn++) {
int j = threadIdx.y + phn*BLOCK_DIM_y; int j = threadIdx.y + phn * BLOCK_DIM_y;
if(i < N && j < N){ if (i < N && j < N) {
float sum_s = 0; float sum_s = 0;
for(int index = 0; index < d; index++){ for (int index = 0; index < d; index++) {
sum_s += inputQ[i * d + index] * inputK[j * d + index]; sum_s += inputQ[i * d + index] * inputK[j * d + index];
} }
S[threadIdx.x][threadIdx.y] = sum_s; S[threadIdx.x][threadIdx.y] = sum_s;
block_sum[threadIdx.x][threadIdx.y] = 1.0f; block_sum[threadIdx.x][threadIdx.y] = 1.0f;
block_max[threadIdx.x][threadIdx.y] = sum_s; block_max[threadIdx.x][threadIdx.y] = sum_s;
} } else {
else{
S[threadIdx.x][threadIdx.y] = 0.0f; S[threadIdx.x][threadIdx.y] = 0.0f;
block_sum[threadIdx.x][threadIdx.y] = 0.0f; block_sum[threadIdx.x][threadIdx.y] = 0.0f;
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__; block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
} }
//----------------fix i, compute the max S[i,j] of this block //----------------fix i, compute the max S[i,j] of this block
__syncthreads(); __syncthreads();
for(int strip = BLOCK_DIM_y/2; strip > 0; strip = strip/2){ for (int strip = BLOCK_DIM_y / 2; strip > 0; strip = strip / 2) {
if (threadIdx.y < strip){ if (threadIdx.y < strip) {
if(block_max[threadIdx.x][threadIdx.y] > block_max[threadIdx.x][threadIdx.y + strip]){ if (block_max[threadIdx.x][threadIdx.y] >
block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y] + block_sum[threadIdx.x][threadIdx.y + strip]*__expf(block_max[threadIdx.x][threadIdx.y + strip] - block_max[threadIdx.x][threadIdx.y]); block_max[threadIdx.x][threadIdx.y + strip]) {
} block_sum[threadIdx.x][threadIdx.y] =
else{ block_sum[threadIdx.x][threadIdx.y] +
block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y + strip] + block_sum[threadIdx.x][threadIdx.y]*__expf(block_max[threadIdx.x][threadIdx.y] - block_max[threadIdx.x][threadIdx.y + strip]); block_sum[threadIdx.x][threadIdx.y + strip] *
block_max[threadIdx.x][threadIdx.y] = block_max[threadIdx.x][threadIdx.y + strip]; __expf(block_max[threadIdx.x][threadIdx.y + strip] -
block_max[threadIdx.x][threadIdx.y]);
} else {
block_sum[threadIdx.x][threadIdx.y] =
block_sum[threadIdx.x][threadIdx.y + strip] +
block_sum[threadIdx.x][threadIdx.y] *
__expf(block_max[threadIdx.x][threadIdx.y] -
block_max[threadIdx.x][threadIdx.y + strip]);
block_max[threadIdx.x][threadIdx.y] =
block_max[threadIdx.x][threadIdx.y + strip];
} }
} }
}//block_max[threadIdx.x][0]store the local max of this block } // block_max[threadIdx.x][0]store the local max of this block
__syncthreads(); __syncthreads();
if(threadIdx.y == 0){ if (threadIdx.y == 0) {
if(grid_max[threadIdx.x] > block_max[threadIdx.x][0]){ if (grid_max[threadIdx.x] > block_max[threadIdx.x][0]) {
grid_sum[threadIdx.x] = grid_sum[threadIdx.x] + block_sum[threadIdx.x][0] * __expf(block_max[threadIdx.x][0] - grid_max[threadIdx.x]); grid_sum[threadIdx.x] = grid_sum[threadIdx.x] +
} block_sum[threadIdx.x][0] *
else{ __expf(block_max[threadIdx.x][0] -
grid_sum[threadIdx.x] = block_sum[threadIdx.x][0] + grid_sum[threadIdx.x]*__expf(grid_max[threadIdx.x] - block_max[threadIdx.x][0]); grid_max[threadIdx.x]);
} else {
grid_sum[threadIdx.x] =
block_sum[threadIdx.x][0] +
grid_sum[threadIdx.x] * __expf(grid_max[threadIdx.x] -
block_max[threadIdx.x][0]);
grid_max[threadIdx.x] = block_max[threadIdx.x][0]; grid_max[threadIdx.x] = block_max[threadIdx.x][0];
}//compare the max between the different blocks, when the loop end, grid_max store the global max } // compare the max between the different blocks, when the loop
// end, grid_max store the global max
} }
__syncthreads(); __syncthreads();
S[threadIdx.x][threadIdx.y] = __expf(S[threadIdx.x][threadIdx.y] - grid_max[threadIdx.x]); //softmax(s)*L S[threadIdx.x][threadIdx.y] =
__expf(S[threadIdx.x][threadIdx.y] -
grid_max[threadIdx.x]); // softmax(s)*L
__syncthreads(); __syncthreads();
int vj = threadIdx.y + blockIdx.y * blockDim.y; int vj = threadIdx.y + blockIdx.y * blockDim.y;
//do not write vj = threadIdx.y + ph * blockDim.y // do not write vj = threadIdx.y + ph * blockDim.y
float sum_o; float sum_o;
if(vj < d){ if (vj < d) {
sum_o = 0; sum_o = 0;
for(int vid = 0; vid < BLOCK_DIM_y; vid++){ for (int vid = 0; vid < BLOCK_DIM_y; vid++) {
if(vid + phn * BLOCK_DIM_y < N){ if (vid + phn * BLOCK_DIM_y < N) {
sum_o += S[threadIdx.x][vid]*inputV[(vid + phn * BLOCK_DIM_y) * d + vj]; sum_o += S[threadIdx.x][vid] *
inputV[(vid + phn * BLOCK_DIM_y) * d + vj];
} }
} }
Out_new[threadIdx.x][threadIdx.y] = __expf(grid_max_old[threadIdx.x] - grid_max[threadIdx.x])*Out_new[threadIdx.x][threadIdx.y] + sum_o; Out_new[threadIdx.x][threadIdx.y] =
__expf(grid_max_old[threadIdx.x] - grid_max[threadIdx.x]) *
Out_new[threadIdx.x][threadIdx.y] +
sum_o;
grid_max_old[threadIdx.x] = grid_max[threadIdx.x]; grid_max_old[threadIdx.x] = grid_max[threadIdx.x];
} }
} }
__syncthreads(); __syncthreads();
int j = threadIdx.y + blockIdx.y * blockDim.y; int j = threadIdx.y + blockIdx.y * blockDim.y;
if(i < N && j < d){ if (i < N && j < d) {
output[i * d + j] = Out_new[threadIdx.x][threadIdx.y]*__fdividef(1.0F, grid_sum[threadIdx.x]); output[i * d + j] = Out_new[threadIdx.x][threadIdx.y] *
__fdividef(1.0F, grid_sum[threadIdx.x]);
} }
} }
namespace infini { namespace infini {
void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output) { void attentionKernel(const float *inputQ, const float *inputK,
int num_block_y = (max_function(N,d) + BLOCK_DIM_y - 1)/BLOCK_DIM_y; const float *inputV, int N, int d, float *output) {
int num_block_x = (N + BLOCK_DIM_x - 1)/BLOCK_DIM_x; int num_block_y = (max_function(N, d) + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
int share_mem = (5*BLOCK_DIM_y + 2)*BLOCK_DIM_x*sizeof(float); int num_block_x = (N + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x,BLOCK_DIM_y,1); int share_mem = (5 * BLOCK_DIM_y + 2) * BLOCK_DIM_x * sizeof(float);
dim3 grid_dim(num_block_x,num_block_y,1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV, N, d, output); dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV,
N, d, output);
} }
} // namespace infini } // namespace infini

View File

@ -4,13 +4,14 @@
namespace infini { namespace infini {
AttentionObj::AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, AttentionObj::AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK,
Tensor inputV, Tensor output) Tensor inputV, Tensor output)
: OperatorObj(OpType::Attention, TensorVec{inputQ, inputK, inputV}, : OperatorObj(OpType::Attention, TensorVec{inputQ, inputK, inputV},
{output}) { {output}) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> AttentionObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>>
AttentionObj::inferShape(const TensorVec &inputs) const {
auto shapeQ = inputs[0]->getDims(); auto shapeQ = inputs[0]->getDims();
auto shapeK = inputs[1]->getDims(); auto shapeK = inputs[1]->getDims();
auto shapeV = inputs[2]->getDims(); auto shapeV = inputs[2]->getDims();
@ -37,6 +38,8 @@ vector<int> AttentionObj::getWorkloadVector() const {
return ret; return ret;
} }
vector<int> AttentionObj::getOpAttrVector() const { return {type.underlying()}; } vector<int> AttentionObj::getOpAttrVector() const {
return {type.underlying()};
}
} // namespace infini } // namespace infini

View File

@ -9,9 +9,9 @@
namespace infini { namespace infini {
void test_attention(const Shape &outputShape, const vector<float> &inputQData, void test_attention(const Shape &outputShape, const vector<float> &inputQData,
const vector<float> &inputKData, const vector<float> &inputKData,
const vector<float> &inputVData, const vector<float> &inputVData,
const vector<float> &ExpectData) { const vector<float> &ExpectData) {
Runtime runtime = NativeCpuRuntimeObj::getInstance(); Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime); Graph gCpu = make_ref<GraphObj>(runtime);
auto inputV = gCpu->addTensor(outputShape, DataType::Float32); auto inputV = gCpu->addTensor(outputShape, DataType::Float32);
@ -31,7 +31,7 @@ void test_attention(const Shape &outputShape, const vector<float> &inputQData,
auto inputKGpu = gCuda->cloneTensor(inputK); auto inputKGpu = gCuda->cloneTensor(inputK);
auto op = gCuda->addOp<AttentionObj>(inputQGpu, inputKGpu, inputVGpu, auto op = gCuda->addOp<AttentionObj>(inputQGpu, inputKGpu, inputVGpu,
nullptr); // AttentionObj nullptr); // AttentionObj
gCuda->dataMalloc(); gCuda->dataMalloc();
inputVGpu->copyin(inputVData); inputVGpu->copyin(inputVData);
inputQGpu->copyin(inputQData); inputQGpu->copyin(inputQData);
@ -45,25 +45,32 @@ void test_attention(const Shape &outputShape, const vector<float> &inputQData,
TEST(CUDA_Attention, run) { TEST(CUDA_Attention, run) {
test_attention( test_attention(
Shape{6,5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., Shape{6, 5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.}, 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.}, 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00, 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
2.017291e+00, 2.945395e+00, 1.997352e-02, 1.017340e+00, 2.017291e+00, 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
2.999871e+00, 3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00, vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03, 6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00}); 8.536250e-03, 1.004909e+00, 2.017291e+00, 2.945395e+00,
1.997352e-02, 1.017340e+00, 2.017291e+00, 2.999871e+00,
3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00,
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
8.536250e-03, 1.004909e+00});
test_attention(Shape{4, 3}, // inputQ test_attention(
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK Shape{4, 3}, // inputQ
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV
vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370, 1.0179861, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.},
1.9941283, 2.9905086, 0.0210545, 1.0006673, 1.9993325, 2.9894698}); vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370,
1.0179861, 1.9941283, 2.9905086, 0.0210545, 1.0006673,
1.9993325, 2.9894698});
} // python output } // python output