forked from jiuyuan/InfiniTensor
modified the format from master
This commit is contained in:
parent
410844c058
commit
b1a2d91aba
|
@ -2,6 +2,7 @@
|
|||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output);
|
||||
void attentionKernel(const float *inputQ, const float *inputK,
|
||||
const float *inputV, int N, int d, float *output);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -19,7 +19,7 @@ class AttentionObj : public OperatorObj {
|
|||
* @param inputV The input tensor V.
|
||||
*/
|
||||
AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, Tensor inputV,
|
||||
Tensor output);
|
||||
Tensor output);
|
||||
OP_CLONE(AttentionObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
#include "operators/attention.h"
|
||||
#include "cuda/cuda_attention.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_attention.h"
|
||||
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -19,11 +18,11 @@ class AttentionCuda : public CudaKernelWithoutConfig {
|
|||
int d = op->getInputs(0)->getDims()[1];
|
||||
|
||||
attentionKernel((float *)inputQData, (float *)inputKData,
|
||||
(float *)inputVData, N, d, (float *)outputData);
|
||||
(float *)inputVData, N, d, (float *)outputData);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32, AttentionCuda,
|
||||
"Attention_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32,
|
||||
AttentionCuda, "Attention_CUDA_Float32");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,14 +3,16 @@
|
|||
#define BLOCK_DIM_x 2
|
||||
#define BLOCK_DIM_y 2
|
||||
|
||||
#define max_function(a,b) ((a)>(b)?(a):(b))
|
||||
#define max_function(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
__global__ void _attentionKernel(const float *inputQ, const float *inputK,
|
||||
const float *inputV, int N, int d,
|
||||
float *output) {
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x; //
|
||||
int phNumN = (N + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
|
||||
__global__ void _attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output){
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x; //
|
||||
int phNumN = (N + BLOCK_DIM_y - 1)/BLOCK_DIM_y;
|
||||
|
||||
__shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||
__shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||
__shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
|
||||
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
|
||||
__shared__ float grid_sum[BLOCK_DIM_x];
|
||||
|
@ -20,92 +22,108 @@ __global__ void _attentionKernel(const float *inputQ, const float *inputK, cons
|
|||
grid_max_old[threadIdx.x] = -__FLT_MAX__;
|
||||
grid_sum[threadIdx.x] = 0.0f;
|
||||
__shared__ float S[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||
|
||||
|
||||
__shared__ float Out_new[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||
Out_new[threadIdx.x][threadIdx.y] = 0.0f;
|
||||
for(int phn = 0; phn < phNumN; phn++){
|
||||
int j = threadIdx.y + phn*BLOCK_DIM_y;
|
||||
if(i < N && j < N){
|
||||
for (int phn = 0; phn < phNumN; phn++) {
|
||||
int j = threadIdx.y + phn * BLOCK_DIM_y;
|
||||
if (i < N && j < N) {
|
||||
float sum_s = 0;
|
||||
for(int index = 0; index < d; index++){
|
||||
for (int index = 0; index < d; index++) {
|
||||
sum_s += inputQ[i * d + index] * inputK[j * d + index];
|
||||
}
|
||||
|
||||
|
||||
S[threadIdx.x][threadIdx.y] = sum_s;
|
||||
block_sum[threadIdx.x][threadIdx.y] = 1.0f;
|
||||
block_max[threadIdx.x][threadIdx.y] = sum_s;
|
||||
}
|
||||
else{
|
||||
} else {
|
||||
S[threadIdx.x][threadIdx.y] = 0.0f;
|
||||
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
|
||||
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
|
||||
}
|
||||
|
||||
|
||||
//----------------fix i, compute the max S[i,j] of this block
|
||||
__syncthreads();
|
||||
|
||||
for(int strip = BLOCK_DIM_y/2; strip > 0; strip = strip/2){
|
||||
if (threadIdx.y < strip){
|
||||
if(block_max[threadIdx.x][threadIdx.y] > block_max[threadIdx.x][threadIdx.y + strip]){
|
||||
block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y] + block_sum[threadIdx.x][threadIdx.y + strip]*__expf(block_max[threadIdx.x][threadIdx.y + strip] - block_max[threadIdx.x][threadIdx.y]);
|
||||
}
|
||||
else{
|
||||
block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y + strip] + block_sum[threadIdx.x][threadIdx.y]*__expf(block_max[threadIdx.x][threadIdx.y] - block_max[threadIdx.x][threadIdx.y + strip]);
|
||||
block_max[threadIdx.x][threadIdx.y] = block_max[threadIdx.x][threadIdx.y + strip];
|
||||
|
||||
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip = strip / 2) {
|
||||
if (threadIdx.y < strip) {
|
||||
if (block_max[threadIdx.x][threadIdx.y] >
|
||||
block_max[threadIdx.x][threadIdx.y + strip]) {
|
||||
block_sum[threadIdx.x][threadIdx.y] =
|
||||
block_sum[threadIdx.x][threadIdx.y] +
|
||||
block_sum[threadIdx.x][threadIdx.y + strip] *
|
||||
__expf(block_max[threadIdx.x][threadIdx.y + strip] -
|
||||
block_max[threadIdx.x][threadIdx.y]);
|
||||
} else {
|
||||
block_sum[threadIdx.x][threadIdx.y] =
|
||||
block_sum[threadIdx.x][threadIdx.y + strip] +
|
||||
block_sum[threadIdx.x][threadIdx.y] *
|
||||
__expf(block_max[threadIdx.x][threadIdx.y] -
|
||||
block_max[threadIdx.x][threadIdx.y + strip]);
|
||||
block_max[threadIdx.x][threadIdx.y] =
|
||||
block_max[threadIdx.x][threadIdx.y + strip];
|
||||
}
|
||||
}
|
||||
}//block_max[threadIdx.x][0]store the local max of this block
|
||||
} // block_max[threadIdx.x][0]store the local max of this block
|
||||
__syncthreads();
|
||||
if(threadIdx.y == 0){
|
||||
if(grid_max[threadIdx.x] > block_max[threadIdx.x][0]){
|
||||
grid_sum[threadIdx.x] = grid_sum[threadIdx.x] + block_sum[threadIdx.x][0] * __expf(block_max[threadIdx.x][0] - grid_max[threadIdx.x]);
|
||||
}
|
||||
else{
|
||||
grid_sum[threadIdx.x] = block_sum[threadIdx.x][0] + grid_sum[threadIdx.x]*__expf(grid_max[threadIdx.x] - block_max[threadIdx.x][0]);
|
||||
if (threadIdx.y == 0) {
|
||||
if (grid_max[threadIdx.x] > block_max[threadIdx.x][0]) {
|
||||
grid_sum[threadIdx.x] = grid_sum[threadIdx.x] +
|
||||
block_sum[threadIdx.x][0] *
|
||||
__expf(block_max[threadIdx.x][0] -
|
||||
grid_max[threadIdx.x]);
|
||||
} else {
|
||||
grid_sum[threadIdx.x] =
|
||||
block_sum[threadIdx.x][0] +
|
||||
grid_sum[threadIdx.x] * __expf(grid_max[threadIdx.x] -
|
||||
block_max[threadIdx.x][0]);
|
||||
grid_max[threadIdx.x] = block_max[threadIdx.x][0];
|
||||
}//compare the max between the different blocks, when the loop end, grid_max store the global max
|
||||
|
||||
}
|
||||
} // compare the max between the different blocks, when the loop
|
||||
// end, grid_max store the global max
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
S[threadIdx.x][threadIdx.y] = __expf(S[threadIdx.x][threadIdx.y] - grid_max[threadIdx.x]); //softmax(s)*L
|
||||
|
||||
|
||||
S[threadIdx.x][threadIdx.y] =
|
||||
__expf(S[threadIdx.x][threadIdx.y] -
|
||||
grid_max[threadIdx.x]); // softmax(s)*L
|
||||
|
||||
__syncthreads();
|
||||
int vj = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
//do not write vj = threadIdx.y + ph * blockDim.y
|
||||
// do not write vj = threadIdx.y + ph * blockDim.y
|
||||
float sum_o;
|
||||
if(vj < d){
|
||||
if (vj < d) {
|
||||
sum_o = 0;
|
||||
for(int vid = 0; vid < BLOCK_DIM_y; vid++){
|
||||
if(vid + phn * BLOCK_DIM_y < N){
|
||||
sum_o += S[threadIdx.x][vid]*inputV[(vid + phn * BLOCK_DIM_y) * d + vj];
|
||||
for (int vid = 0; vid < BLOCK_DIM_y; vid++) {
|
||||
if (vid + phn * BLOCK_DIM_y < N) {
|
||||
sum_o += S[threadIdx.x][vid] *
|
||||
inputV[(vid + phn * BLOCK_DIM_y) * d + vj];
|
||||
}
|
||||
}
|
||||
Out_new[threadIdx.x][threadIdx.y] = __expf(grid_max_old[threadIdx.x] - grid_max[threadIdx.x])*Out_new[threadIdx.x][threadIdx.y] + sum_o;
|
||||
Out_new[threadIdx.x][threadIdx.y] =
|
||||
__expf(grid_max_old[threadIdx.x] - grid_max[threadIdx.x]) *
|
||||
Out_new[threadIdx.x][threadIdx.y] +
|
||||
sum_o;
|
||||
grid_max_old[threadIdx.x] = grid_max[threadIdx.x];
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
int j = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
if(i < N && j < d){
|
||||
|
||||
output[i * d + j] = Out_new[threadIdx.x][threadIdx.y]*__fdividef(1.0F, grid_sum[threadIdx.x]);
|
||||
if (i < N && j < d) {
|
||||
|
||||
output[i * d + j] = Out_new[threadIdx.x][threadIdx.y] *
|
||||
__fdividef(1.0F, grid_sum[threadIdx.x]);
|
||||
}
|
||||
|
||||
}
|
||||
namespace infini {
|
||||
void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output) {
|
||||
int num_block_y = (max_function(N,d) + BLOCK_DIM_y - 1)/BLOCK_DIM_y;
|
||||
int num_block_x = (N + BLOCK_DIM_x - 1)/BLOCK_DIM_x;
|
||||
int share_mem = (5*BLOCK_DIM_y + 2)*BLOCK_DIM_x*sizeof(float);
|
||||
dim3 block_dim(BLOCK_DIM_x,BLOCK_DIM_y,1);
|
||||
dim3 grid_dim(num_block_x,num_block_y,1);
|
||||
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV, N, d, output);
|
||||
void attentionKernel(const float *inputQ, const float *inputK,
|
||||
const float *inputV, int N, int d, float *output) {
|
||||
int num_block_y = (max_function(N, d) + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
int num_block_x = (N + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||
int share_mem = (5 * BLOCK_DIM_y + 2) * BLOCK_DIM_x * sizeof(float);
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV,
|
||||
N, d, output);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -4,13 +4,14 @@
|
|||
namespace infini {
|
||||
|
||||
AttentionObj::AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK,
|
||||
Tensor inputV, Tensor output)
|
||||
Tensor inputV, Tensor output)
|
||||
: OperatorObj(OpType::Attention, TensorVec{inputQ, inputK, inputV},
|
||||
{output}) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> AttentionObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>>
|
||||
AttentionObj::inferShape(const TensorVec &inputs) const {
|
||||
auto shapeQ = inputs[0]->getDims();
|
||||
auto shapeK = inputs[1]->getDims();
|
||||
auto shapeV = inputs[2]->getDims();
|
||||
|
@ -37,6 +38,8 @@ vector<int> AttentionObj::getWorkloadVector() const {
|
|||
return ret;
|
||||
}
|
||||
|
||||
vector<int> AttentionObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||
vector<int> AttentionObj::getOpAttrVector() const {
|
||||
return {type.underlying()};
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -9,9 +9,9 @@
|
|||
namespace infini {
|
||||
|
||||
void test_attention(const Shape &outputShape, const vector<float> &inputQData,
|
||||
const vector<float> &inputKData,
|
||||
const vector<float> &inputVData,
|
||||
const vector<float> &ExpectData) {
|
||||
const vector<float> &inputKData,
|
||||
const vector<float> &inputVData,
|
||||
const vector<float> &ExpectData) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto inputV = gCpu->addTensor(outputShape, DataType::Float32);
|
||||
|
@ -31,7 +31,7 @@ void test_attention(const Shape &outputShape, const vector<float> &inputQData,
|
|||
auto inputKGpu = gCuda->cloneTensor(inputK);
|
||||
|
||||
auto op = gCuda->addOp<AttentionObj>(inputQGpu, inputKGpu, inputVGpu,
|
||||
nullptr); // AttentionObj
|
||||
nullptr); // AttentionObj
|
||||
gCuda->dataMalloc();
|
||||
inputVGpu->copyin(inputVData);
|
||||
inputQGpu->copyin(inputQData);
|
||||
|
@ -45,25 +45,32 @@ void test_attention(const Shape &outputShape, const vector<float> &inputQData,
|
|||
|
||||
TEST(CUDA_Attention, run) {
|
||||
test_attention(
|
||||
Shape{6,5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||
vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03,
|
||||
1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00,
|
||||
2.017291e+00, 2.945395e+00, 1.997352e-02, 1.017340e+00, 2.017291e+00,
|
||||
2.999871e+00, 3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00,
|
||||
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03,
|
||||
1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00});
|
||||
Shape{6, 5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
|
||||
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
|
||||
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
|
||||
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||
vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
|
||||
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
|
||||
8.536250e-03, 1.004909e+00, 2.017291e+00, 2.945395e+00,
|
||||
1.997352e-02, 1.017340e+00, 2.017291e+00, 2.999871e+00,
|
||||
3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00,
|
||||
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
|
||||
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
|
||||
8.536250e-03, 1.004909e+00});
|
||||
|
||||
test_attention(Shape{4, 3}, // inputQ
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.},
|
||||
vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370, 1.0179861,
|
||||
1.9941283, 2.9905086, 0.0210545, 1.0006673, 1.9993325, 2.9894698});
|
||||
test_attention(
|
||||
Shape{4, 3}, // inputQ
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.},
|
||||
vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370,
|
||||
1.0179861, 1.9941283, 2.9905086, 0.0210545, 1.0006673,
|
||||
1.9993325, 2.9894698});
|
||||
|
||||
} // python output
|
||||
|
||||
|
|
Loading…
Reference in New Issue