Compare commits

...

12 Commits

Author SHA1 Message Date
xgqdut2016 54f4265296 modified logic 2023-11-17 17:43:52 +08:00
xgqdut2016 3629881dfa blockreduce matrix times 2023-10-08 15:15:39 +08:00
xgqdut2016 79dd3364df modified threadIdx.y to threadIdx.x 2023-10-08 11:33:16 +08:00
xgqdut2016 56e2c87c9b modified reduce,8ms 2023-10-07 18:14:12 +08:00
xgqdut2016 819484eda2 matrix reduce,threadIdx.x=0,17ms 2023-09-28 16:49:01 +08:00
xgqdut2016 ddbec7d60a BLOCK_DIM_x=1,num_block_x=N 2023-09-28 12:37:10 +08:00
xgqdut2016 b640ab1689 modified attention.cu,BLOCK_DIM_x must leq 32 2023-09-26 14:53:02 +08:00
xgqdut2016 ec391674ac 2D block, share S 2023-09-25 13:02:25 +08:00
xgqdut2016 f1dc440a3c 1D attention ,global S matrix 2023-09-22 16:59:42 +08:00
xgqdut2016 b1a2d91aba modified the format from master 2023-09-21 15:03:57 +08:00
xgqdut2016 410844c058 modified error in onnx.py 2023-09-21 14:59:12 +08:00
xgqdut2016 3f5178d069 the baseline of flash attention 2023-09-21 14:31:43 +08:00
11 changed files with 424 additions and 14 deletions

View File

