forked from jiuyuan/InfiniTensor
Compare commits
12 Commits
master
...
cuda-atten
Author | SHA1 | Date |
---|---|---|
xgqdut2016 | 54f4265296 | |
xgqdut2016 | 3629881dfa | |
xgqdut2016 | 79dd3364df | |
xgqdut2016 | 56e2c87c9b | |
xgqdut2016 | 819484eda2 | |
xgqdut2016 | ddbec7d60a | |
xgqdut2016 | b640ab1689 | |
xgqdut2016 | ec391674ac | |
xgqdut2016 | f1dc440a3c | |
xgqdut2016 | b1a2d91aba | |
xgqdut2016 | 410844c058 | |
xgqdut2016 | 3f5178d069 |
|
@ -73,6 +73,8 @@ class GraphHandlerObj {
|
|||
Tensor cast(Tensor input, Tensor output, int to);
|
||||
Tensor expand(Tensor input, Tensor output, Shape dims);
|
||||
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
|
||||
Tensor attention(Tensor inputQ, Tensor inputK, Tensor inputV,
|
||||
Tensor output);
|
||||
|
||||
Tensor allReduceSum(Tensor input, Tensor output);
|
||||
Tensor allReduceProd(Tensor input, Tensor output);
|
||||
|
|
|
@ -15,16 +15,17 @@ struct OpType {
|
|||
// elements.
|
||||
enum : underlying_t {
|
||||
Unknown,
|
||||
Abs, // Unary
|
||||
Acos, // Unary
|
||||
Acosh, // Unary
|
||||
Add, // Binary
|
||||
And, // Binary
|
||||
ArgMax, //
|
||||
Asin, // Binary
|
||||
Asinh, // Binary
|
||||
Atan, // Binary
|
||||
Atanh, // Binary
|
||||
Abs, // Unary
|
||||
Acos, // Unary
|
||||
Acosh, // Unary
|
||||
Add, // Binary
|
||||
And, // Binary
|
||||
ArgMax, //
|
||||
Asin, // Binary
|
||||
Asinh, // Binary
|
||||
Atan, // Binary
|
||||
Atanh, // Binary
|
||||
Attention,
|
||||
AveragePool, // Pool
|
||||
BatchNormalization, //
|
||||
Bernoulli, //
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
#pragma once
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
void attentionKernel(const float *inputQ, const float *inputK,
|
||||
const float *inputV, int N, int d, float *output);
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,36 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Return elements, either from X or Y, depending on condition.
|
||||
*
|
||||
*/
|
||||
class AttentionObj : public OperatorObj {
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Attention object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param inputX The input tensor Q.
|
||||
* @param inputY The input tensor K.
|
||||
* @param output The output tensor.
|
||||
* @param inputV The input tensor V.
|
||||
*/
|
||||
AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK, Tensor inputV,
|
||||
Tensor output);
|
||||
OP_CLONE(AttentionObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -32,6 +32,7 @@ class OnnxStub:
|
|||
The Onnx model imported into infinitensor.
|
||||
It can be generated from an Onnx model object.
|
||||
"""
|
||||
|
||||
def __init__(self, model: ModelProto, runtime):
|
||||
self.inputs: Dict[str, backend.Tensor] = {}
|
||||
self.outputs: Dict[str, backend.Tensor] = {}
|
||||
|
@ -60,7 +61,6 @@ class OnnxStub:
|
|||
dims, output.type.tensor_type.elem_type
|
||||
)
|
||||
|
||||
|
||||
node_name = []
|
||||
new_node_name = []
|
||||
for node in model.graph.node:
|
||||
|
@ -632,6 +632,13 @@ class OnnxStub:
|
|||
),
|
||||
):
|
||||
tensors[name] = tensor
|
||||
elif node.op_type == "Attention":
|
||||
tensors[node.output[0]] = self.handler.attention(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors[node.input[2]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Broadcast":
|
||||
tensors[node.output[0]] = self.handler.broadcast(
|
||||
tensors[node.input[0]],
|
||||
|
@ -678,7 +685,6 @@ class OnnxStub:
|
|||
for output in model.graph.output:
|
||||
tensors[output.name].set_output()
|
||||
|
||||
|
||||
################################
|
||||
# Allocate memory space for data
|
||||
################################
|
||||
|
@ -1002,6 +1008,10 @@ class OnnxStub:
|
|||
assert len(inputs) == 3, "Check Where Op must have three inputs."
|
||||
new_inputs = [inputs[2], inputs[0], inputs[1]]
|
||||
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
|
||||
elif ty == backend.OpTypeId.Attention:
|
||||
assert len(inputs) == 3, "Check Attention Op must have three inputs."
|
||||
new_inputs = [inputs[0], inputs[1], inputs[2]]
|
||||
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
|
||||
elif ty == backend.OpTypeId.Expand:
|
||||
shape = backend.expand_shape_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape))
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/all_gather.h"
|
||||
#include "operators/all_reduce.h"
|
||||
#include "operators/attention.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/broadcast.h"
|
||||
#include "operators/concat.h"
|
||||
|
@ -406,7 +407,19 @@ Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
|
|||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::attention(Tensor inputQ, Tensor inputK, Tensor inputV,
|
||||
Tensor output) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<AttentionObj>(std::move(inputQ), std::move(inputK),
|
||||
std::move(inputV), output);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<AttentionObj>(std::move(inputQ), std::move(inputK),
|
||||
std::move(inputV), output)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
static CastType inferCastType(Tensor input, int to) {
|
||||
auto iType = input->getDType();
|
||||
auto oType = DataType(to);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "core/data_type.h"
|
||||
#include "core/graph_handler.h"
|
||||
#include "operators/attention.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
|
@ -67,6 +68,7 @@ void export_values(py::module &m) {
|
|||
.def(py::init<decltype(OpType::type)>())
|
||||
.def("id", getId, policy::automatic);
|
||||
py::enum_<decltype(OpType::type)>(m, "OpTypeId")
|
||||
.VALUE(OpType, Attention)
|
||||
.VALUE(OpType, Conv)
|
||||
.VALUE(OpType, MatMul)
|
||||
.VALUE(OpType, ConvTranspose)
|
||||
|
@ -424,6 +426,7 @@ void init_graph_builder(py::module &m) {
|
|||
policy::reference);
|
||||
py::class_<Handler>(m, "GraphHandler")
|
||||
.def(py::init<Runtime>())
|
||||
.def("attention", &Handler::attention, policy::move)
|
||||
.def("tensor", &Handler::tensor, policy::move)
|
||||
.def("conv", &Handler::conv, policy::move)
|
||||
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
#include "operators/attention.h"
|
||||
#include "cuda/cuda_attention.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class AttentionCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<AttentionObj>(_op);
|
||||
|
||||
void *const inputQData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const inputKData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const inputVData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
int N = op->getInputs(0)->getDims()[0];
|
||||
int d = op->getInputs(0)->getDims()[1];
|
||||
|
||||
attentionKernel((float *)inputQData, (float *)inputKData,
|
||||
(float *)inputVData, N, d, (float *)outputData);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Attention, DataType::Float32,
|
||||
AttentionCuda, "Attention_CUDA_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,187 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
|
||||
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||
__global__ void _attentionKernel(const float *__restrict inputQ,
|
||||
const float *__restrict inputK,
|
||||
const float *__restrict inputV, int N, int d,
|
||||
float *__restrict output) {
|
||||
int i = blockIdx.y; // i must < N,Q[i]
|
||||
int phd = threadIdx.x + blockIdx.x * blockDim.x; // V[:,d]
|
||||
|
||||
int phNumN = (N + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
__shared__ float inputS[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||
float newMax;
|
||||
float oldMax;
|
||||
float newSum;
|
||||
|
||||
newMax = -__FLT_MAX__;
|
||||
oldMax = -__FLT_MAX__;
|
||||
newSum = 0.0f;
|
||||
|
||||
float out;
|
||||
out = 0.0f;
|
||||
//---------
|
||||
__shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||
|
||||
__shared__ float sum_partial[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||
int extra = d % BLOCK_DIM_x;
|
||||
int step = (d - extra) / BLOCK_DIM_x;
|
||||
for (int phn = 0; phn < phNumN; phn++) {
|
||||
|
||||
int j = threadIdx.y + phn * BLOCK_DIM_y;
|
||||
|
||||
float sum_r = 0.0f;
|
||||
__syncthreads();
|
||||
if (threadIdx.x < extra) {
|
||||
for (int ind = threadIdx.x * (step + 1);
|
||||
ind < (threadIdx.x + 1) * (step + 1); ind++) {
|
||||
sum_r += inputQ[i * d + ind] * inputK[j * d + ind];
|
||||
}
|
||||
} else {
|
||||
for (int ind = extra * (step + 1) + (threadIdx.x - extra) * step;
|
||||
ind < extra * (step + 1) + (threadIdx.x - extra + 1) * step;
|
||||
ind++) {
|
||||
sum_r += inputQ[i * d + ind] * inputK[j * d + ind];
|
||||
}
|
||||
}
|
||||
if (j < N) {
|
||||
sum_partial[threadIdx.x][threadIdx.y] = sum_r;
|
||||
} else {
|
||||
sum_partial[threadIdx.x][threadIdx.y] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int strip = BLOCK_DIM_x / 2; strip > 0; strip /= 2) {
|
||||
if (threadIdx.x < strip) {
|
||||
sum_partial[threadIdx.x][threadIdx.y] +=
|
||||
sum_partial[threadIdx.x + strip][threadIdx.y];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
float sum_s = sum_partial[0][threadIdx.y];
|
||||
if (j < N) {
|
||||
|
||||
block_sum[threadIdx.x][threadIdx.y] = 1.0f;
|
||||
} else {
|
||||
|
||||
sum_partial[0][threadIdx.y] = -__FLT_MAX__;
|
||||
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip /= 2) {
|
||||
if (threadIdx.y < strip) {
|
||||
if (sum_partial[0][threadIdx.y] >
|
||||
sum_partial[0][threadIdx.y + strip]) {
|
||||
block_sum[threadIdx.x][threadIdx.y] =
|
||||
block_sum[threadIdx.x][threadIdx.y] +
|
||||
block_sum[threadIdx.x][threadIdx.y + strip] *
|
||||
__expf(sum_partial[0][threadIdx.y + strip] -
|
||||
sum_partial[0][threadIdx.y]);
|
||||
} else {
|
||||
block_sum[threadIdx.x][threadIdx.y] =
|
||||
block_sum[threadIdx.x][threadIdx.y + strip] +
|
||||
block_sum[threadIdx.x][threadIdx.y] *
|
||||
__expf(sum_partial[0][threadIdx.y] -
|
||||
sum_partial[0][threadIdx.y + strip]);
|
||||
sum_partial[0][threadIdx.y] =
|
||||
sum_partial[0][threadIdx.y + strip];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (newMax > sum_partial[0][0]) {
|
||||
newSum = newSum + block_sum[threadIdx.x][0] *
|
||||
__expf(sum_partial[0][0] - newMax);
|
||||
} else {
|
||||
newSum = block_sum[threadIdx.x][0] +
|
||||
newSum * __expf(newMax - sum_partial[0][0]);
|
||||
newMax = sum_partial[0][0];
|
||||
}
|
||||
|
||||
if (j < N && phd < d) {
|
||||
inputS[threadIdx.x][threadIdx.y] =
|
||||
__expf(sum_s - newMax) *
|
||||
inputV[(threadIdx.y + phn * BLOCK_DIM_y) * d + phd];
|
||||
} else {
|
||||
inputS[threadIdx.x][threadIdx.y] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip /= 2) {
|
||||
if (threadIdx.y < strip) {
|
||||
inputS[threadIdx.x][threadIdx.y] +=
|
||||
inputS[threadIdx.x][threadIdx.y + strip];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (j < N && phd < d) {
|
||||
out = __expf(oldMax - newMax) * out + inputS[threadIdx.x][0];
|
||||
}
|
||||
oldMax = newMax;
|
||||
}
|
||||
|
||||
if (threadIdx.y + (phNumN - 1) * BLOCK_DIM_y < N && phd < d) {
|
||||
output[i * d + phd] = out * __fdividef(1.0F, newSum);
|
||||
}
|
||||
}
|
||||
namespace infini {
|
||||
void attentionKernel(const float *inputQ, const float *inputK,
|
||||
const float *inputV, int N, int d, float *output) {
|
||||
int num_block_y = N;
|
||||
if (d > 512) {
|
||||
int BLOCK_DIM_x = 1024;
|
||||
int BLOCK_DIM_y = 1;
|
||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<1024, 1>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else if (d > 256) {
|
||||
int BLOCK_DIM_x = 512;
|
||||
int BLOCK_DIM_y = 2;
|
||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<512, 2>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else if (d > 128) {
|
||||
int BLOCK_DIM_x = 256;
|
||||
int BLOCK_DIM_y = 4;
|
||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<256, 4>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else if (d > 64) {
|
||||
int BLOCK_DIM_x = 128;
|
||||
int BLOCK_DIM_y = 8;
|
||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<128, 8>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else if (d > 32) {
|
||||
int BLOCK_DIM_x = 64;
|
||||
int BLOCK_DIM_y = 16;
|
||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<64, 16>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else if (d > 16) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<32, 32>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<16, 64>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,45 @@
|
|||
#include "operators/attention.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
AttentionObj::AttentionObj(GraphObj *graph, Tensor inputQ, Tensor inputK,
|
||||
Tensor inputV, Tensor output)
|
||||
: OperatorObj(OpType::Attention, TensorVec{inputQ, inputK, inputV},
|
||||
{output}) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
AttentionObj::inferShape(const TensorVec &inputs) const {
|
||||
auto shapeQ = inputs[0]->getDims();
|
||||
auto shapeK = inputs[1]->getDims();
|
||||
auto shapeV = inputs[2]->getDims();
|
||||
auto retQK = infer_broadcast(shapeQ, shapeK);
|
||||
auto ret = infer_broadcast(retQK, shapeV);
|
||||
return {{ret}};
|
||||
}
|
||||
|
||||
std::string AttentionObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Attention[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[2]->getDims()) << ",";
|
||||
os << "inputQ=" << inputs[0]->getGuid() << ",";
|
||||
os << "inputK=" << inputs[1]->getGuid() << ",";
|
||||
os << "inputV=" << inputs[2]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> AttentionObj::getWorkloadVector() const {
|
||||
vector<int> ret = getOutput()->getDims();
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> AttentionObj::getOpAttrVector() const {
|
||||
return {type.underlying()};
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,77 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/attention.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void test_attention(const Shape &outputShape, const vector<float> &inputQData,
|
||||
const vector<float> &inputKData,
|
||||
const vector<float> &inputVData,
|
||||
const vector<float> &ExpectData) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto inputV = gCpu->addTensor(outputShape, DataType::Float32);
|
||||
auto inputQ = gCpu->addTensor(outputShape, DataType::Float32);
|
||||
auto inputK = gCpu->addTensor(outputShape, DataType::Float32);
|
||||
|
||||
gCpu->dataMalloc();
|
||||
inputV->copyin(inputVData); //
|
||||
inputQ->copyin(inputQData);
|
||||
inputK->copyin(inputKData); //
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto inputVGpu = gCuda->cloneTensor(inputV);
|
||||
auto inputQGpu = gCuda->cloneTensor(inputQ);
|
||||
auto inputKGpu = gCuda->cloneTensor(inputK);
|
||||
|
||||
auto op = gCuda->addOp<AttentionObj>(inputQGpu, inputKGpu, inputVGpu,
|
||||
nullptr); // AttentionObj
|
||||
gCuda->dataMalloc();
|
||||
inputVGpu->copyin(inputVData);
|
||||
inputQGpu->copyin(inputQData);
|
||||
inputKGpu->copyin(inputKData);
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||
oCpu->printData(); //->printData
|
||||
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||
}
|
||||
|
||||
TEST(CUDA_Attention, run) {
|
||||
test_attention(
|
||||
Shape{6, 5}, vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
|
||||
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
|
||||
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1.,
|
||||
2., 3., 0., 1., 2., 3., 0., 1., 2., 3.,
|
||||
0., 1., 2., 3., 0., 1., 2., 3., 0., 1.},
|
||||
vector<float>{6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
|
||||
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
|
||||
8.536250e-03, 1.004909e+00, 2.017291e+00, 2.945395e+00,
|
||||
1.997352e-02, 1.017340e+00, 2.017291e+00, 2.999871e+00,
|
||||
3.741202e-04, 9.998805e-01, 1.999874e+00, 2.999871e+00,
|
||||
6.507058e-03, 1.001569e+00, 2.000900e+00, 2.991024e+00,
|
||||
6.507058e-03, 1.004909e+00, 1.999979e+00, 2.986577e+00,
|
||||
8.536250e-03, 1.004909e+00});
|
||||
|
||||
test_attention(
|
||||
Shape{4, 3}, // inputQ
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputK
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.}, // inputV
|
||||
vector<float>{0., 1., 2., 3., 0., 1., 2., 3., 0., 1., 2., 3.},
|
||||
vector<float>{0.9640308, 1.9546683, 2.9292183, 2.9460413, 0.0886370,
|
||||
1.0179861, 1.9941283, 2.9905086, 0.0210545, 1.0006673,
|
||||
1.9993325, 2.9894698});
|
||||
|
||||
} // python output
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue