Compare commits

...

12 Commits

Author SHA1 Message Date
xgqdut2016 54f4265296 modified logic 2023-11-17 17:43:52 +08:00
xgqdut2016 3629881dfa blockreduce matrix times 2023-10-08 15:15:39 +08:00
xgqdut2016 79dd3364df modified threadIdx.y to threadIdx.x 2023-10-08 11:33:16 +08:00
xgqdut2016 56e2c87c9b modified reduce,8ms 2023-10-07 18:14:12 +08:00
xgqdut2016 819484eda2 matrix reduce,threadIdx.x=0,17ms 2023-09-28 16:49:01 +08:00
xgqdut2016 ddbec7d60a BLOCK_DIM_x=1,num_block_x=N 2023-09-28 12:37:10 +08:00
xgqdut2016 b640ab1689 modified attention.cu,BLOCK_DIM_x must leq 32 2023-09-26 14:53:02 +08:00
xgqdut2016 ec391674ac 2D block, share S 2023-09-25 13:02:25 +08:00
xgqdut2016 f1dc440a3c 1D attention ,global S matrix 2023-09-22 16:59:42 +08:00
xgqdut2016 b1a2d91aba modified the format from master 2023-09-21 15:03:57 +08:00
xgqdut2016 410844c058 modified error in onnx.py 2023-09-21 14:59:12 +08:00
xgqdut2016 3f5178d069 the baseline of flash attention 2023-09-21 14:31:43 +08:00
11 changed files with 424 additions and 14 deletions

View File

@ -73,6 +73,8 @@ class GraphHandlerObj {
Tensor cast(Tensor input, Tensor output, int to); Tensor cast(Tensor input, Tensor output, int to);
Tensor expand(Tensor input, Tensor output, Shape dims); Tensor expand(Tensor input, Tensor output, Shape dims);
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output); Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
Tensor attention(Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output);
Tensor allReduceSum(Tensor input, Tensor output); Tensor allReduceSum(Tensor input, Tensor output);
Tensor allReduceProd(Tensor input, Tensor output); Tensor allReduceProd(Tensor input, Tensor output);

View File