@ -73,6 +73,8 @@ class GraphHandlerObj {
Tensor cast(Tensor input, Tensor output, int to);
Tensor expand(Tensor input, Tensor output, Shape dims);
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
Tensor attention(Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output);
Tensor allReduceSum(Tensor input, Tensor output);
Tensor allReduceProd(Tensor input, Tensor output);

View File

@ -15,16 +15,17 @@ struct OpType {
// elements.
enum : underlying_t {
Unknown,
Abs, // Unary
Acos, // Unary
Acosh, // Unary
Add, // Binary
And, // Binary
ArgMax, //
Asin, // Binary
Asinh, // Binary
Atan, // Binary
Atanh, // Binary
Abs, // Unary
Acos, // Unary
Acosh, // Unary
Add, // Binary
And, // Binary
ArgMax, //
Asin, // Binary
Asinh, // Binary
Atan, // Binary
Atanh, // Binary
Attention,
AveragePool, // Pool
BatchNormalization, //
Bernoulli, //

View File

@ -0,0 +1,8 @@
#pragma once
#include "operators/unary.h"
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d, float *output);
}; // namespace infini

View File

@ -0,0 +1,36 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Return elements, either from X or Y, depending on condition.
*
*/
class AttentionObj : public OperatorObj {
public:
/**
* @brief Construct a new Attention object.
*
* @param graph The computation graph that this operator belongs to.
* @param inputX The input tensor Q.
* @param inputY The input tensor K.
* @param output The output tensor.
* @param inputV The input tensor V.
*/
AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output);
OP_CLONE(AttentionObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -32,6 +32,7 @@ class OnnxStub:
The Onnx model imported into infinitensor.
It can be generated from an Onnx model object.
"""
def __init__(self, model: ModelProto, runtime):
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}
@ -60,7 +61,6 @@ class OnnxStub:
dims, output.type.tensor_type.elem_type
)
node_name = []
new_node_name = []
for node in model.graph.node:
@ -632,6 +632,13 @@ class OnnxStub:
),
):
tensors[name] = tensor
elif node.op_type == "Attention":
tensors[node.output[0]] = self.handler.attention(
tensors[node.input[0]],
tensors[node.input[1]],
tensors[node.input[2]],
tensors.get(node.output[0]),
)
elif node.op_type == "Broadcast":
tensors[node.output[0]] = self.handler.broadcast(
tensors[node.input[0]],
@ -678,7 +685,6 @@ class OnnxStub:
for output in model.graph.output:
tensors[output.name].set_output()
################################
# Allocate memory space for data
################################
@ -1002,6 +1008,10 @@ class OnnxStub:
assert len(inputs) == 3, "Check Where Op must have three inputs."
new_inputs = [inputs[2], inputs[0], inputs[1]]
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
elif ty == backend.OpTypeId.Attention:
assert len(inputs) == 3, "Check Attention Op must have three inputs."
new_inputs = [inputs[0], inputs[1], inputs[2]]
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
elif ty == backend.OpTypeId.Expand:
shape = backend.expand_shape_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape))

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h"
#include "operators/all_gather.h"
#include "operators/all_reduce.h"
#include "operators/attention.h"
#include "operators/batch_norm.h"
#include "operators/broadcast.h"
#include "operators/concat.h"
@ -406,7 +407,19 @@ Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
->getOutput();
}
}
Tensor GraphHandlerObj::attention(Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output) {
if (output) {
g->addOpWithOutputs<AttentionObj>(std::move(inputQ), std::move(inputK),
std::move(inputV), output);
return output;
} else {
return g
->addOp<AttentionObj>(std::move(inputQ), std::move(inputK),
std::move(inputV), output)
->getOutput();
}
}
static CastType inferCastType(Tensor input, int to) {
auto iType = input->getDType();
auto oType = DataType(to);

View File

@ -1,5 +1,6 @@
#include "core/data_type.h"
#include "core/graph_handler.h"
#include "operators/attention.h"
#include "operators/batch_norm.h"
#include "operators/concat.h"
#include "operators/conv.h"
@ -67,6 +68,7 @@ void export_values(py::module &m) {
.def(py::init<decltype(OpType::type)>())
.def("id", getId, policy::automatic);
py::enum_<decltype(OpType::type)>(m, "OpTypeId")
.VALUE(OpType, Attention)
.VALUE(OpType, Conv)
.VALUE(OpType, MatMul)
.VALUE(OpType, ConvTranspose)
@ -424,6 +426,7 @@ void init_graph_builder(py::module &m) {
policy::reference);
py::class_<Handler>(m, "GraphHandler")
.def(py::init<Runtime>())
.def("attention", &Handler::attention, policy::move)
.def("tensor", &Handler::tensor, policy::move)
.def("conv", &Handler::conv, policy::move)
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)

View File

@ -0,0 +1,28 @@
#include "operators/attention.h"
#include "cuda/cuda_attention.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class AttentionCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<AttentionObj>(_op);
void *const inputQData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const inputKData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const inputVData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
int N = op->getInputs(0)->getDims()[0];
int d = op->getInputs(0)->getDims()[1];
attentionKernel((float *)inputQData, (float *)inputKData,
(float *)inputVData, N, d, (float *)outputData);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32,
AttentionCuda, "Attention_CUDA_Float32");
}; // namespace infini

View File

@ -0,0 +1,187 @@
#include "cuda/cuda_common.h"
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void _attentionKernel(const float *__restrict inputQ,
const float *__restrict inputK,
const float *__restrict inputV, int N, int d,
float *__restrict output) {
int i = blockIdx.y; // i must < N,Q[i]
int phd = threadIdx.x + blockIdx.x * blockDim.x; // V[:,d]
int phNumN = (N + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
__shared__ float inputS[BLOCK_DIM_x][BLOCK_DIM_y];
float newMax;
float oldMax;
float newSum;
newMax = -__FLT_MAX__;
oldMax = -__FLT_MAX__;
newSum = 0.0f;
float out;
out = 0.0f;
//---------
__shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float sum_partial[BLOCK_DIM_x][BLOCK_DIM_y];
int extra = d % BLOCK_DIM_x;
int step = (d - extra) / BLOCK_DIM_x;
for (int phn = 0; phn < phNumN; phn++) {
int j = threadIdx.y + phn * BLOCK_DIM_y;
float sum_r = 0.0f;
__syncthreads();
if (threadIdx.x < extra) {
for (int ind = threadIdx.x * (step + 1);
ind < (threadIdx.x + 1) * (step + 1); ind++) {
sum_r += inputQ[i * d + ind] * inputK[j * d + ind];
}
} else {
for (int ind = extra * (step + 1) + (threadIdx.x - extra) * step;
ind < extra * (step + 1) + (threadIdx.x - extra + 1) * step;
ind++) {
sum_r += inputQ[i * d + ind] * inputK[j * d + ind];
}
}
if (j < N) {
sum_partial[threadIdx.x][threadIdx.y] = sum_r;
} else {
sum_partial[threadIdx.x][threadIdx.y] = 0.0f;
}
__syncthreads();
for (int strip = BLOCK_DIM_x / 2; strip > 0; strip /= 2) {
if (threadIdx.x < strip) {
sum_partial[threadIdx.x][threadIdx.y] +=
sum_partial[threadIdx.x + strip][threadIdx.y];
}
__syncthreads();
}
float sum_s = sum_partial[0][threadIdx.y];
if (j < N) {
block_sum[threadIdx.x][threadIdx.y] = 1.0f;
} else {
sum_partial[0][threadIdx.y] = -__FLT_MAX__;
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
}
__syncthreads();
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip /= 2) {
if (threadIdx.y < strip) {
if (sum_partial[0][threadIdx.y] >
sum_partial[0][threadIdx.y + strip]) {
block_sum[threadIdx.x][threadIdx.y] =
block_sum[threadIdx.x][threadIdx.y] +
block_sum[threadIdx.x][threadIdx.y + strip] *
__expf(sum_partial[0][threadIdx.y + strip] -
sum_partial[0][threadIdx.y]);
} else {
block_sum[threadIdx.x][threadIdx.y] =
block_sum[threadIdx.x][threadIdx.y + strip] +
block_sum[threadIdx.x][threadIdx.y] *
__expf(sum_partial[0][threadIdx.y] -
sum_partial[0][threadIdx.y + strip]);
sum_partial[0][threadIdx.y] =
sum_partial[0][threadIdx.y + strip];
}
}
__syncthreads();
}
if (newMax > sum_partial[0][0]) {
newSum = newSum + block_sum[threadIdx.x][0] *
__expf(sum_partial[0][0] - newMax);
} else {
newSum = block_sum[threadIdx.x][0] +
newSum * __expf(newMax - sum_partial[0][0]);
newMax = sum_partial[0][0];
}
if (j < N && phd < d) {
inputS[threadIdx.x][threadIdx.y] =
__expf(sum_s - newMax) *
inputV[(threadIdx.y + phn * BLOCK_DIM_y) * d + phd];
} else {
inputS[threadIdx.x][threadIdx.y] = 0.0f;
}
__syncthreads();
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip /= 2) {
if (threadIdx.y < strip) {
inputS[threadIdx.x][threadIdx.y] +=
inputS[threadIdx.x][threadIdx.y + strip];
}
__syncthreads();
}
if (j < N && phd < d) {
out = __expf(oldMax - newMax) * out + inputS[threadIdx.x][0];
}
oldMax = newMax;
}
if (threadIdx.y + (phNumN - 1) * BLOCK_DIM_y < N && phd < d) {
output[i * d + phd] = out * __fdividef(1.0F, newSum);
}
}
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d, float *output) {
int num_block_y = N;
if (d > 512) {
int BLOCK_DIM_x = 1024;
int BLOCK_DIM_y = 1;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<1024, 1>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 256) {
int BLOCK_DIM_x = 512;
int BLOCK_DIM_y = 2;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<512, 2>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 128) {
int BLOCK_DIM_x = 256;
int BLOCK_DIM_y = 4;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<256, 4>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 64) {
int BLOCK_DIM_x = 128;
int BLOCK_DIM_y = 8;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<128, 8>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 32) {
int BLOCK_DIM_x = 64;
int BLOCK_DIM_y = 16;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<64, 16>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 16) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<32, 32>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<16, 64>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
}
}
} // namespace infini

View File

@ -0,0 +1,45 @@
#include "operators/attention.h"
#include "utils/operator_utils.h"
namespace infini {
AttentionObj::AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK,
Tensor inputV, Tensor output)
: OperatorObj(OpType::Attention, TensorVec{inputQ, inputK, inputV},
{output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>>
AttentionObj::inferShape(const TensorVec &inputs) const {
auto shapeQ = inputs[0]->getDims();
auto shapeK = inputs[1]->getDims();
auto shapeV = inputs[2]->getDims();
auto retQK = infer_broadcast(shapeQ, shapeK);
auto ret = infer_broadcast(retQK, shapeV);
return {{ret}};
}
std::string AttentionObj::toString() const {
std::ostringstream os;
os << "Attention[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[2]->getDims()) << ",";
os << "inputQ=" << inputs[0]->getGuid() << ",";
os << "inputK=" << inputs[1]->getGuid() << ",";
os << "inputV=" << inputs[2]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> AttentionObj::getWorkloadVector() const {
vector<int> ret = getOutput()->getDims();
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> AttentionObj::getOpAttrVector() const {
return {type.underlying()};
}
} // namespace infini

View File

@ -0,0 +1,77 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/attention.h"
#include "test.h"
namespace infini {
void test_attention(const Shape &outputShape, const vector<float> &inputQData,
const vector<float> &inputKData,
const vector<float> &inputVData,
const vector<float> &ExpectData) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto inputV = gCpu->addTensor(outputShape, DataType::Float32);
auto inputQ = gCpu->addTensor(outputShape, DataType::Float32);
auto inputK = gCpu->addTensor(outputShape, DataType::Float32);
gCpu->dataMalloc();
inputV->copyin(inputVData); //
inputQ->copyin(inputQData);
inputK->copyin(inputKData); //
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputVGpu = gCuda->cloneTensor(inputV);
auto inputQGpu = gCuda->cloneTensor(inputQ);
auto inputKGpu = gCuda->cloneTensor(inputK);
auto op = gCuda->addOp<AttentionObj>(inputQGpu, inputKGpu, inputVGpu,
nullptr); // AttentionObj
gCuda->dataMalloc();
inputVGpu->copyin(inputVData);
inputQGpu->copyin(inputQData);
inputKGpu->copyin(inputKData);
cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
}
TEST(CUDA_Attention, run) {
test_attention(
Shape{6, 5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
8.536250e-03, 1.004909e+00, 2.017291e+00, 2.945395e+00,
1.997352e-02, 1.017340e+00, 2.017291e+00, 2.999871e+00,
3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00,
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
8.536250e-03, 1.004909e+00});
test_attention(
Shape{4, 3}, // inputQ
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.},
vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370,
1.0179861, 1.9941283, 2.9905086, 0.0210545, 1.0006673,
1.9993325, 2.9894698});
} // python output
} // namespace infini