the baseline of flash attention

This commit is contained in:
xgqdut2016 2023-09-21 14:31:43 +08:00
parent 8f2597a508
commit 3f5178d069
11 changed files with 338 additions and 14 deletions

View File

@ -73,6 +73,8 @@ class GraphHandlerObj {
Tensor cast(Tensor input, Tensor output, int to); Tensor cast(Tensor input, Tensor output, int to);
Tensor expand(Tensor input, Tensor output, Shape dims); Tensor expand(Tensor input, Tensor output, Shape dims);
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output); Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
Tensor attention(Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output);
Tensor allReduceSum(Tensor input, Tensor output); Tensor allReduceSum(Tensor input, Tensor output);
Tensor allReduceProd(Tensor input, Tensor output); Tensor allReduceProd(Tensor input, Tensor output);

View File

@ -15,16 +15,17 @@ struct OpType {
// elements. // elements.
enum : underlying_t { enum : underlying_t {
Unknown, Unknown,
Abs, // Unary Abs, // Unary
Acos, // Unary Acos, // Unary
Acosh, // Unary Acosh, // Unary
Add, // Binary Add, // Binary
And, // Binary And, // Binary
ArgMax, // ArgMax, //
Asin, // Binary Asin, // Binary
Asinh, // Binary Asinh, // Binary
Atan, // Binary Atan, // Binary
Atanh, // Binary Atanh, // Binary
Attention,
AveragePool, // Pool AveragePool, // Pool
BatchNormalization, // BatchNormalization, //
Bernoulli, // Bernoulli, //

View File

@ -0,0 +1,7 @@
#pragma once
#include "operators/unary.h"
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output);
}; // namespace infini

View File

@ -0,0 +1,36 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Return elements, either from X or Y, depending on condition.
*
*/
class AttentionObj : public OperatorObj {
public:
/**
* @brief Construct a new Attention object.
*
* @param graph The computation graph that this operator belongs to.
* @param inputX The input tensor Q.
* @param inputY The input tensor K.
* @param output The output tensor.
* @param inputV The input tensor V.
*/
AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output);
OP_CLONE(AttentionObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -32,6 +32,7 @@ class OnnxStub:
The Onnx model imported into infinitensor. The Onnx model imported into infinitensor.
It can be generated from an Onnx model object. It can be generated from an Onnx model object.
""" """
def __init__(self, model: ModelProto, runtime): def __init__(self, model: ModelProto, runtime):
self.inputs: Dict[str, backend.Tensor] = {} self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {}
@ -60,7 +61,6 @@ class OnnxStub:
dims, output.type.tensor_type.elem_type dims, output.type.tensor_type.elem_type
) )
node_name = [] node_name = []
new_node_name = [] new_node_name = []
for node in model.graph.node: for node in model.graph.node:
@ -632,6 +632,13 @@ class OnnxStub:
), ),
): ):
tensors[name] = tensor tensors[name] = tensor
elif node.op_type == "Attention":
tensors[node.output[0]] = self.handler.attention(
tensors[node.input[0]],
tensors[node.input[1]],
tensors[node.input[2]],
tensors.get(node.output[0]),
)
elif node.op_type == "Broadcast": elif node.op_type == "Broadcast":
tensors[node.output[0]] = self.handler.broadcast( tensors[node.output[0]] = self.handler.broadcast(
tensors[node.input[0]], tensors[node.input[0]],
@ -674,11 +681,10 @@ class OnnxStub:
for input in model.graph.input: for input in model.graph.input:
tensors[input.name].set_input() tensors[input.name].set_input()
for output in model.graph.output: for output in model.graph.output:
tensors[output.name].set_output() tensors[output.name].set_output()
################################ ################################
# Allocate memory space for data # Allocate memory space for data
################################ ################################
@ -1002,6 +1008,10 @@ class OnnxStub:
assert len(inputs) == 3, "Check Where Op must have three inputs." assert len(inputs) == 3, "Check Where Op must have three inputs."
new_inputs = [inputs[2], inputs[0], inputs[1]] new_inputs = [inputs[2], inputs[0], inputs[1]]
ctx.push_node(make_node(ty.name, new_inputs, outputs, name)) ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
elif ty == backend.OpTypeId.Attention:
assert len(inputs) == 3, "Check Attention Op must have three inputs."
new_inputs = [inputs[2], inputs[0], inputs[1]]
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
elif ty == backend.OpTypeId.Expand: elif ty == backend.OpTypeId.Expand:
shape = backend.expand_shape_of(op) shape = backend.expand_shape_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape)) ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape))

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "operators/all_gather.h" #include "operators/all_gather.h"
#include "operators/all_reduce.h" #include "operators/all_reduce.h"
#include "operators/attention.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/broadcast.h" #include "operators/broadcast.h"
#include "operators/concat.h" #include "operators/concat.h"
@ -406,7 +407,19 @@ Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
->getOutput(); ->getOutput();
} }
} }
Tensor GraphHandlerObj::attention(Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output) {
if (output) {
g->addOpWithOutputs<AttentionObj>(std::move(inputQ), std::move(inputK),
std::move(inputV), output);
return output;
} else {
return g
->addOp<AttentionObj>(std::move(inputQ), std::move(inputK),
std::move(inputV), output)
->getOutput();
}
}
static CastType inferCastType(Tensor input, int to) { static CastType inferCastType(Tensor input, int to) {
auto iType = input->getDType(); auto iType = input->getDType();
auto oType = DataType(to); auto oType = DataType(to);

View File

@ -1,5 +1,6 @@
#include "core/data_type.h" #include "core/data_type.h"
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "operators/attention.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/conv.h" #include "operators/conv.h"
@ -67,6 +68,7 @@ void export_values(py::module &m) {
.def(py::init<decltype(OpType::type)>()) .def(py::init<decltype(OpType::type)>())
.def("id", getId, policy::automatic); .def("id", getId, policy::automatic);
py::enum_<decltype(OpType::type)>(m, "OpTypeId") py::enum_<decltype(OpType::type)>(m, "OpTypeId")
.VALUE(OpType, Attention)
.VALUE(OpType, Conv) .VALUE(OpType, Conv)
.VALUE(OpType, MatMul) .VALUE(OpType, MatMul)
.VALUE(OpType, ConvTranspose) .VALUE(OpType, ConvTranspose)
@ -424,6 +426,7 @@ void init_graph_builder(py::module &m) {
policy::reference); policy::reference);
py::class_<Handler>(m, "GraphHandler") py::class_<Handler>(m, "GraphHandler")
.def(py::init<Runtime>()) .def(py::init<Runtime>())
.def("attention", &Handler::attention, policy::move)
.def("tensor", &Handler::tensor, policy::move) .def("tensor", &Handler::tensor, policy::move)
.def("conv", &Handler::conv, policy::move) .def("conv", &Handler::conv, policy::move)
.def("convTransposed2d", &Handler::convTransposed2d, policy::move) .def("convTransposed2d", &Handler::convTransposed2d, policy::move)

View File

@ -0,0 +1,29 @@
#include "operators/attention.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_attention.h"
namespace infini {
class AttentionCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<AttentionObj>(_op);
void *const inputQData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const inputKData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const inputVData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
int N = op->getInputs(0)->getDims()[0];
int d = op->getInputs(0)->getDims()[1];
attentionKernel((float *)inputQData, (float *)inputKData,
(float *)inputVData, N, d, (float *)outputData);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32, AttentionCuda,
"Attention_CUDA_Float32");
}; // namespace infini

View File

@ -0,0 +1,111 @@
#include "cuda/cuda_common.h"
#define BLOCK_DIM_x 2
#define BLOCK_DIM_y 2
#define max_function(a,b) ((a)>(b)?(a):(b))
__global__ void _attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output){
int i = threadIdx.x + blockIdx.x * blockDim.x; //
int phNumN = (N + BLOCK_DIM_y - 1)/BLOCK_DIM_y;
__shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y];
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
__shared__ float grid_sum[BLOCK_DIM_x];
__shared__ float grid_max[BLOCK_DIM_x];
__shared__ float grid_max_old[BLOCK_DIM_x];
grid_max[threadIdx.x] = -__FLT_MAX__;
grid_max_old[threadIdx.x] = -__FLT_MAX__;
grid_sum[threadIdx.x] = 0.0f;
__shared__ float S[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float Out_new[BLOCK_DIM_x][BLOCK_DIM_y];
Out_new[threadIdx.x][threadIdx.y] = 0.0f;
for(int phn = 0; phn < phNumN; phn++){
int j = threadIdx.y + phn*BLOCK_DIM_y;
if(i < N && j < N){
float sum_s = 0;
for(int index = 0; index < d; index++){
sum_s += inputQ[i * d + index] * inputK[j * d + index];
}
S[threadIdx.x][threadIdx.y] = sum_s;
block_sum[threadIdx.x][threadIdx.y] = 1.0f;
block_max[threadIdx.x][threadIdx.y] = sum_s;
}
else{
S[threadIdx.x][threadIdx.y] = 0.0f;
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
}
//----------------fix i, compute the max S[i,j] of this block
__syncthreads();
for(int strip = BLOCK_DIM_y/2; strip > 0; strip = strip/2){
if (threadIdx.y < strip){
if(block_max[threadIdx.x][threadIdx.y] > block_max[threadIdx.x][threadIdx.y + strip]){
block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y] + block_sum[threadIdx.x][threadIdx.y + strip]*__expf(block_max[threadIdx.x][threadIdx.y + strip] - block_max[threadIdx.x][threadIdx.y]);
}
else{
block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y + strip] + block_sum[threadIdx.x][threadIdx.y]*__expf(block_max[threadIdx.x][threadIdx.y] - block_max[threadIdx.x][threadIdx.y + strip]);
block_max[threadIdx.x][threadIdx.y] = block_max[threadIdx.x][threadIdx.y + strip];
}
}
}//block_max[threadIdx.x][0]store the local max of this block
__syncthreads();
if(threadIdx.y == 0){
if(grid_max[threadIdx.x] > block_max[threadIdx.x][0]){
grid_sum[threadIdx.x] = grid_sum[threadIdx.x] + block_sum[threadIdx.x][0] * __expf(block_max[threadIdx.x][0] - grid_max[threadIdx.x]);
}
else{
grid_sum[threadIdx.x] = block_sum[threadIdx.x][0] + grid_sum[threadIdx.x]*__expf(grid_max[threadIdx.x] - block_max[threadIdx.x][0]);
grid_max[threadIdx.x] = block_max[threadIdx.x][0];
}//compare the max between the different blocks, when the loop end, grid_max store the global max
}
__syncthreads();
S[threadIdx.x][threadIdx.y] = __expf(S[threadIdx.x][threadIdx.y] - grid_max[threadIdx.x]); //softmax(s)*L
__syncthreads();
int vj = threadIdx.y + blockIdx.y * blockDim.y;
//do not write vj = threadIdx.y + ph * blockDim.y
float sum_o;
if(vj < d){
sum_o = 0;
for(int vid = 0; vid < BLOCK_DIM_y; vid++){
if(vid + phn * BLOCK_DIM_y < N){
sum_o += S[threadIdx.x][vid]*inputV[(vid + phn * BLOCK_DIM_y) * d + vj];
}
}
Out_new[threadIdx.x][threadIdx.y] = __expf(grid_max_old[threadIdx.x] - grid_max[threadIdx.x])*Out_new[threadIdx.x][threadIdx.y] + sum_o;
grid_max_old[threadIdx.x] = grid_max[threadIdx.x];
}
}
__syncthreads();
int j = threadIdx.y + blockIdx.y * blockDim.y;
if(i < N && j < d){
output[i * d + j] = Out_new[threadIdx.x][threadIdx.y]*__fdividef(1.0F, grid_sum[threadIdx.x]);
}
}
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output) {
int num_block_y = (max_function(N,d) + BLOCK_DIM_y - 1)/BLOCK_DIM_y;
int num_block_x = (N + BLOCK_DIM_x - 1)/BLOCK_DIM_x;
int share_mem = (5*BLOCK_DIM_y + 2)*BLOCK_DIM_x*sizeof(float);
dim3 block_dim(BLOCK_DIM_x,BLOCK_DIM_y,1);
dim3 grid_dim(num_block_x,num_block_y,1);
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV, N, d, output);
}
} // namespace infini

