forked from jiuyuan/InfiniTensor
Compare commits
16 Commits
master
...
cuda-atten
Author | SHA1 | Date |
---|---|---|
![]() |
80cd1c951e | |
![]() |
a66ff430ec | |
![]() |
9b6c44dd40 | |
![]() |
131d1cb6d0 | |
![]() |
54f4265296 | |
![]() |
3629881dfa | |
![]() |
79dd3364df | |
![]() |
56e2c87c9b | |
![]() |
819484eda2 | |
![]() |
ddbec7d60a | |
![]() |
b640ab1689 | |
![]() |
ec391674ac | |
![]() |
f1dc440a3c | |
![]() |
b1a2d91aba | |
![]() |
410844c058 | |
![]() |
3f5178d069 |
|
@ -73,6 +73,8 @@ class GraphHandlerObj {
|
||||||
Tensor cast(Tensor input, Tensor output, int to);
|
Tensor cast(Tensor input, Tensor output, int to);
|
||||||
Tensor expand(Tensor input, Tensor output, Shape dims);
|
Tensor expand(Tensor input, Tensor output, Shape dims);
|
||||||
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
|
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
|
||||||
|
Tensor attention(Tensor inputQ, Tensor inputK, Tensor inputV,
|
||||||
|
Tensor output);
|
||||||
|
|
||||||
Tensor allReduceSum(Tensor input, Tensor output);
|
Tensor allReduceSum(Tensor input, Tensor output);
|
||||||
Tensor allReduceProd(Tensor input, Tensor output);
|
Tensor allReduceProd(Tensor input, Tensor output);
|
||||||
|
|
|
@ -15,16 +15,17 @@ struct OpType {
|
||||||
// elements.
|
// elements.
|
||||||
enum : underlying_t {
|
enum : underlying_t {
|
||||||
Unknown,
|
Unknown,
|
||||||
Abs, // Unary
|
Abs, // Unary
|
||||||
Acos, // Unary
|
Acos, // Unary
|
||||||
Acosh, // Unary
|
Acosh, // Unary
|
||||||
Add, // Binary
|
Add, // Binary
|
||||||
And, // Binary
|
And, // Binary
|
||||||
ArgMax, //
|
ArgMax, //
|
||||||
Asin, // Binary
|
Asin, // Binary
|
||||||
Asinh, // Binary
|
Asinh, // Binary
|
||||||
Atan, // Binary
|
Atan, // Binary
|
||||||
Atanh, // Binary
|
Atanh, // Binary
|
||||||
|
Attention,
|
||||||
AveragePool, // Pool
|
AveragePool, // Pool
|
||||||
BatchNormalization, //
|
BatchNormalization, //
|
||||||
Bernoulli, //
|
Bernoulli, //
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
#pragma once
|
||||||
|
#include "operators/unary.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void attentionKernel(const float *inputQ, const float *inputK,
|
||||||
|
const float *inputV, int N, int d, float *output);
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,36 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
/**
|
||||||
|
* @brief Return elements, either from X or Y, depending on condition.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class AttentionObj : public OperatorObj {
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Attention object.
|
||||||
|
*
|
||||||
|
* @param graph The computation graph that this operator belongs to.
|
||||||
|
* @param inputX The input tensor Q.
|
||||||
|
* @param inputY The input tensor K.
|
||||||
|
* @param output The output tensor.
|
||||||
|
* @param inputV The input tensor V.
|
||||||
|
*/
|
||||||
|
AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, Tensor inputV,
|
||||||
|
Tensor output);
|
||||||
|
OP_CLONE(AttentionObj);
|
||||||
|
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return inputs.size(); }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -32,6 +32,7 @@ class OnnxStub:
|
||||||
The Onnx model imported into infinitensor.
|
The Onnx model imported into infinitensor.
|
||||||
It can be generated from an Onnx model object.
|
It can be generated from an Onnx model object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: ModelProto, runtime):
|
def __init__(self, model: ModelProto, runtime):
|
||||||
self.inputs: Dict[str, backend.Tensor] = {}
|
self.inputs: Dict[str, backend.Tensor] = {}
|
||||||
self.outputs: Dict[str, backend.Tensor] = {}
|
self.outputs: Dict[str, backend.Tensor] = {}
|
||||||
|
@ -60,7 +61,6 @@ class OnnxStub:
|
||||||
dims, output.type.tensor_type.elem_type
|
dims, output.type.tensor_type.elem_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
node_name = []
|
node_name = []
|
||||||
new_node_name = []
|
new_node_name = []
|
||||||
for node in model.graph.node:
|
for node in model.graph.node:
|
||||||
|
@ -632,6 +632,13 @@ class OnnxStub:
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
tensors[name] = tensor
|
tensors[name] = tensor
|
||||||
|
elif node.op_type == "Attention":
|
||||||
|
tensors[node.output[0]] = self.handler.attention(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors[node.input[2]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
elif node.op_type == "Broadcast":
|
elif node.op_type == "Broadcast":
|
||||||
tensors[node.output[0]] = self.handler.broadcast(
|
tensors[node.output[0]] = self.handler.broadcast(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
@ -678,7 +685,6 @@ class OnnxStub:
|
||||||
for output in model.graph.output:
|
for output in model.graph.output:
|
||||||
tensors[output.name].set_output()
|
tensors[output.name].set_output()
|
||||||
|
|
||||||
|
|
||||||
################################
|
################################
|
||||||
# Allocate memory space for data
|
# Allocate memory space for data
|
||||||
################################
|
################################
|
||||||
|
@ -1002,6 +1008,10 @@ class OnnxStub:
|
||||||
assert len(inputs) == 3, "Check Where Op must have three inputs."
|
assert len(inputs) == 3, "Check Where Op must have three inputs."
|
||||||
new_inputs = [inputs[2], inputs[0], inputs[1]]
|
new_inputs = [inputs[2], inputs[0], inputs[1]]
|
||||||
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
|
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
|
||||||
|
elif ty == backend.OpTypeId.Attention:
|
||||||
|
assert len(inputs) == 3, "Check Attention Op must have three inputs."
|
||||||
|
new_inputs = [inputs[0], inputs[1], inputs[2]]
|
||||||
|
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
|
||||||
elif ty == backend.OpTypeId.Expand:
|
elif ty == backend.OpTypeId.Expand:
|
||||||
shape = backend.expand_shape_of(op)
|
shape = backend.expand_shape_of(op)
|
||||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape))
|
ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape))
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
#include "operators/all_gather.h"
|
#include "operators/all_gather.h"
|
||||||
#include "operators/all_reduce.h"
|
#include "operators/all_reduce.h"
|
||||||
|
#include "operators/attention.h"
|
||||||
#include "operators/batch_norm.h"
|
#include "operators/batch_norm.h"
|
||||||
#include "operators/broadcast.h"
|
#include "operators/broadcast.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
|
@ -406,7 +407,19 @@ Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
|
||||||
->getOutput();
|
->getOutput();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Tensor GraphHandlerObj::attention(Tensor inputQ, Tensor inputK, Tensor inputV,
|
||||||
|
Tensor output) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<AttentionObj>(std::move(inputQ), std::move(inputK),
|
||||||
|
std::move(inputV), output);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g
|
||||||
|
->addOp<AttentionObj>(std::move(inputQ), std::move(inputK),
|
||||||
|
std::move(inputV), output)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
static CastType inferCastType(Tensor input, int to) {
|
static CastType inferCastType(Tensor input, int to) {
|
||||||
auto iType = input->getDType();
|
auto iType = input->getDType();
|
||||||
auto oType = DataType(to);
|
auto oType = DataType(to);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#include "core/data_type.h"
|
#include "core/data_type.h"
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
|
#include "operators/attention.h"
|
||||||
#include "operators/batch_norm.h"
|
#include "operators/batch_norm.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
|
@ -67,6 +68,7 @@ void export_values(py::module &m) {
|
||||||
.def(py::init<decltype(OpType::type)>())
|
.def(py::init<decltype(OpType::type)>())
|
||||||
.def("id", getId, policy::automatic);
|
.def("id", getId, policy::automatic);
|
||||||
py::enum_<decltype(OpType::type)>(m, "OpTypeId")
|
py::enum_<decltype(OpType::type)>(m, "OpTypeId")
|
||||||
|
.VALUE(OpType, Attention)
|
||||||
.VALUE(OpType, Conv)
|
.VALUE(OpType, Conv)
|
||||||
.VALUE(OpType, MatMul)
|
.VALUE(OpType, MatMul)
|
||||||
.VALUE(OpType, ConvTranspose)
|
.VALUE(OpType, ConvTranspose)
|
||||||
|
@ -424,6 +426,7 @@ void init_graph_builder(py::module &m) {
|
||||||
policy::reference);
|
policy::reference);
|
||||||
py::class_<Handler>(m, "GraphHandler")
|
py::class_<Handler>(m, "GraphHandler")
|
||||||
.def(py::init<Runtime>())
|
.def(py::init<Runtime>())
|
||||||
|
.def("attention", &Handler::attention, policy::move)
|
||||||
.def("tensor", &Handler::tensor, policy::move)
|
.def("tensor", &Handler::tensor, policy::move)
|
||||||
.def("conv", &Handler::conv, policy::move)
|
.def("conv", &Handler::conv, policy::move)
|
||||||
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)
|
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
#include "operators/attention.h"
|
||||||
|
#include "cuda/cuda_attention.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class AttentionCuda : public CudaKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<AttentionObj>(_op);
|
||||||
|
|
||||||
|
void *const inputQData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const inputKData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
|
void *const inputVData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||||
|
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
int N = op->getInputs(0)->getDims()[0];
|
||||||
|
int d = op->getInputs(0)->getDims()[1];
|
||||||
|
|
||||||
|
attentionKernel((float *)inputQData, (float *)inputKData,
|
||||||
|
(float *)inputVData, N, d, (float *)outputData);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32,
|
||||||
|
AttentionCuda, "Attention_CUDA_Float32");
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,252 @@
|
||||||
|
#include "cuda/cuda_common.h"
|
||||||
|
const int Rq = 4;
|
||||||
|
const int Rv = 8; // 必须是4的倍数
|
||||||
|
const int Br = 16;
|
||||||
|
const int Bc = 16;
|
||||||
|
const int Bk = 4; // 必须是4的倍数
|
||||||
|
|
||||||
|
template <int Br, int Bc, int Rq>
|
||||||
|
__device__ void matmulRQK(const float *__restrict inputQ,
|
||||||
|
const float *__restrict inputK, float *shareQK,
|
||||||
|
float *shareVK, int N, int d, int width, int indQ,
|
||||||
|
int indK, float *val) {
|
||||||
|
float a[4];
|
||||||
|
for (int ph = 0; ph < width; ph++) {
|
||||||
|
for (int index_k = 0; index_k < Bk; index_k++) {
|
||||||
|
(float4 &)a[0] = (float4 &)
|
||||||
|
inputK[(indK + index_k) * d + (threadIdx.y + ph * Bc) * Bk];
|
||||||
|
for (int idk = 0; idk < Bk; idk++) {
|
||||||
|
if (threadIdx.y < Bc) {
|
||||||
|
shareVK[(threadIdx.y * Bk + idk) * Bc * Bk +
|
||||||
|
threadIdx.x * Bk + index_k] = a[idk];
|
||||||
|
if (indK + index_k >= N ||
|
||||||
|
(threadIdx.y + ph * Bc) * Bk + idk >= d) {
|
||||||
|
|
||||||
|
shareVK[(threadIdx.y * Bk + idk) * Bc * Bk +
|
||||||
|
threadIdx.x * Bk + index_k] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int index_q = 0; index_q < Rq; index_q++) {
|
||||||
|
(float4 &)shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk +
|
||||||
|
threadIdx.x * Bk] = (float4 &)
|
||||||
|
inputQ[(indQ + index_q) * d + (threadIdx.x + ph * Bc) * Bk];
|
||||||
|
for (int idk = 0; idk < Bk; idk++) {
|
||||||
|
if (indQ + index_q >= N ||
|
||||||
|
(threadIdx.x + ph * Bc) * Bk + idk >= d) {
|
||||||
|
shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk +
|
||||||
|
threadIdx.x * Bk + idk] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int index = 0; index < Bc * Bk; index++) {
|
||||||
|
for (int index_q = 0; index_q < Rq; index_q++) {
|
||||||
|
for (int index_k = 0; index_k < Bk; index_k++) {
|
||||||
|
val[index_q * Bk + index_k] = std::fma(
|
||||||
|
shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk + index],
|
||||||
|
shareVK[index * Bc * Bk + threadIdx.x * Bk + index_k],
|
||||||
|
val[index_q * Bk + index_k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <int Br, int Bc, int Rq, int Rv>
|
||||||
|
__device__ void matmulSV(float *shareQK, const float *__restrict inputV,
|
||||||
|
float *shareVK, int N, int d, int j, int indQ,
|
||||||
|
int indK, int indV, float *val, float *newMax,
|
||||||
|
float *sumSV) {
|
||||||
|
if (threadIdx.y < Bc) {
|
||||||
|
for (int index_k = 0; index_k < Bk; index_k++) {
|
||||||
|
for (int id = 0; id < (int)(Rv / 4); id++) {
|
||||||
|
(float4 &)shareVK[(threadIdx.y * Bk + index_k) * Bc * Rv +
|
||||||
|
threadIdx.x * Rv + id * 4] = (float4 &)
|
||||||
|
inputV[((threadIdx.y + j * Bc) * Bk + index_k) * d + indV +
|
||||||
|
id * 4];
|
||||||
|
}
|
||||||
|
for (int index_v = 0; index_v < Rv; index_v++) {
|
||||||
|
if ((threadIdx.y + j * Bc) * Bk + index_k >= N ||
|
||||||
|
indV + index_v >= d) {
|
||||||
|
shareVK[(threadIdx.y * Bk + index_k) * Bc * Rv +
|
||||||
|
threadIdx.x * Rv + index_v] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int index_q = 0; index_q < Rq; index_q++) {
|
||||||
|
for (int index_k = 0; index_k < Bk; index_k++) {
|
||||||
|
if (indQ + index_q < N && indK + index_k < N) {
|
||||||
|
shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk +
|
||||||
|
threadIdx.x * Bk + index_k] =
|
||||||
|
__expf(val[index_q * Bk + index_k] - newMax[index_q]);
|
||||||
|
} else {
|
||||||
|
|
||||||
|
shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk +
|
||||||
|
threadIdx.x * Bk + index_k] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int phc = 0; phc < Bc * Bk; phc++) {
|
||||||
|
for (int index_q = 0; index_q < Rq; index_q++) {
|
||||||
|
|
||||||
|
for (int index_v = 0; index_v < Rv; index_v++) {
|
||||||
|
sumSV[index_q * Rv + index_v] +=
|
||||||
|
shareQK[(threadIdx.y * Rq + index_q) * Bc * Bk + phc] *
|
||||||
|
shareVK[phc * Bc * Rv + threadIdx.x * Rv + index_v];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename T> struct SumOp {
|
||||||
|
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T> struct MaxOp {
|
||||||
|
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
|
||||||
|
return max(a, b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <template <typename> class ReductionOp, typename T,
|
||||||
|
int thread_group_width = 32>
|
||||||
|
__inline__ __device__ T WarpAllReduce(T val) {
|
||||||
|
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
|
||||||
|
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
|
||||||
|
}
|
||||||
|
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int Br, int Bc, int Rq, int Rv>
|
||||||
|
__global__ void _attentionKernel(const float *__restrict inputQ,
|
||||||
|
const float *__restrict inputK,
|
||||||
|
const float *__restrict inputV, int N, int d,
|
||||||
|
float *__restrict output) {
|
||||||
|
|
||||||
|
__shared__ float shareQK[Rq * Br * Bc * Bk];
|
||||||
|
__shared__ float shareVK[Bk * Bc * Bc * Rv];
|
||||||
|
|
||||||
|
float sumSV[Rq * Rv] = {0.0f};
|
||||||
|
float newMax[Rq];
|
||||||
|
float oldMax[Rq];
|
||||||
|
float newSum[Rq] = {0.0f};
|
||||||
|
|
||||||
|
float val[Rq * Bk];
|
||||||
|
|
||||||
|
int indV = Rv * (threadIdx.x + blockIdx.x * blockDim.x);
|
||||||
|
int indQ = Rq * (threadIdx.y + blockIdx.y * blockDim.y);
|
||||||
|
|
||||||
|
for (int index_q = 0; index_q < Rq; index_q++) {
|
||||||
|
newMax[index_q] = -__FLT_MAX__;
|
||||||
|
oldMax[index_q] = -__FLT_MAX__;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Tc = (N + Bc * Bk - 1) / (Bc * Bk);
|
||||||
|
|
||||||
|
int width = (d + Bc * Bk - 1) / (Bc * Bk);
|
||||||
|
for (int j = 0; j < Tc; j++) {
|
||||||
|
|
||||||
|
int indK = Bk * (threadIdx.x + j * Bc);
|
||||||
|
for (int index_q = 0; index_q < Rq; index_q++) {
|
||||||
|
for (int index_k = 0; index_k < Bk; index_k++) {
|
||||||
|
|
||||||
|
val[index_q * Bk + index_k] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
matmulRQK<Br, Bc, Rq>(inputQ, inputK, shareQK, shareVK, N, d, width,
|
||||||
|
indQ, indK, val);
|
||||||
|
for (int index_q = 0; index_q < Rq; index_q++) {
|
||||||
|
float tmpReduceMax = -__FLT_MAX__;
|
||||||
|
for (int index_k = 0; index_k < Bk; index_k++) {
|
||||||
|
if (indQ + index_q < N && indK + index_k < N) {
|
||||||
|
|
||||||
|
tmpReduceMax =
|
||||||
|
max(tmpReduceMax, val[index_q * Bk + index_k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
tmpReduceMax = WarpAllReduce<MaxOp, float, Bc>(tmpReduceMax);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
shareQK[threadIdx.y * Rq + index_q] = tmpReduceMax;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
float tmpReduceSum = 0.0f;
|
||||||
|
for (int index_k = 0; index_k < Bk; index_k++) {
|
||||||
|
if (indQ + index_q < N && indK + index_k < N) {
|
||||||
|
tmpReduceSum += __expf(val[index_q * Bk + index_k] -
|
||||||
|
shareQK[threadIdx.y * Rq + index_q]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
tmpReduceSum = WarpAllReduce<SumOp, float, Bc>(tmpReduceSum);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
shareQK[threadIdx.y * Rq + index_q + Rq * Br] = tmpReduceSum;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
if (newMax[index_q] > shareQK[threadIdx.y * Rq + index_q]) {
|
||||||
|
newSum[index_q] =
|
||||||
|
std::fma(shareQK[threadIdx.y * Rq + index_q + Rq * Br],
|
||||||
|
__expf(shareQK[threadIdx.y * Rq + index_q] -
|
||||||
|
newMax[index_q]),
|
||||||
|
newSum[index_q]);
|
||||||
|
} else {
|
||||||
|
newSum[index_q] =
|
||||||
|
std::fma(newSum[index_q],
|
||||||
|
__expf(newMax[index_q] -
|
||||||
|
shareQK[threadIdx.y * Rq + index_q]),
|
||||||
|
shareQK[threadIdx.y * Rq + index_q + Rq * Br]);
|
||||||
|
|
||||||
|
newMax[index_q] = shareQK[threadIdx.y * Rq + index_q];
|
||||||
|
}
|
||||||
|
// PV
|
||||||
|
for (int index_v = 0; index_v < Rv; index_v++) {
|
||||||
|
sumSV[index_q * Rv + index_v] *=
|
||||||
|
__expf(oldMax[index_q] - newMax[index_q]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
matmulSV<Br, Bc, Rq, Rv>(shareQK, inputV, shareVK, N, d, j, indQ, indK,
|
||||||
|
indV, val, newMax, sumSV);
|
||||||
|
|
||||||
|
for (int index_q = 0; index_q < Rq; index_q++) {
|
||||||
|
oldMax[index_q] = newMax[index_q];
|
||||||
|
}
|
||||||
|
|
||||||
|
//__syncthreads();
|
||||||
|
}
|
||||||
|
for (int index_q = 0; index_q < Rq; index_q++) {
|
||||||
|
float inv = __fdividef(1.0F, newSum[index_q]);
|
||||||
|
for (int index_v = 0; index_v < Rv; index_v++) {
|
||||||
|
sumSV[index_q * Rv + index_v] = sumSV[index_q * Rv + index_v] * inv;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int index_q = 0; index_q < Rq; index_q++) {
|
||||||
|
|
||||||
|
for (int id = 0; id < (int)(Rv / 4); id++) {
|
||||||
|
if (indQ + index_q < N) {
|
||||||
|
(float4 &)output[(indQ + index_q) * d + indV + id * 4] =
|
||||||
|
(float4 &)sumSV[index_q * Rv + id * 4];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
namespace infini {
|
||||||
|
void attentionKernel(const float *inputQ, const float *inputK,
|
||||||
|
const float *inputV, int N, int d, float *output) {
|
||||||
|
int num_block_x = (d + Rv * Bc - 1) / (Rv * Bc);
|
||||||
|
int num_block_y = (N + Rq * Br - 1) / (Rq * Br);
|
||||||
|
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||||
|
dim3 block_dim(Bc, Br, 1);
|
||||||
|
|
||||||
|
_attentionKernel<Br, Bc, Rq, Rv>
|
||||||
|
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,45 @@
|
||||||
|
#include "operators/attention.h"
|
||||||
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
AttentionObj::AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK,
|
||||||
|
Tensor inputV, Tensor output)
|
||||||
|
: OperatorObj(OpType::Attention, TensorVec{inputQ, inputK, inputV},
|
||||||
|
{output}) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>>
|
||||||
|
AttentionObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
auto shapeQ = inputs[0]->getDims();
|
||||||
|
auto shapeK = inputs[1]->getDims();
|
||||||
|
auto shapeV = inputs[2]->getDims();
|
||||||
|
auto retQK = infer_broadcast(shapeQ, shapeK);
|
||||||
|
auto ret = infer_broadcast(retQK, shapeV);
|
||||||
|
return {{ret}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string AttentionObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Attention[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[2]->getDims()) << ",";
|
||||||
|
os << "inputQ=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "inputK=" << inputs[1]->getGuid() << ",";
|
||||||
|
os << "inputV=" << inputs[2]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> AttentionObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = getOutput()->getDims();
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> AttentionObj::getOpAttrVector() const {
|
||||||
|
return {type.underlying()};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,77 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/attention.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void test_attention(const Shape &outputShape, const vector<float> &inputQData,
|
||||||
|
const vector<float> &inputKData,
|
||||||
|
const vector<float> &inputVData,
|
||||||
|
const vector<float> &ExpectData) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
auto inputV = gCpu->addTensor(outputShape, DataType::Float32);
|
||||||
|
auto inputQ = gCpu->addTensor(outputShape, DataType::Float32);
|
||||||
|
auto inputK = gCpu->addTensor(outputShape, DataType::Float32);
|
||||||
|
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
inputV->copyin(inputVData); //
|
||||||
|
inputQ->copyin(inputQData);
|
||||||
|
inputK->copyin(inputKData); //
|
||||||
|
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto inputVGpu = gCuda->cloneTensor(inputV);
|
||||||
|
auto inputQGpu = gCuda->cloneTensor(inputQ);
|
||||||
|
auto inputKGpu = gCuda->cloneTensor(inputK);
|
||||||
|
|
||||||
|
auto op = gCuda->addOp<AttentionObj>(inputQGpu, inputKGpu, inputVGpu,
|
||||||
|
nullptr); // AttentionObj
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
inputVGpu->copyin(inputVData);
|
||||||
|
inputQGpu->copyin(inputQData);
|
||||||
|
inputKGpu->copyin(inputKData);
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||||
|
oCpu->printData(); //->printData
|
||||||
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_Attention, run) {
|
||||||
|
test_attention(
|
||||||
|
Shape{6, 5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||||
|
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
|
||||||
|
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||||
|
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||||
|
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
|
||||||
|
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||||
|
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||||
|
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
|
||||||
|
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||||
|
vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
|
||||||
|
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
|
||||||
|
8.536250e-03, 1.004909e+00, 2.017291e+00, 2.945395e+00,
|
||||||
|
1.997352e-02, 1.017340e+00, 2.017291e+00, 2.999871e+00,
|
||||||
|
3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00,
|
||||||
|
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
|
||||||
|
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
|
||||||
|
8.536250e-03, 1.004909e+00});
|
||||||
|
|
||||||
|
test_attention(
|
||||||
|
Shape{4, 3}, // inputQ
|
||||||
|
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK
|
||||||
|
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV
|
||||||
|
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.},
|
||||||
|
vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370,
|
||||||
|
1.0179861, 1.9941283, 2.9905086, 0.0210545, 1.0006673,
|
||||||
|
1.9993325, 2.9894698});
|
||||||
|
|
||||||
|
} // python output
|
||||||
|
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue