the baseline of flash attention

This commit is contained in:
xgqdut2016 2023-09-21 14:31:43 +08:00
parent 8f2597a508
commit 3f5178d069
11 changed files with 338 additions and 14 deletions

View File

@ -73,6 +73,8 @@ class GraphHandlerObj {
Tensor cast(Tensor input, Tensor output, int to);
Tensor expand(Tensor input, Tensor output, Shape dims);
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
Tensor attention(Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output);
Tensor allReduceSum(Tensor input, Tensor output);
Tensor allReduceProd(Tensor input, Tensor output);

View File

@ -15,16 +15,17 @@ struct OpType {
// elements.
enum : underlying_t {
Unknown,
Abs, // Unary
Acos, // Unary
Acosh, // Unary
Add, // Binary
And, // Binary
ArgMax, //
Asin, // Binary
Asinh, // Binary
Atan, // Binary
Atanh, // Binary
Abs, // Unary
Acos, // Unary
Acosh, // Unary
Add, // Binary
And, // Binary
ArgMax, //
Asin, // Binary
Asinh, // Binary
Atan, // Binary
Atanh, // Binary
Attention,
AveragePool, // Pool
BatchNormalization, //
Bernoulli, //

View File

@ -0,0 +1,7 @@
#pragma once
#include "operators/unary.h"
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output);
}; // namespace infini

View File

@ -0,0 +1,36 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Return elements, either from X or Y, depending on condition.
*
*/
class AttentionObj : public OperatorObj {
public:
/**
* @brief Construct a new Attention object.
*
* @param graph The computation graph that this operator belongs to.
* @param inputX The input tensor Q.
* @param inputY The input tensor K.
* @param output The output tensor.
* @param inputV The input tensor V.
*/
AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output);
OP_CLONE(AttentionObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -32,6 +32,7 @@ class OnnxStub:
The Onnx model imported into infinitensor.
It can be generated from an Onnx model object.
"""
def __init__(self, model: ModelProto, runtime):
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}
@ -60,7 +61,6 @@ class OnnxStub:
dims, output.type.tensor_type.elem_type
)
node_name = []
new_node_name = []
for node in model.graph.node:
@ -632,6 +632,13 @@ class OnnxStub:
),
):
tensors[name] = tensor
elif node.op_type == "Attention":
tensors[node.output[0]] = self.handler.attention(
tensors[node.input[0]],
tensors[node.input[1]],
tensors[node.input[2]],
tensors.get(node.output[0]),
)
elif node.op_type == "Broadcast":
tensors[node.output[0]] = self.handler.broadcast(
tensors[node.input[0]],
@ -674,11 +681,10 @@ class OnnxStub:
for input in model.graph.input:
tensors[input.name].set_input()
for output in model.graph.output:
tensors[output.name].set_output()
################################
# Allocate memory space for data
################################
@ -1002,6 +1008,10 @@ class OnnxStub:
assert len(inputs) == 3, "Check Where Op must have three inputs."
new_inputs = [inputs[2], inputs[0], inputs[1]]
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
elif ty == backend.OpTypeId.Attention:
assert len(inputs) == 3, "Check Attention Op must have three inputs."
new_inputs = [inputs[2], inputs[0], inputs[1]]
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
elif ty == backend.OpTypeId.Expand:
shape = backend.expand_shape_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape))

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h"
#include "operators/all_gather.h"
#include "operators/all_reduce.h"
#include "operators/attention.h"
#include "operators/batch_norm.h"
#include "operators/broadcast.h"
#include "operators/concat.h"
@ -406,7 +407,19 @@ Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
->getOutput();
}
}
Tensor GraphHandlerObj::attention(Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output) {
if (output) {
g->addOpWithOutputs<AttentionObj>(std::move(inputQ), std::move(inputK),
std::move(inputV), output);
return output;
} else {
return g
->addOp<AttentionObj>(std::move(inputQ), std::move(inputK),
std::move(inputV), output)
->getOutput();
}
}
static CastType inferCastType(Tensor input, int to) {
auto iType = input->getDType();
auto oType = DataType(to);

View File

@ -1,5 +1,6 @@
#include "core/data_type.h"
#include "core/graph_handler.h"
#include "operators/attention.h"
#include "operators/batch_norm.h"
#include "operators/concat.h"
#include "operators/conv.h"
@ -67,6 +68,7 @@ void export_values(py::module &m) {
.def(py::init<decltype(OpType::type)>())
.def("id", getId, policy::automatic);
py::enum_<decltype(OpType::type)>(m, "OpTypeId")
.VALUE(OpType, Attention)
.VALUE(OpType, Conv)
.VALUE(OpType, MatMul)
.VALUE(OpType, ConvTranspose)
@ -424,6 +426,7 @@ void init_graph_builder(py::module &m) {
policy::reference);
py::class_<Handler>(m, "GraphHandler")
.def(py::init<Runtime>())
.def("attention", &Handler::attention, policy::move)
.def("tensor", &Handler::tensor, policy::move)
.def("conv", &Handler::conv, policy::move)
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)

View File

@ -0,0 +1,29 @@
#include "operators/attention.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_attention.h"
namespace infini {
class AttentionCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<AttentionObj>(_op);
void *const inputQData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const inputKData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const inputVData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
int N = op->getInputs(0)->getDims()[0];
int d = op->getInputs(0)->getDims()[1];
attentionKernel((float *)inputQData, (float *)inputKData,
(float *)inputVData, N, d, (float *)outputData);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32, AttentionCuda,
"Attention_CUDA_Float32");
}; // namespace infini

View File

@ -0,0 +1,111 @@
#include "cuda/cuda_common.h"
#define BLOCK_DIM_x 2
#define BLOCK_DIM_y 2
#define max_function(a,b) ((a)>(b)?(a):(b))
__global__ void _attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output){
int i = threadIdx.x + blockIdx.x * blockDim.x; //
int phNumN = (N + BLOCK_DIM_y - 1)/BLOCK_DIM_y;
__shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y];
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
__shared__ float grid_sum[BLOCK_DIM_x];
__shared__ float grid_max[BLOCK_DIM_x];
__shared__ float grid_max_old[BLOCK_DIM_x];
grid_max[threadIdx.x] = -__FLT_MAX__;
grid_max_old[threadIdx.x] = -__FLT_MAX__;
grid_sum[threadIdx.x] = 0.0f;
__shared__ float S[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float Out_new[BLOCK_DIM_x][BLOCK_DIM_y];
Out_new[threadIdx.x][threadIdx.y] = 0.0f;
for(int phn = 0; phn < phNumN; phn++){
int j = threadIdx.y + phn*BLOCK_DIM_y;
if(i < N && j < N){
float sum_s = 0;
for(int index = 0; index < d; index++){
sum_s += inputQ[i * d + index] * inputK[j * d + index];
}
S[threadIdx.x][threadIdx.y] = sum_s;
block_sum[threadIdx.x][threadIdx.y] = 1.0f;
block_max[threadIdx.x][threadIdx.y] = sum_s;
}
else{
S[threadIdx.x][threadIdx.y] = 0.0f;
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
}
//----------------fix i, compute the max S[i,j] of this block
__syncthreads();
for(int strip = BLOCK_DIM_y/2; strip > 0; strip = strip/2){
if (threadIdx.y < strip){
if(block_max[threadIdx.x][threadIdx.y] > block_max[threadIdx.x][threadIdx.y + strip]){
block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y] + block_sum[threadIdx.x][threadIdx.y + strip]*__expf(block_max[threadIdx.x][threadIdx.y + strip] - block_max[threadIdx.x][threadIdx.y]);
}
else{
block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y + strip] + block_sum[threadIdx.x][threadIdx.y]*__expf(block_max[threadIdx.x][threadIdx.y] - block_max[threadIdx.x][threadIdx.y + strip]);
block_max[threadIdx.x][threadIdx.y] = block_max[threadIdx.x][threadIdx.y + strip];
}
}
}//block_max[threadIdx.x][0]store the local max of this block
__syncthreads();
if(threadIdx.y == 0){
if(grid_max[threadIdx.x] > block_max[threadIdx.x][0]){
grid_sum[threadIdx.x] = grid_sum[threadIdx.x] + block_sum[threadIdx.x][0] * __expf(block_max[threadIdx.x][0] - grid_max[threadIdx.x]);
}
else{
grid_sum[threadIdx.x] = block_sum[threadIdx.x][0] + grid_sum[threadIdx.x]*__expf(grid_max[threadIdx.x] - block_max[threadIdx.x][0]);
grid_max[threadIdx.x] = block_max[threadIdx.x][0];
}//compare the max between the different blocks, when the loop end, grid_max store the global max
}
__syncthreads();
S[threadIdx.x][threadIdx.y] = __expf(S[threadIdx.x][threadIdx.y] - grid_max[threadIdx.x]); //softmax(s)*L
__syncthreads();
int vj = threadIdx.y + blockIdx.y * blockDim.y;
//do not write vj = threadIdx.y + ph * blockDim.y
float sum_o;
if(vj < d){
sum_o = 0;
for(int vid = 0; vid < BLOCK_DIM_y; vid++){
if(vid + phn * BLOCK_DIM_y < N){
sum_o += S[threadIdx.x][vid]*inputV[(vid + phn * BLOCK_DIM_y) * d + vj];
}
}
Out_new[threadIdx.x][threadIdx.y] = __expf(grid_max_old[threadIdx.x] - grid_max[threadIdx.x])*Out_new[threadIdx.x][threadIdx.y] + sum_o;
grid_max_old[threadIdx.x] = grid_max[threadIdx.x];
}
}
__syncthreads();
int j = threadIdx.y + blockIdx.y * blockDim.y;
if(i < N && j < d){
output[i * d + j] = Out_new[threadIdx.x][threadIdx.y]*__fdividef(1.0F, grid_sum[threadIdx.x]);
}
}
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output) {
int num_block_y = (max_function(N,d) + BLOCK_DIM_y - 1)/BLOCK_DIM_y;
int num_block_x = (N + BLOCK_DIM_x - 1)/BLOCK_DIM_x;
int share_mem = (5*BLOCK_DIM_y + 2)*BLOCK_DIM_x*sizeof(float);
dim3 block_dim(BLOCK_DIM_x,BLOCK_DIM_y,1);
dim3 grid_dim(num_block_x,num_block_y,1);
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV, N, d, output);
}
} // namespace infini

View File

@ -0,0 +1,42 @@
#include "operators/attention.h"
#include "utils/operator_utils.h"
namespace infini {
AttentionObj::AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK,
Tensor inputV, Tensor output)
: OperatorObj(OpType::Attention, TensorVec{inputQ, inputK, inputV},
{output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> AttentionObj::inferShape(const TensorVec &inputs) const {
auto shapeQ = inputs[0]->getDims();
auto shapeK = inputs[1]->getDims();
auto shapeV = inputs[2]->getDims();
auto retQK = infer_broadcast(shapeQ, shapeK);
auto ret = infer_broadcast(retQK, shapeV);
return {{ret}};
}
std::string AttentionObj::toString() const {
std::ostringstream os;
os << "Attention[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[2]->getDims()) << ",";
os << "inputQ=" << inputs[0]->getGuid() << ",";
os << "inputK=" << inputs[1]->getGuid() << ",";
os << "inputV=" << inputs[2]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> AttentionObj::getWorkloadVector() const {
vector<int> ret = getOutput()->getDims();
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> AttentionObj::getOpAttrVector() const { return {type.underlying()}; }
} // namespace infini

View File

@ -0,0 +1,70 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/attention.h"
#include "test.h"
namespace infini {
void test_attention(const Shape &outputShape, const vector<float> &inputQData,
const vector<float> &inputKData,
const vector<float> &inputVData,
const vector<float> &ExpectData) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto inputV = gCpu->addTensor(outputShape, DataType::Float32);
auto inputQ = gCpu->addTensor(outputShape, DataType::Float32);
auto inputK = gCpu->addTensor(outputShape, DataType::Float32);
gCpu->dataMalloc();
inputV->copyin(inputVData); //
inputQ->copyin(inputQData);
inputK->copyin(inputKData); //
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputVGpu = gCuda->cloneTensor(inputV);
auto inputQGpu = gCuda->cloneTensor(inputQ);
auto inputKGpu = gCuda->cloneTensor(inputK);
auto op = gCuda->addOp<AttentionObj>(inputQGpu, inputKGpu, inputVGpu,
nullptr); // AttentionObj
gCuda->dataMalloc();
inputVGpu->copyin(inputVData);
inputQGpu->copyin(inputQData);
inputKGpu->copyin(inputKData);
cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
}
TEST(CUDA_Attention, run) {
test_attention(
Shape{6,5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03,
1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00,
2.017291e+00, 2.945395e+00, 1.997352e-02, 1.017340e+00, 2.017291e+00,
2.999871e+00, 3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00,
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03,
1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00});
test_attention(Shape{4, 3}, // inputQ
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.},
vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370, 1.0179861,
1.9941283, 2.9905086, 0.0210545, 1.0006673, 1.9993325, 2.9894698});
} // python output
} // namespace infini