modified the format from master

This commit is contained in:
xgqdut2016 2023-09-21 15:03:57 +08:00
parent 410844c058
commit b1a2d91aba
6 changed files with 120 additions and 92 deletions

View File

@ -2,6 +2,7 @@
#include "operators/unary.h"
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output);
void attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d, float *output);
}; // namespace infini

View File

@ -19,7 +19,7 @@ class AttentionObj : public OperatorObj {
* @param inputV The input tensor V.
*/
AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output);
Tensor output);
OP_CLONE(AttentionObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -1,8 +1,7 @@
#include "operators/attention.h"
#include "cuda/cuda_attention.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_attention.h"
namespace infini {
@ -19,11 +18,11 @@ class AttentionCuda : public CudaKernelWithoutConfig {
int d = op->getInputs(0)->getDims()[1];
attentionKernel((float *)inputQData, (float *)inputKData,
(float *)inputVData, N, d, (float *)outputData);
(float *)inputVData, N, d, (float *)outputData);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32, AttentionCuda,
"Attention_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32,
AttentionCuda, "Attention_CUDA_Float32");
}; // namespace infini

View File

@ -3,14 +3,16 @@
#define BLOCK_DIM_x 2
#define BLOCK_DIM_y 2
#define max_function(a,b) ((a)>(b)?(a):(b))
#define max_function(a, b) ((a) > (b) ? (a) : (b))
__global__ void _attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d,
float *output) {
int i = threadIdx.x + blockIdx.x * blockDim.x; //
int phNumN = (N + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
__global__ void _attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output){
int i = threadIdx.x + blockIdx.x * blockDim.x; //
int phNumN = (N + BLOCK_DIM_y - 1)/BLOCK_DIM_y;
__shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y];
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
__shared__ float grid_sum[BLOCK_DIM_x];
@ -20,92 +22,108 @@ __global__ void _attentionKernel(const float *inputQ, const float *inputK, cons
grid_max_old[threadIdx.x] = -__FLT_MAX__;
grid_sum[threadIdx.x] = 0.0f;
__shared__ float S[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float Out_new[BLOCK_DIM_x][BLOCK_DIM_y];
Out_new[threadIdx.x][threadIdx.y] = 0.0f;
for(int phn = 0; phn < phNumN; phn++){
int j = threadIdx.y + phn*BLOCK_DIM_y;
if(i < N && j < N){
for (int phn = 0; phn < phNumN; phn++) {
int j = threadIdx.y + phn * BLOCK_DIM_y;
if (i < N && j < N) {
float sum_s = 0;
for(int index = 0; index < d; index++){
for (int index = 0; index < d; index++) {
sum_s += inputQ[i * d + index] * inputK[j * d + index];
}
S[threadIdx.x][threadIdx.y] = sum_s;
block_sum[threadIdx.x][threadIdx.y] = 1.0f;
block_max[threadIdx.x][threadIdx.y] = sum_s;
}
else{
} else {
S[threadIdx.x][threadIdx.y] = 0.0f;
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
}
//----------------fix i, compute the max S[i,j] of this block
__syncthreads();
for(int strip = BLOCK_DIM_y/2; strip > 0; strip = strip/2){
if (threadIdx.y < strip){
if(block_max[threadIdx.x][threadIdx.y] > block_max[threadIdx.x][threadIdx.y + strip]){
block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y] + block_sum[threadIdx.x][threadIdx.y + strip]*__expf(block_max[threadIdx.x][threadIdx.y + strip] - block_max[threadIdx.x][threadIdx.y]);
}
else{
block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y + strip] + block_sum[threadIdx.x][threadIdx.y]*__expf(block_max[threadIdx.x][threadIdx.y] - block_max[threadIdx.x][threadIdx.y + strip]);
block_max[threadIdx.x][threadIdx.y] = block_max[threadIdx.x][threadIdx.y + strip];
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip = strip / 2) {
if (threadIdx.y < strip) {
if (block_max[threadIdx.x][threadIdx.y] >
block_max[threadIdx.x][threadIdx.y + strip]) {
block_sum[threadIdx.x][threadIdx.y] =
block_sum[threadIdx.x][threadIdx.y] +
block_sum[threadIdx.x][threadIdx.y + strip] *
__expf(block_max[threadIdx.x][threadIdx.y + strip] -
block_max[threadIdx.x][threadIdx.y]);
} else {
block_sum[threadIdx.x][threadIdx.y] =
block_sum[threadIdx.x][threadIdx.y + strip] +
block_sum[threadIdx.x][threadIdx.y] *
__expf(block_max[threadIdx.x][threadIdx.y] -
block_max[threadIdx.x][threadIdx.y + strip]);
block_max[threadIdx.x][threadIdx.y] =
block_max[threadIdx.x][threadIdx.y + strip];
}
}
}//block_max[threadIdx.x][0]store the local max of this block
} // block_max[threadIdx.x][0]store the local max of this block
__syncthreads();
if(threadIdx.y == 0){
if(grid_max[threadIdx.x] > block_max[threadIdx.x][0]){
grid_sum[threadIdx.x] = grid_sum[threadIdx.x] + block_sum[threadIdx.x][0] * __expf(block_max[threadIdx.x][0] - grid_max[threadIdx.x]);
}
else{
grid_sum[threadIdx.x] = block_sum[threadIdx.x][0] + grid_sum[threadIdx.x]*__expf(grid_max[threadIdx.x] - block_max[threadIdx.x][0]);
if (threadIdx.y == 0) {
if (grid_max[threadIdx.x] > block_max[threadIdx.x][0]) {
grid_sum[threadIdx.x] = grid_sum[threadIdx.x] +
block_sum[threadIdx.x][0] *
__expf(block_max[threadIdx.x][0] -
grid_max[threadIdx.x]);
} else {
grid_sum[threadIdx.x] =
block_sum[threadIdx.x][0] +
grid_sum[threadIdx.x] * __expf(grid_max[threadIdx.x] -
block_max[threadIdx.x][0]);
grid_max[threadIdx.x] = block_max[threadIdx.x][0];
}//compare the max between the different blocks, when the loop end, grid_max store the global max
}
} // compare the max between the different blocks, when the loop
// end, grid_max store the global max
}
__syncthreads();
S[threadIdx.x][threadIdx.y] = __expf(S[threadIdx.x][threadIdx.y] - grid_max[threadIdx.x]); //softmax(s)*L
S[threadIdx.x][threadIdx.y] =
__expf(S[threadIdx.x][threadIdx.y] -
grid_max[threadIdx.x]); // softmax(s)*L
__syncthreads();
int vj = threadIdx.y + blockIdx.y * blockDim.y;
//do not write vj = threadIdx.y + ph * blockDim.y
// do not write vj = threadIdx.y + ph * blockDim.y
float sum_o;
if(vj < d){
if (vj < d) {
sum_o = 0;
for(int vid = 0; vid < BLOCK_DIM_y; vid++){
if(vid + phn * BLOCK_DIM_y < N){
sum_o += S[threadIdx.x][vid]*inputV[(vid + phn * BLOCK_DIM_y) * d + vj];
for (int vid = 0; vid < BLOCK_DIM_y; vid++) {
if (vid + phn * BLOCK_DIM_y < N) {
sum_o += S[threadIdx.x][vid] *
inputV[(vid + phn * BLOCK_DIM_y) * d + vj];
}
}
Out_new[threadIdx.x][threadIdx.y] = __expf(grid_max_old[threadIdx.x] - grid_max[threadIdx.x])*Out_new[threadIdx.x][threadIdx.y] + sum_o;
Out_new[threadIdx.x][threadIdx.y] =
__expf(grid_max_old[threadIdx.x] - grid_max[threadIdx.x]) *
Out_new[threadIdx.x][threadIdx.y] +
sum_o;
grid_max_old[threadIdx.x] = grid_max[threadIdx.x];
}
}
__syncthreads();
int j = threadIdx.y + blockIdx.y * blockDim.y;
if(i < N && j < d){
output[i * d + j] = Out_new[threadIdx.x][threadIdx.y]*__fdividef(1.0F, grid_sum[threadIdx.x]);
if (i < N && j < d) {
output[i * d + j] = Out_new[threadIdx.x][threadIdx.y] *
__fdividef(1.0F, grid_sum[threadIdx.x]);
}
}
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output) {
int num_block_y = (max_function(N,d) + BLOCK_DIM_y - 1)/BLOCK_DIM_y;
int num_block_x = (N + BLOCK_DIM_x - 1)/BLOCK_DIM_x;
int share_mem = (5*BLOCK_DIM_y + 2)*BLOCK_DIM_x*sizeof(float);
dim3 block_dim(BLOCK_DIM_x,BLOCK_DIM_y,1);
dim3 grid_dim(num_block_x,num_block_y,1);
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV, N, d, output);
void attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d, float *output) {
int num_block_y = (max_function(N, d) + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
int num_block_x = (N + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
int share_mem = (5 * BLOCK_DIM_y + 2) * BLOCK_DIM_x * sizeof(float);
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV,
N, d, output);
}
} // namespace infini

View File

@ -4,13 +4,14 @@
namespace infini {
AttentionObj::AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK,
Tensor inputV, Tensor output)
Tensor inputV, Tensor output)
: OperatorObj(OpType::Attention, TensorVec{inputQ, inputK, inputV},
{output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> AttentionObj::inferShape(const TensorVec &inputs) const {
optional<vector<Shape>>
AttentionObj::inferShape(const TensorVec &inputs) const {
auto shapeQ = inputs[0]->getDims();
auto shapeK = inputs[1]->getDims();
auto shapeV = inputs[2]->getDims();
@ -37,6 +38,8 @@ vector<int> AttentionObj::getWorkloadVector() const {
return ret;
}
vector<int> AttentionObj::getOpAttrVector() const { return {type.underlying()}; }
vector<int> AttentionObj::getOpAttrVector() const {
return {type.underlying()};
}
} // namespace infini

View File

@ -9,9 +9,9 @@
namespace infini {
void test_attention(const Shape &outputShape, const vector<float> &inputQData,
const vector<float> &inputKData,
const vector<float> &inputVData,
const vector<float> &ExpectData) {
const vector<float> &inputKData,
const vector<float> &inputVData,
const vector<float> &ExpectData) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto inputV = gCpu->addTensor(outputShape, DataType::Float32);
@ -31,7 +31,7 @@ void test_attention(const Shape &outputShape, const vector<float> &inputQData,
auto inputKGpu = gCuda->cloneTensor(inputK);
auto op = gCuda->addOp<AttentionObj>(inputQGpu, inputKGpu, inputVGpu,
nullptr); // AttentionObj
nullptr); // AttentionObj
gCuda->dataMalloc();
inputVGpu->copyin(inputVData);
inputQGpu->copyin(inputQData);
@ -45,25 +45,32 @@ void test_attention(const Shape &outputShape, const vector<float> &inputQData,
TEST(CUDA_Attention, run) {
test_attention(
Shape{6,5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03,
1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00,
2.017291e+00, 2.945395e+00, 1.997352e-02, 1.017340e+00, 2.017291e+00,
2.999871e+00, 3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00,
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03,
1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00});
Shape{6, 5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
8.536250e-03, 1.004909e+00, 2.017291e+00, 2.945395e+00,
1.997352e-02, 1.017340e+00, 2.017291e+00, 2.999871e+00,
3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00,
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
8.536250e-03, 1.004909e+00});
test_attention(Shape{4, 3}, // inputQ
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.},
vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370, 1.0179861,
1.9941283, 2.9905086, 0.0210545, 1.0006673, 1.9993325, 2.9894698});
test_attention(
Shape{4, 3}, // inputQ
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.},
vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370,
1.0179861, 1.9941283, 2.9905086, 0.0210545, 1.0006673,
1.9993325, 2.9894698});
} // python output