@ -15,16 +15,17 @@ struct OpType {
// elements. // elements.
enum : underlying_t { enum : underlying_t {
Unknown, Unknown,
Abs, // Unary Abs, // Unary
Acos, // Unary Acos, // Unary
Acosh, // Unary Acosh, // Unary
Add, // Binary Add, // Binary
And, // Binary And, // Binary
ArgMax, // ArgMax, //
Asin, // Binary Asin, // Binary
Asinh, // Binary Asinh, // Binary
Atan, // Binary Atan, // Binary
Atanh, // Binary Atanh, // Binary
Attention,
AveragePool, // Pool AveragePool, // Pool
BatchNormalization, // BatchNormalization, //
Bernoulli, // Bernoulli, //

View File

@ -0,0 +1,8 @@
#pragma once
#include "operators/unary.h"
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d, float *output);
}; // namespace infini

View File

@ -0,0 +1,36 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Return elements, either from X or Y, depending on condition.
*
*/
class AttentionObj : public OperatorObj {
public:
/**
* @brief Construct a new Attention object.
*
* @param graph The computation graph that this operator belongs to.
* @param inputX The input tensor Q.
* @param inputY The input tensor K.
* @param output The output tensor.
* @param inputV The input tensor V.
*/
AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output);
OP_CLONE(AttentionObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -32,6 +32,7 @@ class OnnxStub:
The Onnx model imported into infinitensor. The Onnx model imported into infinitensor.
It can be generated from an Onnx model object. It can be generated from an Onnx model object.
""" """
def __init__(self, model: ModelProto, runtime): def __init__(self, model: ModelProto, runtime):
self.inputs: Dict[str, backend.Tensor] = {} self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {}
@ -60,7 +61,6 @@ class OnnxStub:
dims, output.type.tensor_type.elem_type dims, output.type.tensor_type.elem_type
) )
node_name = [] node_name = []
new_node_name = [] new_node_name = []
for node in model.graph.node: for node in model.graph.node:
@ -632,6 +632,13 @@ class OnnxStub:
), ),
): ):
tensors[name] = tensor tensors[name] = tensor
elif node.op_type == "Attention":
tensors[node.output[0]] = self.handler.attention(
tensors[node.input[0]],
tensors[node.input[1]],
tensors[node.input[2]],
tensors.get(node.output[0]),
)
elif node.op_type == "Broadcast": elif node.op_type == "Broadcast":
tensors[node.output[0]] = self.handler.broadcast( tensors[node.output[0]] = self.handler.broadcast(
tensors[node.input[0]], tensors[node.input[0]],
@ -678,7 +685,6 @@ class OnnxStub:
for output in model.graph.output: for output in model.graph.output:
tensors[output.name].set_output() tensors[output.name].set_output()
################################ ################################
# Allocate memory space for data # Allocate memory space for data
################################ ################################
@ -1002,6 +1008,10 @@ class OnnxStub:
assert len(inputs) == 3, "Check Where Op must have three inputs." assert len(inputs) == 3, "Check Where Op must have three inputs."
new_inputs = [inputs[2], inputs[0], inputs[1]] new_inputs = [inputs[2], inputs[0], inputs[1]]
ctx.push_node(make_node(ty.name, new_inputs, outputs, name)) ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
elif ty == backend.OpTypeId.Attention:
assert len(inputs) == 3, "Check Attention Op must have three inputs."
new_inputs = [inputs[0], inputs[1], inputs[2]]
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
elif ty == backend.OpTypeId.Expand: elif ty == backend.OpTypeId.Expand:
shape = backend.expand_shape_of(op) shape = backend.expand_shape_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape)) ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape))

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "operators/all_gather.h" #include "operators/all_gather.h"
#include "operators/all_reduce.h" #include "operators/all_reduce.h"
#include "operators/attention.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/broadcast.h" #include "operators/broadcast.h"
#include "operators/concat.h" #include "operators/concat.h"
@ -406,7 +407,19 @@ Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
->getOutput(); ->getOutput();
} }
} }
Tensor GraphHandlerObj::attention(Tensor inputQ, Tensor inputK, Tensor inputV,
Tensor output) {
if (output) {
g->addOpWithOutputs<AttentionObj>(std::move(inputQ), std::move(inputK),
std::move(inputV), output);
return output;
} else {
return g
->addOp<AttentionObj>(std::move(inputQ), std::move(inputK),
std::move(inputV), output)
->getOutput();
}
}
static CastType inferCastType(Tensor input, int to) { static CastType inferCastType(Tensor input, int to) {
auto iType = input->getDType(); auto iType = input->getDType();
auto oType = DataType(to); auto oType = DataType(to);

View File

@ -1,5 +1,6 @@
#include "core/data_type.h" #include "core/data_type.h"
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "operators/attention.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/conv.h" #include "operators/conv.h"
@ -67,6 +68,7 @@ void export_values(py::module &m) {
.def(py::init<decltype(OpType::type)>()) .def(py::init<decltype(OpType::type)>())
.def("id", getId, policy::automatic); .def("id", getId, policy::automatic);
py::enum_<decltype(OpType::type)>(m, "OpTypeId") py::enum_<decltype(OpType::type)>(m, "OpTypeId")
.VALUE(OpType, Attention)
.VALUE(OpType, Conv) .VALUE(OpType, Conv)
.VALUE(OpType, MatMul) .VALUE(OpType, MatMul)
.VALUE(OpType, ConvTranspose) .VALUE(OpType, ConvTranspose)
@ -424,6 +426,7 @@ void init_graph_builder(py::module &m) {
policy::reference); policy::reference);
py::class_<Handler>(m, "GraphHandler") py::class_<Handler>(m, "GraphHandler")
.def(py::init<Runtime>()) .def(py::init<Runtime>())
.def("attention", &Handler::attention, policy::move)
.def("tensor", &Handler::tensor, policy::move) .def("tensor", &Handler::tensor, policy::move)
.def("conv", &Handler::conv, policy::move) .def("conv", &Handler::conv, policy::move)
.def("convTransposed2d", &Handler::convTransposed2d, policy::move) .def("convTransposed2d", &Handler::convTransposed2d, policy::move)

View File

@ -0,0 +1,28 @@
#include "operators/attention.h"
#include "cuda/cuda_attention.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class AttentionCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<AttentionObj>(_op);
void *const inputQData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const inputKData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const inputVData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
int N = op->getInputs(0)->getDims()[0];
int d = op->getInputs(0)->getDims()[1];
attentionKernel((float *)inputQData, (float *)inputKData,
(float *)inputVData, N, d, (float *)outputData);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32,
AttentionCuda, "Attention_CUDA_Float32");
}; // namespace infini

View File

@ -0,0 +1,187 @@
#include "cuda/cuda_common.h"
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void _attentionKernel(const float *__restrict inputQ,
const float *__restrict inputK,
const float *__restrict inputV, int N, int d,
float *__restrict output) {
int i = blockIdx.y; // i must < N,Q[i]
int phd = threadIdx.x + blockIdx.x * blockDim.x; // V[:,d]
int phNumN = (N + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
__shared__ float inputS[BLOCK_DIM_x][BLOCK_DIM_y];
float newMax;
float oldMax;
float newSum;
newMax = -__FLT_MAX__;
oldMax = -__FLT_MAX__;
newSum = 0.0f;
float out;
out = 0.0f;
//---------
__shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float sum_partial[BLOCK_DIM_x][BLOCK_DIM_y];
int extra = d % BLOCK_DIM_x;
int step = (d - extra) / BLOCK_DIM_x;
for (int phn = 0; phn < phNumN; phn++) {
int j = threadIdx.y + phn * BLOCK_DIM_y;
float sum_r = 0.0f;
__syncthreads();
if (threadIdx.x < extra) {
for (int ind = threadIdx.x * (step + 1);
ind < (threadIdx.x + 1) * (step + 1); ind++) {
sum_r += inputQ[i * d + ind] * inputK[j * d + ind];
}
} else {
for (int ind = extra * (step + 1) + (threadIdx.x - extra) * step;
ind < extra * (step + 1) + (threadIdx.x - extra + 1) * step;
ind++) {
sum_r += inputQ[i * d + ind] * inputK[j * d + ind];
}
}
if (j < N) {
sum_partial[threadIdx.x][threadIdx.y] = sum_r;
} else {
sum_partial[threadIdx.x][threadIdx.y] = 0.0f;
}
__syncthreads();
for (int strip = BLOCK_DIM_x / 2; strip > 0; strip /= 2) {
if (threadIdx.x < strip) {
sum_partial[threadIdx.x][threadIdx.y] +=
sum_partial[threadIdx.x + strip][threadIdx.y];
}
__syncthreads();
}
float sum_s = sum_partial[0][threadIdx.y];
if (j < N) {
block_sum[threadIdx.x][threadIdx.y] = 1.0f;
} else {
sum_partial[0][threadIdx.y] = -__FLT_MAX__;
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
}
__syncthreads();
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip /= 2) {
if (threadIdx.y < strip) {
if (sum_partial[0][threadIdx.y] >
sum_partial[0][threadIdx.y + strip]) {
block_sum[threadIdx.x][threadIdx.y] =
block_sum[threadIdx.x][threadIdx.y] +
block_sum[threadIdx.x][threadIdx.y + strip] *
__expf(sum_partial[0][threadIdx.y + strip] -
sum_partial[0][threadIdx.y]);
} else {
block_sum[threadIdx.x][threadIdx.y] =
block_sum[threadIdx.x][threadIdx.y + strip] +
block_sum[threadIdx.x][threadIdx.y] *
__expf(sum_partial[0][threadIdx.y] -
sum_partial[0][threadIdx.y + strip]);
sum_partial[0][threadIdx.y] =
sum_partial[0][threadIdx.y + strip];
}
}
__syncthreads();
}
if (newMax > sum_partial[0][0]) {
newSum = newSum + block_sum[threadIdx.x][0] *
__expf(sum_partial[0][0] - newMax);
} else {
newSum = block_sum[threadIdx.x][0] +
newSum * __expf(newMax - sum_partial[0][0]);
newMax = sum_partial[0][0];
}
if (j < N && phd < d) {
inputS[threadIdx.x][threadIdx.y] =
__expf(sum_s - newMax) *
inputV[(threadIdx.y + phn * BLOCK_DIM_y) * d + phd];
} else {
inputS[threadIdx.x][threadIdx.y] = 0.0f;
}
__syncthreads();
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip /= 2) {
if (threadIdx.y < strip) {
inputS[threadIdx.x][threadIdx.y] +=
inputS[threadIdx.x][threadIdx.y + strip];
}
__syncthreads();
}
if (j < N && phd < d) {
out = __expf(oldMax - newMax) * out + inputS[threadIdx.x][0];
}
oldMax = newMax;
}
if (threadIdx.y + (phNumN - 1) * BLOCK_DIM_y < N && phd < d) {
output[i * d + phd] = out * __fdividef(1.0F, newSum);
}
}
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d, float *output) {
int num_block_y = N;
if (d > 512) {
int BLOCK_DIM_x = 1024;
int BLOCK_DIM_y = 1;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<1024, 1>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 256) {
int BLOCK_DIM_x = 512;
int BLOCK_DIM_y = 2;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<512, 2>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 128) {
int BLOCK_DIM_x = 256;
int BLOCK_DIM_y = 4;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<256, 4>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 64) {
int BLOCK_DIM_x = 128;
int BLOCK_DIM_y = 8;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<128, 8>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 32) {
int BLOCK_DIM_x = 64;
int BLOCK_DIM_y = 16;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<64, 16>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 16) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<32, 32>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<16, 64>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
}
}
} // namespace infini

View File

@ -0,0 +1,45 @@
#include "operators/attention.h"
#include "utils/operator_utils.h"
namespace infini {
AttentionObj::AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK,
Tensor inputV, Tensor output)
: OperatorObj(OpType::Attention, TensorVec{inputQ, inputK, inputV},
{output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>>
AttentionObj::inferShape(const TensorVec &inputs) const {
auto shapeQ = inputs[0]->getDims();
auto shapeK = inputs[1]->getDims();
auto shapeV = inputs[2]->getDims();
auto retQK = infer_broadcast(shapeQ, shapeK);
auto ret = infer_broadcast(retQK, shapeV);
return {{ret}};
}
std::string AttentionObj::toString() const {
std::ostringstream os;
os << "Attention[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[2]->getDims()) << ",";
os << "inputQ=" << inputs[0]->getGuid() << ",";
os << "inputK=" << inputs[1]->getGuid() << ",";
os << "inputV=" << inputs[2]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> AttentionObj::getWorkloadVector() const {
vector<int> ret = getOutput()->getDims();
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> AttentionObj::getOpAttrVector() const {
return {type.underlying()};
}
} // namespace infini

View File

@ -0,0 +1,77 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/attention.h"
#include "test.h"
namespace infini {
void test_attention(const Shape &outputShape, const vector<float> &inputQData,
const vector<float> &inputKData,
const vector<float> &inputVData,
const vector<float> &ExpectData) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto inputV = gCpu->addTensor(outputShape, DataType::Float32);
auto inputQ = gCpu->addTensor(outputShape, DataType::Float32);
auto inputK = gCpu->addTensor(outputShape, DataType::Float32);
gCpu->dataMalloc();
inputV->copyin(inputVData); //
inputQ->copyin(inputQData);
inputK->copyin(inputKData); //
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputVGpu = gCuda->cloneTensor(inputV);
auto inputQGpu = gCuda->cloneTensor(inputQ);
auto inputKGpu = gCuda->cloneTensor(inputK);
auto op = gCuda->addOp<AttentionObj>(inputQGpu, inputKGpu, inputVGpu,
nullptr); // AttentionObj
gCuda->dataMalloc();
inputVGpu->copyin(inputVData);
inputQGpu->copyin(inputQData);
inputKGpu->copyin(inputKData);
cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
}
TEST(CUDA_Attention, run) {
test_attention(
Shape{6, 5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
8.536250e-03, 1.004909e+00, 2.017291e+00, 2.945395e+00,
1.997352e-02, 1.017340e+00, 2.017291e+00, 2.999871e+00,
3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00,
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
8.536250e-03, 1.004909e+00});
test_attention(
Shape{4, 3}, // inputQ
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.},
vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370,
1.0179861, 1.9941283, 2.9905086, 0.0210545, 1.0006673,
1.9993325, 2.9894698});
} // python output
} // namespace infini