View File

@ -0,0 +1,42 @@
#include "operators/attention.h"
#include "utils/operator_utils.h"
namespace infini {
AttentionObj::AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK,
Tensor inputV, Tensor output)
: OperatorObj(OpType::Attention, TensorVec{inputQ, inputK, inputV},
{output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> AttentionObj::inferShape(const TensorVec &inputs) const {
auto shapeQ = inputs[0]->getDims();
auto shapeK = inputs[1]->getDims();
auto shapeV = inputs[2]->getDims();
auto retQK = infer_broadcast(shapeQ, shapeK);
auto ret = infer_broadcast(retQK, shapeV);
return {{ret}};
}
std::string AttentionObj::toString() const {
std::ostringstream os;
os << "Attention[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[2]->getDims()) << ",";
os << "inputQ=" << inputs[0]->getGuid() << ",";
os << "inputK=" << inputs[1]->getGuid() << ",";
os << "inputV=" << inputs[2]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> AttentionObj::getWorkloadVector() const {
vector<int> ret = getOutput()->getDims();
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> AttentionObj::getOpAttrVector() const { return {type.underlying()}; }
} // namespace infini

View File

@ -0,0 +1,70 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/attention.h"
#include "test.h"
namespace infini {
void test_attention(const Shape &outputShape, const vector<float> &inputQData,
const vector<float> &inputKData,
const vector<float> &inputVData,
const vector<float> &ExpectData) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto inputV = gCpu->addTensor(outputShape, DataType::Float32);
auto inputQ = gCpu->addTensor(outputShape, DataType::Float32);
auto inputK = gCpu->addTensor(outputShape, DataType::Float32);
gCpu->dataMalloc();
inputV->copyin(inputVData); //
inputQ->copyin(inputQData);
inputK->copyin(inputKData); //
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputVGpu = gCuda->cloneTensor(inputV);
auto inputQGpu = gCuda->cloneTensor(inputQ);
auto inputKGpu = gCuda->cloneTensor(inputK);
auto op = gCuda->addOp<AttentionObj>(inputQGpu, inputKGpu, inputVGpu,
nullptr); // AttentionObj
gCuda->dataMalloc();
inputVGpu->copyin(inputVData);
inputQGpu->copyin(inputQData);
inputKGpu->copyin(inputKData);
cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
}
TEST(CUDA_Attention, run) {
test_attention(
Shape{6,5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03,
1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00,
2.017291e+00, 2.945395e+00, 1.997352e-02, 1.017340e+00, 2.017291e+00,
2.999871e+00, 3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00,
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03,
1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00});
test_attention(Shape{4, 3}, // inputQ
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.},
vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370, 1.0179861,
1.9941283, 2.9905086, 0.0210545, 1.0006673, 1.9993325, 2.9894698});
} // python output
} // namespace infini