forked from jiuyuan/InfiniTensor
Compare commits
8 Commits
Author | SHA1 | Date |
---|---|---|
![]() |
6458093da4 | |
![]() |
61f6954c99 | |
![]() |
815d0ebf44 | |
![]() |
2fb1c8cf32 | |
![]() |
86877509c1 | |
![]() |
0adac91385 | |
![]() |
269e4ea40c | |
![]() |
2436ccb868 |
|
@ -6,7 +6,7 @@ option(USE_INTELCPU "Support INTELCPU" OFF)
|
|||
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON)
|
||||
option(USE_PROTOBUF "Serialize and deserialize tensors" OFF)
|
||||
option(BUILD_NNET "Build nnet" OFF)
|
||||
option(BUILD_DIST "Build project for distributed running" OFF)
|
||||
option(BUILD_DIST "Build project for distributed running" ON)
|
||||
option(BUILD_TEST "Build tests" OFF)
|
||||
|
||||
if(USE_CUDA)
|
||||
|
|
|
@ -5,6 +5,9 @@
|
|||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#endif
|
||||
namespace infini {
|
||||
|
||||
class GraphHandlerObj {
|
||||
|
@ -64,6 +67,10 @@ class GraphHandlerObj {
|
|||
Tensor transpose(Tensor data, Tensor transposed, Shape perm);
|
||||
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
||||
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
||||
TensorVec attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
|
||||
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||
Tensor position_id, Tensor output_matmul,
|
||||
Tensor output_k_cache, Tensor output_v_cache);
|
||||
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
||||
int num_outputs);
|
||||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||
|
@ -95,12 +102,15 @@ class GraphHandlerObj {
|
|||
|
||||
//------ runtime
|
||||
|
||||
inline void data_malloc() { g->dataMalloc(); }
|
||||
inline void data_malloc() { g->dataMalloc(true); }
|
||||
|
||||
inline void tune() { g->getRuntime()->run(g, true); }
|
||||
|
||||
inline void run() { g->getRuntime()->run(g); }
|
||||
|
||||
#ifdef USE_CUDA
|
||||
inline void run_with_cudagraph() {(as<CudaRuntimeObj>(g->getRuntime()))->runWithCudaGraph(g);}
|
||||
#endif
|
||||
inline double get_perf_time() { return g->getRuntime()->getPerfTime(g); }
|
||||
};
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ struct OpType {
|
|||
Asinh, // Unary
|
||||
Atan, // Unary
|
||||
Atanh, // Unary
|
||||
AttentionKVCache, // Fusion
|
||||
AveragePool, // Pool
|
||||
BatchNormalization, //
|
||||
Bernoulli, //
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
#include <cstdio>
|
||||
|
||||
struct AttentionKVCacheMetadata {
|
||||
int dimSize[4];
|
||||
int stride[4];
|
||||
};
|
||||
|
||||
namespace infini {
|
||||
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
||||
float *input_q, float *input_k, float *input_v,
|
||||
int *position_id, float *output_matmul,
|
||||
const AttentionKVCacheMetadata &compMeta,
|
||||
float *output_O_temp, float *output_sum_temp);
|
||||
|
||||
} // namespace infini
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
#include "core/common.h"
|
||||
#include <cublas_v2.h>
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <cudnn.h>
|
||||
|
@ -111,4 +112,9 @@ inline const char *curandGetErrorString(curandStatus_t error) {
|
|||
|
||||
using CudaPtr = void *;
|
||||
|
||||
class CUDAStream{
|
||||
public:
|
||||
static cudaStream_t stream;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -11,10 +11,15 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
private:
|
||||
cudnnHandle_t cudnn;
|
||||
cublasHandle_t cublas;
|
||||
cublasLtHandle_t cublaslt;
|
||||
std::unique_ptr<CommunicatorObj> comm;
|
||||
CudaPtr workspace;
|
||||
size_t workspaceSize;
|
||||
|
||||
bool cudaGraphCreated=false;
|
||||
cudaGraph_t cudaGraph;
|
||||
cudaGraphExec_t cudaGraphInstance;
|
||||
|
||||
public:
|
||||
explicit CudaRuntimeObj(int deviceId = 0)
|
||||
: RuntimeObj(Device::CUDA, deviceId) {
|
||||
|
@ -22,16 +27,26 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
checkCudaError(cudaSetDevice(deviceId));
|
||||
checkCudnnError(cudnnCreate(&cudnn));
|
||||
checkCublasError(cublasCreate(&cublas));
|
||||
checkCublasError(cublasLtCreate(&cublaslt));
|
||||
// 10GB for Longformer
|
||||
// size_t longformerNum = 3lu * (1 << 30);
|
||||
workspaceSize = 7ll << 30; // 7 GB
|
||||
workspace = alloc(workspaceSize);
|
||||
checkCudaError(cudaStreamCreate(&CUDAStream::stream));
|
||||
checkCudnnError(cudnnSetStream(cudnn, CUDAStream::stream));
|
||||
checkCublasError(cublasSetStream(cublas, CUDAStream::stream));
|
||||
}
|
||||
virtual ~CudaRuntimeObj() {
|
||||
try {
|
||||
if(cudaGraphCreated){
|
||||
checkCudaError(cudaGraphExecDestroy(cudaGraphInstance));
|
||||
checkCudaError(cudaGraphDestroy(cudaGraph));
|
||||
checkCudaError(cudaStreamDestroy(CUDAStream::stream));
|
||||
}
|
||||
dealloc(workspace);
|
||||
checkCudnnError(cudnnDestroy(cudnn));
|
||||
checkCublasError(cublasDestroy(cublas));
|
||||
checkCublasError(cublasLtDestroy(cublaslt));
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << "Error in ~CudaRuntimeObj: " << e.what() << std::endl;
|
||||
}
|
||||
|
@ -52,6 +67,7 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
void dealloc(void *ptr) override { checkCudaError(cudaFree(ptr)); }
|
||||
cudnnHandle_t cudnnHandle() const { return cudnn; }
|
||||
cublasHandle_t cublasHandle() const { return cublas; }
|
||||
cublasLtHandle_t cublasLtHandle() const { return cublaslt; }
|
||||
size_t getWorkspaceSize() const { return workspaceSize; }
|
||||
CudaPtr getWorkspace(size_t size) const {
|
||||
IT_ASSERT(size <= workspaceSize);
|
||||
|
@ -75,6 +91,8 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
|
||||
void runWithoutSync(const Graph &graph) const;
|
||||
|
||||
void runWithCudaGraph(const Graph &graph);
|
||||
|
||||
// init communicator
|
||||
void initComm(const string &name, int worldSize, int rank) final;
|
||||
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Fused Attention with KVCache input operator. All the input and output
|
||||
* tensors should have the same rank except for the position_id.
|
||||
*
|
||||
*/
|
||||
class AttentionKVCacheObj : public OperatorObj {
|
||||
int dim;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new AttentionKVCache object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input_k_cache The k_cache input tensor.
|
||||
* @param input_v_cache The v_cache input tensor.
|
||||
* @param input_q The query input tensor.
|
||||
* @param input_k The key input tensor.
|
||||
* @param input_v The value input tensor.
|
||||
* @param position_id The positon id of the query.
|
||||
* @param output_matmul The query output tensor.
|
||||
* @param output_k_cache The output k_cache tensor.
|
||||
* @param output_v_cache The output v_cache tensor.
|
||||
*/
|
||||
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
||||
Tensor input_v, Tensor position_id,
|
||||
Tensor output_matmul, Tensor output_k_cache,
|
||||
Tensor output_v_cache);
|
||||
OP_CLONE(AttentionKVCacheObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 6; }
|
||||
int numOutputs() const override { return 3; }
|
||||
int getDim() const { return dim; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -23,12 +23,13 @@ from onnx.checker import (
|
|||
ValidationError,
|
||||
)
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from onnx.numpy_helper import to_array
|
||||
from onnx.numpy_helper import to_array, from_array
|
||||
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
||||
from functools import reduce
|
||||
from onnxsim import simplify
|
||||
import copy
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
|
||||
class OnnxStub:
|
||||
|
@ -46,6 +47,9 @@ class OnnxStub:
|
|||
model = model_simp
|
||||
except ValidationError:
|
||||
pass
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
self.inputs: Dict[str, backend.Tensor] = {}
|
||||
self.outputs: Dict[str, backend.Tensor] = {}
|
||||
self.initializer: Dict[int, TensorProto] = {}
|
||||
|
@ -183,15 +187,33 @@ class OnnxStub:
|
|||
op[1],
|
||||
)
|
||||
elif node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
if False:
|
||||
if tensors[node.input[0]].shape()[0] == 1 and tensors[node.input[0]].shape()[1] == 1 \
|
||||
and len(tensors[node.input[1]].shape()) == 2 and node.input[1] in data.keys():
|
||||
data[node.input[1]] = from_array(
|
||||
np.transpose(to_array(data[node.input[1]])))
|
||||
tensors[node.input[1]] = self.handler.tensor(
|
||||
[tensors[node.input[1]].shape()[1], tensors[node.input[1]].shape()[0]],
|
||||
tensors[node.input[1]].dtype())
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
True,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
else:
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
elif node.op_type == "Gemm":
|
||||
attributes = _parse_attribute(
|
||||
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||
|
@ -545,6 +567,18 @@ class OnnxStub:
|
|||
(attr.i for attr in node.attribute if attr.name == "axis")
|
||||
),
|
||||
)
|
||||
elif node.op_type == "AttentionKVCache":
|
||||
tensors[node.output[0]], tensors[node.output[1]], tensors[node.output[2]] = self.handler.attentionKVCache(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors[node.input[2]],
|
||||
tensors[node.input[3]],
|
||||
tensors[node.input[4]],
|
||||
tensors[node.input[5]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors.get(node.output[1]),
|
||||
tensors.get(node.output[2]),
|
||||
)
|
||||
elif node.op_type == "Split":
|
||||
for name, tensor in zip(
|
||||
node.output,
|
||||
|
@ -1090,6 +1124,9 @@ class OnnxStub:
|
|||
def run(self) -> None:
|
||||
self.handler.run()
|
||||
|
||||
def run_with_cudagraph(self) -> None:
|
||||
self.handler.run_with_cudagraph()
|
||||
|
||||
def get_perf_time(self) -> float:
|
||||
self.handler.get_perf_time()
|
||||
|
||||
|
|
|
@ -66,50 +66,33 @@ string GraphObj::toString() const {
|
|||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
bool GraphObj::topo_sort() {
|
||||
if (this->sorted)
|
||||
if (this->sorted) {
|
||||
return true;
|
||||
|
||||
// std::unordered_set<Tensor> inputs;
|
||||
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
|
||||
}
|
||||
std::vector<Operator> sorted;
|
||||
|
||||
while (!waiting.empty()) {
|
||||
std::unordered_set<OperatorObj *> flags;
|
||||
while (sorted.size() < ops.size()) {
|
||||
// Any node is move to sorted in this loop.
|
||||
auto modified = false;
|
||||
// Find head nodes.
|
||||
for (auto it = waiting.begin(); it != waiting.end();) {
|
||||
const auto &this_inputs = (*it)->getInputs();
|
||||
// If none of the input tensors is in waiting list,
|
||||
// this node is a head node.
|
||||
const auto is_head = std::all_of(
|
||||
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
|
||||
auto src = input->getSource();
|
||||
return src // If the source node is in the waiting list,
|
||||
// means that this node is not the head node.
|
||||
? waiting.find(src) == waiting.end()
|
||||
// This tensor has no source node,
|
||||
// it must be a input tensor.
|
||||
: (/*inputs.insert(input),*/ true);
|
||||
});
|
||||
// Moves head node to sorted.
|
||||
if (is_head) {
|
||||
for (auto const &op : ops) {
|
||||
if (flags.find(op.get()) != flags.end()) {
|
||||
continue;
|
||||
}
|
||||
if (const auto &inputs = op->getInputs(); std::all_of(
|
||||
inputs.begin(), inputs.end(), [&](const auto &input) {
|
||||
auto src = input->getSource().get();
|
||||
return !src || flags.find(src) != flags.end();
|
||||
})) {
|
||||
modified = true;
|
||||
sorted.emplace_back(std::move(*it));
|
||||
it = waiting.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
sorted.emplace_back(op);
|
||||
flags.insert(op.get());
|
||||
}
|
||||
}
|
||||
// Waiting list never modifies during a pass,
|
||||
// sorting fails.
|
||||
if (!modified) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Done.
|
||||
this->ops = std::move(sorted);
|
||||
return this->sorted = true;
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/all_gather.h"
|
||||
#include "operators/all_reduce.h"
|
||||
#include "operators/attention_kvcache.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/broadcast.h"
|
||||
#include "operators/concat.h"
|
||||
|
@ -239,6 +240,28 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
|
|||
}
|
||||
}
|
||||
|
||||
TensorVec GraphHandlerObj::attentionKVCache(
|
||||
Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
||||
Tensor input_v, Tensor position_id, Tensor output_matmul,
|
||||
Tensor output_k_cache, Tensor output_v_cache) {
|
||||
if (output_matmul && output_k_cache && output_v_cache) {
|
||||
g->addOpWithOutputs<AttentionKVCacheObj>(
|
||||
std::move(input_k_cache), std::move(input_v_cache),
|
||||
std::move(input_q), std::move(input_k), std::move(input_v),
|
||||
std::move(position_id), output_matmul, output_k_cache,
|
||||
output_v_cache);
|
||||
return {output_matmul, output_k_cache, output_v_cache};
|
||||
} else {
|
||||
return g
|
||||
->addOp<AttentionKVCacheObj>(
|
||||
std::move(input_k_cache), std::move(input_v_cache),
|
||||
std::move(input_q), std::move(input_k), std::move(input_v),
|
||||
std::move(position_id), output_matmul, output_k_cache,
|
||||
output_v_cache)
|
||||
->getOutputs();
|
||||
}
|
||||
}
|
||||
|
||||
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||
int axis, int num_outputs) {
|
||||
if (outputs) {
|
||||
|
@ -427,6 +450,7 @@ Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
|
|||
|
||||
static CastType inferCastType(Tensor input, int to) {
|
||||
auto iType = input->getDType();
|
||||
|
||||
auto oType = DataType(to);
|
||||
if (iType == DataType::Float32 && oType == DataType::Float16) {
|
||||
return CastType::Float2Float16;
|
||||
|
|
|
@ -65,8 +65,10 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
|||
if (graph) { // if graph != nullptr, outputs should be created
|
||||
auto dataTypes = inferDataType();
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
IT_ASSERT(!outputs[i], "Find empty output while operator creation");
|
||||
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
|
||||
if(!outputs[i])
|
||||
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
|
||||
else if (shapes[i] != outputs[i]->getDims())
|
||||
return false;
|
||||
}
|
||||
} else { // if outputs have been created, check their shapes
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#ifdef INFINI_USE_NCCL
|
||||
#include "cuda/nccl_communicator.h"
|
||||
#endif
|
||||
|
@ -17,6 +18,7 @@ void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
|
|||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
cudaStream_t infini::CUDAStream::stream;
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -40,6 +42,20 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
|||
}
|
||||
}
|
||||
|
||||
void CudaRuntimeObj::runWithCudaGraph(const Graph &graph) {
|
||||
if(!cudaGraphCreated){
|
||||
cudaStreamBeginCapture(CUDAStream::stream, cudaStreamCaptureModeGlobal);
|
||||
runWithoutSync(graph);
|
||||
cudaStreamEndCapture(CUDAStream::stream, &cudaGraph);
|
||||
cudaGraphInstantiate(&cudaGraphInstance, cudaGraph, NULL, NULL, 0);
|
||||
cudaGraphCreated=true;
|
||||
}
|
||||
else{
|
||||
cudaGraphLaunch(cudaGraphInstance, CUDAStream::stream);
|
||||
}
|
||||
cudaStreamSynchronize(CUDAStream::stream);
|
||||
}
|
||||
|
||||
void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
|
|
|
@ -479,6 +479,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("transpose", &Handler::transpose, policy::move)
|
||||
.def("reshape", &Handler::reshape, policy::move)
|
||||
.def("concat", &Handler::concat, policy::move)
|
||||
.def("attentionKVCache", &Handler::attentionKVCache, policy::move)
|
||||
.def("split", &Handler::split, policy::move)
|
||||
.def("gather", &Handler::gather, policy::move)
|
||||
.def("gatherElements", &Handler::gatherElements, policy::move)
|
||||
|
@ -503,6 +504,9 @@ void init_graph_builder(py::module &m) {
|
|||
.def("get_perf_time", &Handler::get_perf_time, policy::automatic)
|
||||
.def("tune", &Handler::tune, policy::automatic)
|
||||
.def("run", &Handler::run, policy::automatic)
|
||||
#ifdef USE_CUDA
|
||||
.def("run_with_cudagraph", &Handler::run_with_cudagraph, policy::automatic)
|
||||
#endif
|
||||
.def("get_perf_time", &Handler::get_perf_time, policy::automatic);
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
|
|||
.getNcclComm();
|
||||
// TODO: Using default stream 0 for now.
|
||||
checkNcclError(ncclAllReduce(input, output, count, ncclFloat,
|
||||
getRedOp(), comm, 0));
|
||||
getRedOp(), comm, CUDAStream::stream));
|
||||
}
|
||||
|
||||
virtual ncclRedOp_t getRedOp() const = 0;
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
#include "operators/attention_kvcache.h"
|
||||
#include "cuda/cuda_attention_kvcache.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include <functional>
|
||||
|
||||
namespace infini {
|
||||
|
||||
class AttentionKVCacheCompute {
|
||||
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
|
||||
Tensor tensor) const {
|
||||
int nDims = tensor->getRank();
|
||||
auto strides = tensor->getStride();
|
||||
IT_ASSERT(nDims == 4);
|
||||
IT_ASSERT(strides.size() == (size_t)nDims);
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
metadata.dimSize[i] = tensor->getDims().at(i);
|
||||
metadata.stride[i] = strides.at(i);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
||||
Tensor input_k, Tensor input_v, Tensor position_id,
|
||||
Tensor output_matmul, Tensor output_temp_O, Tensor output_temp_sum) const {
|
||||
AttentionKVCacheMetadata metadata;
|
||||
initAttentionKVCacheMetadata(metadata, input_v_cache);
|
||||
|
||||
attention_kvcache_kernel(input_k_cache->getRawDataPtr<float *>(),
|
||||
input_v_cache->getRawDataPtr<float *>(),
|
||||
input_q->getRawDataPtr<float *>(),
|
||||
input_k->getRawDataPtr<float *>(),
|
||||
input_v->getRawDataPtr<float *>(),
|
||||
position_id->getRawDataPtr<int *>(),
|
||||
output_matmul->getRawDataPtr<float *>(),
|
||||
metadata,
|
||||
output_temp_O->getRawDataPtr<float *>(),
|
||||
output_temp_sum->getRawDataPtr<float *>());
|
||||
}
|
||||
};
|
||||
|
||||
class AttentionKVCacheCuda : private AttentionKVCacheCompute,
|
||||
public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
do_compute(_op->getInputs()[0], _op->getInputs()[1],
|
||||
_op->getInputs()[2], _op->getInputs()[3],
|
||||
_op->getInputs()[4], _op->getInputs()[5],
|
||||
_op->getOutputs()[0], _op->getOutputs()[1],
|
||||
_op->getOutputs()[2]);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, DataType::Float32,
|
||||
AttentionKVCacheCuda, "AttentionKVCache_CUDA_Float32");
|
||||
} // namespace infini
|
|
@ -0,0 +1,272 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_attention_kvcache.h"
|
||||
#define WARP_SIZE 32
|
||||
#define BLOCKSIZE WARP_SIZE
|
||||
#define SEQ_UNIT 32
|
||||
|
||||
__global__ void _attention_kvcache_kernel_64(float* input_k_cache,
|
||||
float* input_v_cache,
|
||||
float* input_q,
|
||||
float* input_k,
|
||||
float* input_v,
|
||||
int* position_id,
|
||||
float* output_matmul,
|
||||
AttentionKVCacheMetadata compMeta) {
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
int group_id = threadIdx.x / WARP_SIZE;
|
||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||
|
||||
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
|
||||
return;
|
||||
|
||||
float ptr_V[SEQ_UNIT*2];
|
||||
float ptr_K[SEQ_UNIT*2];
|
||||
float ptr_Q[2];
|
||||
float ptr_P[SEQ_UNIT];
|
||||
|
||||
float ptr_O[2];
|
||||
float ptr_max[1];
|
||||
float ptr_sum[1];
|
||||
|
||||
float ptr_max_last[1];
|
||||
float ptr_sum_last[1];
|
||||
float ptr_O_last[2];
|
||||
|
||||
(float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)];
|
||||
|
||||
int SEQ_LENGTH = position_id[0] + 1;
|
||||
|
||||
int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]);
|
||||
|
||||
|
||||
for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){
|
||||
ptr_max_last[0] = ptr_max[0];
|
||||
ptr_sum_last[0] = ptr_sum[0];
|
||||
(float2 &)ptr_O_last[0] = (float2 &)ptr_O[0];
|
||||
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
||||
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
|
||||
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
|
||||
= (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
||||
}
|
||||
else{
|
||||
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
|
||||
= (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
|
||||
(float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
|
||||
(float2 &)ptr_K[idx_SEQ_UNIT * 2];
|
||||
}
|
||||
ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2];
|
||||
ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1];
|
||||
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2) {
|
||||
ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset);
|
||||
}
|
||||
ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2];
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2){
|
||||
ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset);
|
||||
}
|
||||
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
||||
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
|
||||
ptr_P[idx_SEQ_UNIT] /= 8;
|
||||
ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]);
|
||||
}
|
||||
ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]);
|
||||
|
||||
ptr_sum[0] = 0;
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
||||
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]);
|
||||
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
|
||||
}
|
||||
ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0];
|
||||
|
||||
ptr_O[0] = 0;
|
||||
ptr_O[1] = 0;
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
||||
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
|
||||
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
|
||||
= (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
||||
}
|
||||
else{
|
||||
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
|
||||
= (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
|
||||
(float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
|
||||
(float2 &)ptr_V[idx_SEQ_UNIT * 2];
|
||||
}
|
||||
|
||||
ptr_P[idx_SEQ_UNIT] /= ptr_sum[0];
|
||||
|
||||
ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]);
|
||||
ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]);
|
||||
}
|
||||
ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
|
||||
ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
|
||||
}
|
||||
(float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0];
|
||||
}
|
||||
|
||||
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
||||
float* input_v_cache,
|
||||
float* input_q,
|
||||
float* input_k,
|
||||
float* input_v,
|
||||
int* position_id,
|
||||
AttentionKVCacheMetadata compMeta,
|
||||
float* output_O_temp,
|
||||
float* output_sum_temp) {
|
||||
int seq_length = position_id[0] + 1;
|
||||
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
|
||||
if(blockIdx.y >= stride)
|
||||
return;
|
||||
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
int group_id = threadIdx.x / WARP_SIZE;
|
||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||
int idx_seq = blockIdx.y * SEQ_UNIT;
|
||||
|
||||
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
|
||||
return;
|
||||
|
||||
float ptr_V[SEQ_UNIT*4];
|
||||
float ptr_K[SEQ_UNIT*4];
|
||||
float ptr_Q[4];
|
||||
float ptr_P[SEQ_UNIT] = {0};
|
||||
|
||||
float ptr_O[4] = {0};
|
||||
float ptr_sum[1] = {0};
|
||||
|
||||
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)];
|
||||
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
|
||||
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
|
||||
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
||||
}
|
||||
else{
|
||||
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
|
||||
= (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
|
||||
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
|
||||
(float4 &)ptr_K[idx_SEQ_UNIT * 4];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i ++){
|
||||
ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i];
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2) {
|
||||
ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset);
|
||||
}
|
||||
ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
|
||||
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
|
||||
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
|
||||
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
||||
}
|
||||
else{
|
||||
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
|
||||
= (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
|
||||
(float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]
|
||||
= (float4 &)ptr_V[idx_SEQ_UNIT * 4];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i ++)
|
||||
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i ++)
|
||||
ptr_O[i] /= ptr_sum[0];
|
||||
|
||||
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
|
||||
if(threadIdx.x == 0){
|
||||
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
||||
float* output_matmul,
|
||||
AttentionKVCacheMetadata compMeta,
|
||||
float* output_O_temp,
|
||||
float* output_sum_temp) {
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
int group_id = threadIdx.x / WARP_SIZE;
|
||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||
|
||||
float ptr_O[4] = {0};
|
||||
float ptr_O_sum[4] = {0};
|
||||
float ptr_sum = 0;
|
||||
float ptr_sum_temp;
|
||||
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < size; i ++){
|
||||
(float4 &)ptr_O[0]
|
||||
= (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
|
||||
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k ++)
|
||||
ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
|
||||
ptr_sum += ptr_sum_temp;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k ++)
|
||||
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
|
||||
|
||||
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
|
||||
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
||||
float *input_q, float *input_k,
|
||||
float *input_v, int *position_id, float *output_matmul,
|
||||
const AttentionKVCacheMetadata &compMeta,
|
||||
float *output_O_temp, float *output_sum_temp) {
|
||||
IT_ASSERT(compMeta.dimSize[3] == 64 || compMeta.dimSize[3] == 128);
|
||||
|
||||
int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT;
|
||||
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
|
||||
dim3 blockDim(BLOCKSIZE, 1);
|
||||
|
||||
if(compMeta.dimSize[3] == 64)
|
||||
_attention_kvcache_kernel_64<<<gridDim.x, blockDim, 0, CUDAStream::stream>>>(
|
||||
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta);
|
||||
else{
|
||||
_attention_kvcache_kernel_128_1<<<gridDim, blockDim, 0, CUDAStream::stream>>>(
|
||||
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, compMeta, output_O_temp, output_sum_temp);
|
||||
_attention_kvcache_kernel_128_2<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE, 0, CUDAStream::stream>>>(
|
||||
position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -25,7 +25,7 @@ void clip_kernel(float *input, float *output, int num, float minValue,
|
|||
float maxValue) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_clip_kernel<<<gridsize, blocksize>>>(input, output, num, minValue,
|
||||
_clip_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num, minValue,
|
||||
maxValue);
|
||||
}
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
|||
int blocksize = block_work_size();
|
||||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_div_kernel<<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
|
||||
_div_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
|
||||
b3, c0, c1, c2, c3);
|
||||
}
|
||||
void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||
|
@ -80,7 +80,7 @@ void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
|||
int blocksize = block_work_size();
|
||||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_pow_kernel<<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
|
||||
_pow_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
|
||||
b3, c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ void expandKernel(float *input, float *output, int nDims, int outputsize,
|
|||
SmallArray inputShape, SmallArray outputShape) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
|
||||
_expandKernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
|
||||
_expandKernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter,
|
|||
int oSize) {
|
||||
int blocksize = 32 * 16;
|
||||
int gridsize = (oSize + blocksize - 1) / blocksize;
|
||||
_extend_kernel<<<gridsize, blocksize>>>(in, out, blockSize, blockSizeOuter,
|
||||
_extend_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(in, out, blockSize, blockSizeOuter,
|
||||
oSize);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -46,9 +46,9 @@ void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num) {
|
|||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
if (metaData.indexType == DataType::Int64) {
|
||||
_gather_kernel<int64_t>
|
||||
<<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
<<<gridSize, blockSize, 0, CUDAStream::stream>>>(in, out, metaData, num);
|
||||
} else {
|
||||
_gather_kernel<int><<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
_gather_kernel<int><<<gridSize, blockSize, 0, CUDAStream::stream>>>(in, out, metaData, num);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -40,22 +40,22 @@ void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
|
|||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
if (metaData.dataType == DataType::Float32 &&
|
||||
metaData.indexType == DataType::Int64) {
|
||||
_gather_elements_kernel<float, int64_t><<<gridSize, blockSize>>>(
|
||||
_gather_elements_kernel<float, int64_t><<<gridSize, blockSize, 0, CUDAStream::stream>>>(
|
||||
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
|
||||
metaData, num);
|
||||
} else if (metaData.dataType == DataType::Int32 &&
|
||||
metaData.indexType == DataType::Int64) {
|
||||
_gather_elements_kernel<int, int64_t><<<gridSize, blockSize>>>(
|
||||
_gather_elements_kernel<int, int64_t><<<gridSize, blockSize, 0, CUDAStream::stream>>>(
|
||||
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
|
||||
num);
|
||||
} else if (metaData.dataType == DataType::Float32 &&
|
||||
metaData.indexType == DataType::Int32) {
|
||||
_gather_elements_kernel<float, int><<<gridSize, blockSize>>>(
|
||||
_gather_elements_kernel<float, int><<<gridSize, blockSize, 0, CUDAStream::stream>>>(
|
||||
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
|
||||
metaData, num);
|
||||
} else if (metaData.dataType == DataType::Int32 &&
|
||||
metaData.indexType == DataType::Int32) {
|
||||
_gather_elements_kernel<int, int><<<gridSize, blockSize>>>(
|
||||
_gather_elements_kernel<int, int><<<gridSize, blockSize, 0, CUDAStream::stream>>>(
|
||||
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
|
||||
num);
|
||||
} else {
|
||||
|
|
|
@ -33,6 +33,7 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
|||
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
|
||||
};
|
||||
class matmulCublas : public Kernel {
|
||||
int a;
|
||||
bool do_compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const {
|
||||
auto op = as<MatmulObj>(_op);
|
||||
|
@ -43,6 +44,8 @@ class matmulCublas : public Kernel {
|
|||
auto record = as<MatmulCublasPerfRecordObj>(_record);
|
||||
|
||||
const auto [b, m, n, k] = op->getBMNK();
|
||||
// std::cout << b << " " << m << " " << n << " " << k << std::endl;
|
||||
// std::cout << op->getTransA() << " " << op->getTransB() << std::endl;
|
||||
auto opA =
|
||||
op->getTransA() ? CUBLAS_OP_T : CUBLAS_OP_N; // BLAS_N = col major
|
||||
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
|
@ -95,10 +98,41 @@ class matmulCublas : public Kernel {
|
|||
&beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
} else {
|
||||
stat = cublasGemmEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
|
||||
CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
|
||||
// stat = cublasGemmEx(
|
||||
// context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
// CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
|
||||
// CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
|
||||
cublasLtMatmulDesc_t matmulDesc;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
cudaDataType_t dtype = CUDA_R_32F;
|
||||
cublasLtMatrixLayout_t layout_A, layout_B, layout_C;
|
||||
cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType);
|
||||
cublasLtMatrixLayoutCreate(&layout_A, dtype, k, m, k);
|
||||
cublasLtMatrixLayoutCreate(&layout_B, dtype, n, k, n);
|
||||
cublasLtMatrixLayoutCreate(&layout_C, dtype, n, m, n);
|
||||
|
||||
stat = cublasLtMatmul(
|
||||
context->cublasLtHandle(),
|
||||
matmulDesc,
|
||||
&alpha,
|
||||
inBData,
|
||||
layout_B,
|
||||
inAData,
|
||||
layout_A,
|
||||
&beta,
|
||||
outData,
|
||||
layout_C,
|
||||
outData,
|
||||
layout_C,
|
||||
nullptr,
|
||||
NULL,
|
||||
0,
|
||||
CUDAStream::stream);
|
||||
cublasLtMatrixLayoutDestroy(layout_A);
|
||||
cublasLtMatrixLayoutDestroy(layout_B);
|
||||
cublasLtMatrixLayoutDestroy(layout_C);
|
||||
cublasLtMatmulDescDestroy(matmulDesc);
|
||||
}
|
||||
// if (stat != CUBLAS_STATUS_SUCCESS)
|
||||
// cout << cublasGetErrorString(stat);
|
||||
|
|
|
@ -46,7 +46,7 @@ void pad_slice_kernel(float *partData, float *wholeData,
|
|||
bool isPad) {
|
||||
int blockSize = 32 * 16;
|
||||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
_pad_slice_kernel<<<gridSize, blockSize>>>(partData, wholeData, metadata,
|
||||
_pad_slice_kernel<<<gridSize, blockSize, 0, CUDAStream::stream>>>(partData, wholeData, metadata,
|
||||
nDims, num, isPad);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -7,7 +7,7 @@ class CopyCuda : public CudaKernelWithoutConfig {
|
|||
auto inData = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
auto outData = op->getOutputs()[0]->getRawDataPtr<void *>();
|
||||
cudaMemcpyAsync(outData, inData, op->getInputs(0)->getBytes(),
|
||||
cudaMemcpyDeviceToDevice);
|
||||
cudaMemcpyDeviceToDevice, CUDAStream::stream);
|
||||
}
|
||||
};
|
||||
// reshape/flatten/identity all act as copying from input to output.
|
||||
|
|
|
@ -213,7 +213,7 @@ void resize_kernel_nearest(float *in, float *out, const MetaData &metaData,
|
|||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
IT_ASSERT(nearestMode <
|
||||
sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0]));
|
||||
_resize_kernel_nearest<<<gridsize, blocksize>>>(
|
||||
_resize_kernel_nearest<<<gridsize, blocksize, 0, CUDAStream::stream>>>(
|
||||
in, out, metaData, num, coordinateMode, nearestMode);
|
||||
}
|
||||
|
||||
|
@ -223,7 +223,7 @@ void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
|
|||
auto gridsize = (num + blocksize - 1) / blocksize;
|
||||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
_resize_kernel_linear_coeff<<<gridsize, blocksize>>>(in, out, metaData, num,
|
||||
_resize_kernel_linear_coeff<<<gridsize, blocksize, 0, CUDAStream::stream>>>(in, out, metaData, num,
|
||||
coordinateMode);
|
||||
}
|
||||
|
||||
|
@ -233,7 +233,7 @@ void resize_kernel_cubic(float *in, float *out, const MetaData &metaData,
|
|||
auto gridsize = (num + blocksize - 1) / blocksize;
|
||||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
_resize_kernel_cubic_coeff<<<gridsize, blocksize>>>(in, out, metaData, num,
|
||||
_resize_kernel_cubic_coeff<<<gridsize, blocksize, 0, CUDAStream::stream>>>(in, out, metaData, num,
|
||||
coordinateMode);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -141,7 +141,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<1024>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::stream>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
|
@ -150,7 +150,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<32, 32>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::stream>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
|
@ -159,7 +159,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<16, 64>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::stream>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
int BLOCK_DIM_y = 128;
|
||||
|
@ -168,7 +168,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<8, 128>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::stream>>>(input, output, size, dimsize, stride);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
int BLOCK_DIM_y = 256;
|
||||
|
@ -177,7 +177,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<4, 256>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::stream>>>(input, output, size, dimsize, stride);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -59,11 +59,11 @@ void split_concat_kernel(const ElementTensorMetadata &eleMeta,
|
|||
// gridsize = max_n_elements / blockSize
|
||||
int max_n_elements =
|
||||
*std::max_element(eleMeta.nElements, eleMeta.nElements + batchSize);
|
||||
int gridDimX = (max_n_elements - 1) / (32 * 16) + 1;
|
||||
int gridDimX = (max_n_elements + 32 * 16 - 1) / (32 * 16);
|
||||
// each y is a split among the batch
|
||||
dim3 gridSize(gridDimX, batchSize);
|
||||
|
||||
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims,
|
||||
_split_concat_kernel<<<gridSize, blockSize, 0, CUDAStream::stream>>>(eleMeta, compMeta, dim, nDims,
|
||||
isSplit);
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ void transpose_kernel(float *input, float *output, int nDims, int size,
|
|||
SmallArray strides, SmallArray outputShape) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (size + block_work_size() - 1) / block_work_size();
|
||||
_transpose_kernel<<<gridsize, blocksize>>>(input, output, nDims, size,
|
||||
_transpose_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, nDims, size,
|
||||
strides, outputShape);
|
||||
}
|
||||
|
||||
|
|
|
@ -114,67 +114,67 @@ void softmax_kernel(float *input, float *output, size_t num) {
|
|||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_softmax_kernel1<<<1, 1>>>(input, output, num);
|
||||
_softmax_kernel2<<<gridsize, blocksize>>>(input, output, num);
|
||||
_softmax_kernel1<<<1, 1, 0, CUDAStream::stream>>>(input, output, num);
|
||||
_softmax_kernel2<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
|
||||
}
|
||||
void relu_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_relu_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
_relu_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
|
||||
}
|
||||
void sigmoid_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
_sigmoid_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
|
||||
}
|
||||
void hard_sigmoid_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_hard_sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
_hard_sigmoid_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
|
||||
}
|
||||
void hard_swish_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_hard_swish_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
_hard_swish_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
|
||||
}
|
||||
void tanh_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_tanh_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
_tanh_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
|
||||
}
|
||||
void abs_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_abs_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
_abs_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
|
||||
}
|
||||
void sqrt_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_sqrt_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
_sqrt_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
|
||||
}
|
||||
void gelu_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_gelu_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
_gelu_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
|
||||
}
|
||||
void erf_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_erf_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
_erf_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
|
||||
}
|
||||
void neg_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_neg_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
_neg_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
|
||||
}
|
||||
}; // namespace infini
|
||||
|
|
|
@ -73,7 +73,7 @@ void whereKernel(const float *inputX, const float *inputY,
|
|||
}
|
||||
int blocksize = 32 * 16;
|
||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||
_whereKernel<<<gridsize, blocksize>>>(
|
||||
_whereKernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(
|
||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||
inputYShape, conditionShape, outputShape);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
#include "operators/attention_kvcache.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
AttentionKVCacheObj::AttentionKVCacheObj(
|
||||
GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
||||
Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul,
|
||||
Tensor output_k_cache, Tensor output_v_cache)
|
||||
: OperatorObj(OpType::AttentionKVCache,
|
||||
TensorVec{input_k_cache, input_v_cache, input_q, input_k,
|
||||
input_v, position_id},
|
||||
TensorVec{output_matmul, output_k_cache, output_v_cache}) {
|
||||
int rank = inputs[0]->getRank();
|
||||
IT_ASSERT(rank == 4);
|
||||
dim = 2;
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
AttentionKVCacheObj::inferShape(const TensorVec &inputs) const {
|
||||
IT_ASSERT(inputs.size() == 6);
|
||||
Shape dims = inputs[0]->getDims();
|
||||
ShapeElem n = dims.at(dim);
|
||||
dims[dim] = n + 1;
|
||||
return {{inputs[2]->getDims(), dims, dims}};
|
||||
}
|
||||
|
||||
std::string AttentionKVCacheObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "AttentionKVCache[" << getGuid() << "]";
|
||||
os << "(";
|
||||
for (auto input : inputs)
|
||||
os << vecToString(input->getDims()) << ",";
|
||||
os << "dim=" << dim << ",";
|
||||
os << "input=";
|
||||
for (auto input : inputs)
|
||||
os << input->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> AttentionKVCacheObj::getWorkloadVector() const {
|
||||
vector<int> ret = getOutputs()[0]->getDims();
|
||||
ret.emplace(ret.begin(), (int)inputs.size());
|
||||
ret.emplace(ret.begin(), dim);
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> AttentionKVCacheObj::getOpAttrVector() const {
|
||||
return {type.underlying(), dim};
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,68 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/attention_kvcache.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(TestCudaRuntime, CudaGraph) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||
auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||
auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||
auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||
auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
|
||||
|
||||
auto op = gCuda->addOp<AttentionKVCacheObj>(
|
||||
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
|
||||
position_id_d, nullptr, nullptr, nullptr);
|
||||
auto op1 = gCuda->addOp<AttentionKVCacheObj>(input_k_cache_d, input_v_cache_d, op->getOutputs()[0], input_k_d, input_v_d,
|
||||
position_id_d, nullptr, nullptr, nullptr);
|
||||
auto op2 = gCuda->addOp<AttentionKVCacheObj>(input_k_cache_d, input_v_cache_d, op1->getOutputs()[0], input_k_d, input_v_d,
|
||||
position_id_d, nullptr, nullptr, nullptr);
|
||||
gCuda->dataMalloc();
|
||||
|
||||
input_q_d->setData(OneGenerator());
|
||||
input_k_d->setData(OneGenerator());
|
||||
input_v_d->setData(OneGenerator());
|
||||
position_id_d->setData(IncrementalGenerator());
|
||||
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
float milliseconds_1 = 0, milliseconds_2 = 0;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaEventRecord(start);
|
||||
cudaRuntime->run(gCuda);
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
cudaEventElapsedTime(&milliseconds_1, start, stop);
|
||||
printf("without cudaGraph, latency: %f ms\n", milliseconds_1);
|
||||
|
||||
cudaRuntime->runWithCudaGraph(gCuda);
|
||||
cudaRuntime->runWithCudaGraph(gCuda);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaEventRecord(start);
|
||||
cudaRuntime->runWithCudaGraph(gCuda);
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
cudaEventElapsedTime(&milliseconds_2, start, stop);
|
||||
printf("with cudaGraph, latency: %f ms\n", milliseconds_2);
|
||||
EXPECT_GE(milliseconds_1, milliseconds_2);
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,42 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/attention_kvcache.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(AttentionKVCache, Cuda) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||
auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||
auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||
auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||
auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
|
||||
|
||||
auto op = gCuda->addOp<AttentionKVCacheObj>(
|
||||
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
|
||||
position_id_d, nullptr, nullptr, nullptr);
|
||||
gCuda->dataMalloc();
|
||||
|
||||
input_q_d->setData(OneGenerator());
|
||||
input_k_d->setData(OneGenerator());
|
||||
input_v_d->setData(OneGenerator());
|
||||
position_id_d->setData(IncrementalGenerator());
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]);
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue