diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 7f514ebd..d111391c 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -73,6 +73,8 @@ class GraphHandlerObj { Tensor cast(Tensor input, Tensor output, int to); Tensor expand(Tensor input, Tensor output, Shape dims); Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output); + Tensor attention(Tensor inputQ, Tensor inputK, Tensor inputV, + Tensor output); Tensor allReduceSum(Tensor input, Tensor output); Tensor allReduceProd(Tensor input, Tensor output); diff --git a/include/core/op_type.h b/include/core/op_type.h index e0146c5f..8b616ecc 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -15,16 +15,17 @@ struct OpType { // elements. enum : underlying_t { Unknown, - Abs, // Unary - Acos, // Unary - Acosh, // Unary - Add, // Binary - And, // Binary - ArgMax, // - Asin, // Binary - Asinh, // Binary - Atan, // Binary - Atanh, // Binary + Abs, // Unary + Acos, // Unary + Acosh, // Unary + Add, // Binary + And, // Binary + ArgMax, // + Asin, // Binary + Asinh, // Binary + Atan, // Binary + Atanh, // Binary + Attention, AveragePool, // Pool BatchNormalization, // Bernoulli, // diff --git a/include/cuda/cuda_attention.h b/include/cuda/cuda_attention.h new file mode 100644 index 00000000..dac8c86a --- /dev/null +++ b/include/cuda/cuda_attention.h @@ -0,0 +1,7 @@ +#pragma once +#include "operators/unary.h" + +namespace infini { +void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output); + +}; // namespace infini diff --git a/include/operators/attention.h b/include/operators/attention.h new file mode 100644 index 00000000..64653531 --- /dev/null +++ b/include/operators/attention.h @@ -0,0 +1,36 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * @brief Return elements, either from X or Y, depending on condition. + * + */ +class AttentionObj : public OperatorObj { + + public: + /** + * @brief Construct a new Attention object. + * + * @param graph The computation graph that this operator belongs to. + * @param inputX The input tensor Q. + * @param inputY The input tensor K. + * @param output The output tensor. + * @param inputV The input tensor V. + */ + AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, Tensor inputV, + Tensor output); + OP_CLONE(AttentionObj); + + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return inputs.size(); } + int numOutputs() const override { return 1; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +} // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 813a5e8e..b8762b69 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -32,6 +32,7 @@ class OnnxStub: The Onnx model imported into infinitensor. It can be generated from an Onnx model object. """ + def __init__(self, model: ModelProto, runtime): self.inputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {} @@ -60,7 +61,6 @@ class OnnxStub: dims, output.type.tensor_type.elem_type ) - node_name = [] new_node_name = [] for node in model.graph.node: @@ -632,6 +632,13 @@ class OnnxStub: ), ): tensors[name] = tensor + elif node.op_type == "Attention": + tensors[node.output[0]] = self.handler.attention( + tensors[node.input[0]], + tensors[node.input[1]], + tensors[node.input[2]], + tensors.get(node.output[0]), + ) elif node.op_type == "Broadcast": tensors[node.output[0]] = self.handler.broadcast( tensors[node.input[0]], @@ -674,11 +681,10 @@ class OnnxStub: for input in model.graph.input: tensors[input.name].set_input() - + for output in model.graph.output: tensors[output.name].set_output() - ################################ # Allocate memory space for data ################################ @@ -1002,6 +1008,10 @@ class OnnxStub: assert len(inputs) == 3, "Check Where Op must have three inputs." new_inputs = [inputs[2], inputs[0], inputs[1]] ctx.push_node(make_node(ty.name, new_inputs, outputs, name)) + elif ty == backend.OpTypeId.Attention: + assert len(inputs) == 3, "Check Attention Op must have three inputs." + new_inputs = [inputs[2], inputs[0], inputs[1]] + ctx.push_node(make_node(ty.name, new_inputs, outputs, name)) elif ty == backend.OpTypeId.Expand: shape = backend.expand_shape_of(op) ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape)) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index a804a8c7..5b75ec49 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -1,6 +1,7 @@ #include "core/graph_handler.h" #include "operators/all_gather.h" #include "operators/all_reduce.h" +#include "operators/attention.h" #include "operators/batch_norm.h" #include "operators/broadcast.h" #include "operators/concat.h" @@ -406,7 +407,19 @@ Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition, ->getOutput(); } } - +Tensor GraphHandlerObj::attention(Tensor inputQ, Tensor inputK, Tensor inputV, + Tensor output) { + if (output) { + g->addOpWithOutputs(std::move(inputQ), std::move(inputK), + std::move(inputV), output); + return output; + } else { + return g + ->addOp(std::move(inputQ), std::move(inputK), + std::move(inputV), output) + ->getOutput(); + } +} static CastType inferCastType(Tensor input, int to) { auto iType = input->getDType(); auto oType = DataType(to); diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index bea3f4bc..f6d8d056 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -1,5 +1,6 @@ #include "core/data_type.h" #include "core/graph_handler.h" +#include "operators/attention.h" #include "operators/batch_norm.h" #include "operators/concat.h" #include "operators/conv.h" @@ -67,6 +68,7 @@ void export_values(py::module &m) { .def(py::init()) .def("id", getId, policy::automatic); py::enum_(m, "OpTypeId") + .VALUE(OpType, Attention) .VALUE(OpType, Conv) .VALUE(OpType, MatMul) .VALUE(OpType, ConvTranspose) @@ -424,6 +426,7 @@ void init_graph_builder(py::module &m) { policy::reference); py::class_(m, "GraphHandler") .def(py::init()) + .def("attention", &Handler::attention, policy::move) .def("tensor", &Handler::tensor, policy::move) .def("conv", &Handler::conv, policy::move) .def("convTransposed2d", &Handler::convTransposed2d, policy::move) diff --git a/src/kernels/cuda/attention.cc b/src/kernels/cuda/attention.cc new file mode 100644 index 00000000..0ee03e95 --- /dev/null +++ b/src/kernels/cuda/attention.cc @@ -0,0 +1,29 @@ +#include "operators/attention.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_attention.h" + + +namespace infini { + +class AttentionCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + void *const inputQData = (op->getInputs(0)->getRawDataPtr()); + void *const inputKData = (op->getInputs(1)->getRawDataPtr()); + void *const inputVData = (op->getInputs(2)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + int N = op->getInputs(0)->getDims()[0]; + int d = op->getInputs(0)->getDims()[1]; + + attentionKernel((float *)inputQData, (float *)inputKData, + (float *)inputVData, N, d, (float *)outputData); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32, AttentionCuda, + "Attention_CUDA_Float32"); + +}; // namespace infini diff --git a/src/kernels/cuda/attention.cu b/src/kernels/cuda/attention.cu new file mode 100644 index 00000000..504cc373 --- /dev/null +++ b/src/kernels/cuda/attention.cu @@ -0,0 +1,111 @@ +#include "cuda/cuda_common.h" + +#define BLOCK_DIM_x 2 +#define BLOCK_DIM_y 2 + +#define max_function(a,b) ((a)>(b)?(a):(b)) + +__global__ void _attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output){ + int i = threadIdx.x + blockIdx.x * blockDim.x; // + int phNumN = (N + BLOCK_DIM_y - 1)/BLOCK_DIM_y; + + __shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y]; + __shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y]; + block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__; + block_sum[threadIdx.x][threadIdx.y] = 0.0f; + __shared__ float grid_sum[BLOCK_DIM_x]; + __shared__ float grid_max[BLOCK_DIM_x]; + __shared__ float grid_max_old[BLOCK_DIM_x]; + grid_max[threadIdx.x] = -__FLT_MAX__; + grid_max_old[threadIdx.x] = -__FLT_MAX__; + grid_sum[threadIdx.x] = 0.0f; + __shared__ float S[BLOCK_DIM_x][BLOCK_DIM_y]; + + __shared__ float Out_new[BLOCK_DIM_x][BLOCK_DIM_y]; + Out_new[threadIdx.x][threadIdx.y] = 0.0f; + for(int phn = 0; phn < phNumN; phn++){ + int j = threadIdx.y + phn*BLOCK_DIM_y; + if(i < N && j < N){ + float sum_s = 0; + for(int index = 0; index < d; index++){ + sum_s += inputQ[i * d + index] * inputK[j * d + index]; + } + + S[threadIdx.x][threadIdx.y] = sum_s; + block_sum[threadIdx.x][threadIdx.y] = 1.0f; + block_max[threadIdx.x][threadIdx.y] = sum_s; + } + else{ + S[threadIdx.x][threadIdx.y] = 0.0f; + block_sum[threadIdx.x][threadIdx.y] = 0.0f; + block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__; + } + + //----------------fix i, compute the max S[i,j] of this block + __syncthreads(); + + for(int strip = BLOCK_DIM_y/2; strip > 0; strip = strip/2){ + if (threadIdx.y < strip){ + if(block_max[threadIdx.x][threadIdx.y] > block_max[threadIdx.x][threadIdx.y + strip]){ + block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y] + block_sum[threadIdx.x][threadIdx.y + strip]*__expf(block_max[threadIdx.x][threadIdx.y + strip] - block_max[threadIdx.x][threadIdx.y]); + } + else{ + block_sum[threadIdx.x][threadIdx.y] = block_sum[threadIdx.x][threadIdx.y + strip] + block_sum[threadIdx.x][threadIdx.y]*__expf(block_max[threadIdx.x][threadIdx.y] - block_max[threadIdx.x][threadIdx.y + strip]); + block_max[threadIdx.x][threadIdx.y] = block_max[threadIdx.x][threadIdx.y + strip]; + } + } + }//block_max[threadIdx.x][0]store the local max of this block + __syncthreads(); + if(threadIdx.y == 0){ + if(grid_max[threadIdx.x] > block_max[threadIdx.x][0]){ + grid_sum[threadIdx.x] = grid_sum[threadIdx.x] + block_sum[threadIdx.x][0] * __expf(block_max[threadIdx.x][0] - grid_max[threadIdx.x]); + } + else{ + grid_sum[threadIdx.x] = block_sum[threadIdx.x][0] + grid_sum[threadIdx.x]*__expf(grid_max[threadIdx.x] - block_max[threadIdx.x][0]); + grid_max[threadIdx.x] = block_max[threadIdx.x][0]; + }//compare the max between the different blocks, when the loop end, grid_max store the global max + + } + __syncthreads(); + + S[threadIdx.x][threadIdx.y] = __expf(S[threadIdx.x][threadIdx.y] - grid_max[threadIdx.x]); //softmax(s)*L + + __syncthreads(); + int vj = threadIdx.y + blockIdx.y * blockDim.y; + //do not write vj = threadIdx.y + ph * blockDim.y + float sum_o; + if(vj < d){ + sum_o = 0; + for(int vid = 0; vid < BLOCK_DIM_y; vid++){ + if(vid + phn * BLOCK_DIM_y < N){ + sum_o += S[threadIdx.x][vid]*inputV[(vid + phn * BLOCK_DIM_y) * d + vj]; + } + } + Out_new[threadIdx.x][threadIdx.y] = __expf(grid_max_old[threadIdx.x] - grid_max[threadIdx.x])*Out_new[threadIdx.x][threadIdx.y] + sum_o; + grid_max_old[threadIdx.x] = grid_max[threadIdx.x]; + } + + + } + __syncthreads(); + + + + + int j = threadIdx.y + blockIdx.y * blockDim.y; + if(i < N && j < d){ + + output[i * d + j] = Out_new[threadIdx.x][threadIdx.y]*__fdividef(1.0F, grid_sum[threadIdx.x]); + } + +} +namespace infini { +void attentionKernel(const float *inputQ, const float *inputK, const float *inputV, int N, int d, float *output) { + int num_block_y = (max_function(N,d) + BLOCK_DIM_y - 1)/BLOCK_DIM_y; + int num_block_x = (N + BLOCK_DIM_x - 1)/BLOCK_DIM_x; + int share_mem = (5*BLOCK_DIM_y + 2)*BLOCK_DIM_x*sizeof(float); + dim3 block_dim(BLOCK_DIM_x,BLOCK_DIM_y,1); + dim3 grid_dim(num_block_x,num_block_y,1); + _attentionKernel<<>>(inputQ, inputK, inputV, N, d, output); +} +} // namespace infini diff --git a/src/operators/attention.cc b/src/operators/attention.cc new file mode 100644 index 00000000..8ffd0c78 --- /dev/null +++ b/src/operators/attention.cc @@ -0,0 +1,42 @@ +#include "operators/attention.h" +#include "utils/operator_utils.h" + +namespace infini { + +AttentionObj::AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, + Tensor inputV, Tensor output) + : OperatorObj(OpType::Attention, TensorVec{inputQ, inputK, inputV}, + {output}) { + IT_ASSERT(checkValid(graph)); +} + +optional> AttentionObj::inferShape(const TensorVec &inputs) const { + auto shapeQ = inputs[0]->getDims(); + auto shapeK = inputs[1]->getDims(); + auto shapeV = inputs[2]->getDims(); + auto retQK = infer_broadcast(shapeQ, shapeK); + auto ret = infer_broadcast(retQK, shapeV); + return {{ret}}; +} + +std::string AttentionObj::toString() const { + std::ostringstream os; + os << "Attention[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[2]->getDims()) << ","; + os << "inputQ=" << inputs[0]->getGuid() << ","; + os << "inputK=" << inputs[1]->getGuid() << ","; + os << "inputV=" << inputs[2]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector AttentionObj::getWorkloadVector() const { + vector ret = getOutput()->getDims(); + ret.emplace(ret.begin(), type.underlying()); + return ret; +} + +vector AttentionObj::getOpAttrVector() const { return {type.underlying()}; } + +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_attention.cc b/test/kernels/cuda/test_cuda_attention.cc new file mode 100644 index 00000000..7a82b848 --- /dev/null +++ b/test/kernels/cuda/test_cuda_attention.cc @@ -0,0 +1,70 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/attention.h" + +#include "test.h" + +namespace infini { + +void test_attention(const Shape &outputShape, const vector &inputQData, + const vector &inputKData, + const vector &inputVData, + const vector &ExpectData) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto inputV = gCpu->addTensor(outputShape, DataType::Float32); + auto inputQ = gCpu->addTensor(outputShape, DataType::Float32); + auto inputK = gCpu->addTensor(outputShape, DataType::Float32); + + gCpu->dataMalloc(); + inputV->copyin(inputVData); // + inputQ->copyin(inputQData); + inputK->copyin(inputKData); // + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputVGpu = gCuda->cloneTensor(inputV); + auto inputQGpu = gCuda->cloneTensor(inputQ); + auto inputKGpu = gCuda->cloneTensor(inputK); + + auto op = gCuda->addOp(inputQGpu, inputKGpu, inputVGpu, + nullptr); // AttentionObj + gCuda->dataMalloc(); + inputVGpu->copyin(inputVData); + inputQGpu->copyin(inputQData); + inputKGpu->copyin(inputKData); + cudaRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu + oCpu->printData(); //->printData + EXPECT_TRUE(oCpu->equalData(ExpectData)); +} + +TEST(CUDA_Attention, run) { + test_attention( + Shape{6,5}, vector{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., + 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.}, + vector{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., + 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.}, + vector{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1., + 2., 3., 0., 1., 2., 3., 0., 1., 2., 3., 0., 1.}, + vector{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03, + 1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00, + 2.017291e+00, 2.945395e+00, 1.997352e-02, 1.017340e+00, 2.017291e+00, + 2.999871e+00, 3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00, + 6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00, 6.507058e-03, + 1.004909e+00, 1.999979e+00, 2.986577e+00, 8.536250e-03, 1.004909e+00}); + + test_attention(Shape{4, 3}, // inputQ + vector{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK + vector{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV + vector{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, + vector{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370, 1.0179861, + 1.9941283, 2.9905086, 0.0210545, 1.0006673, 1.9993325, 2.9894698}); + +} // python output + +} // namespace infini