Compare commits

...

8 Commits

Author SHA1 Message Date
xiaonans 6458093da4 fix graph topo & add cublaslt support & others 2023-12-20 16:33:49 +08:00
xiaonans 61f6954c99 [feat] add cudagraph support 2023-12-01 15:38:01 +08:00
xiaonans 815d0ebf44 cleaning 2023-11-28 16:29:48 +08:00
xiaonans 2fb1c8cf32 gemv2N to gemv2T 2023-11-27 16:19:01 +08:00
xiaonans 86877509c1 remove cudamalloc in attention op 2023-11-24 13:14:06 +08:00
xiaonans 0adac91385 kvcache_attention support reduce intra blocks 2023-11-21 17:30:21 +08:00
xiaonans 269e4ea40c add test to attention_kvcache op 2023-11-14 10:21:34 +08:00
xiaonans 2436ccb868 [feature] add fused attention_kvcache operator support 2023-11-10 11:14:38 +08:00
34 changed files with 776 additions and 88 deletions

View File

@ -6,7 +6,7 @@ option(USE_INTELCPU "Support INTELCPU" OFF)
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON)
option(USE_PROTOBUF "Serialize and deserialize tensors" OFF)
option(BUILD_NNET "Build nnet" OFF)
option(BUILD_DIST "Build project for distributed running" OFF)
option(BUILD_DIST "Build project for distributed running" ON)
option(BUILD_TEST "Build tests" OFF)
if(USE_CUDA)

View File

@ -5,6 +5,9 @@
#include <cstdint>
#include <iostream>
#ifdef USE_CUDA
#include "cuda/cuda_runtime.h"
#endif
namespace infini {
class GraphHandlerObj {
@ -64,6 +67,10 @@ class GraphHandlerObj {
Tensor transpose(Tensor data, Tensor transposed, Shape perm);
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
Tensor concat(TensorVec inputs, Tensor output, int dim);
TensorVec attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
Tensor input_q, Tensor input_k, Tensor input_v,
Tensor position_id, Tensor output_matmul,
Tensor output_k_cache, Tensor output_v_cache);
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
int num_outputs);
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
@ -95,12 +102,15 @@ class GraphHandlerObj {
//------ runtime
inline void data_malloc() { g->dataMalloc(); }
inline void data_malloc() { g->dataMalloc(true); }
inline void tune() { g->getRuntime()->run(g, true); }
inline void run() { g->getRuntime()->run(g); }
#ifdef USE_CUDA
inline void run_with_cudagraph() {(as<CudaRuntimeObj>(g->getRuntime()))->runWithCudaGraph(g);}
#endif
inline double get_perf_time() { return g->getRuntime()->getPerfTime(g); }
};

View File

@ -25,6 +25,7 @@ struct OpType {
Asinh, // Unary
Atan, // Unary
Atanh, // Unary
AttentionKVCache, // Fusion
AveragePool, // Pool
BatchNormalization, //
Bernoulli, //

View File

@ -0,0 +1,16 @@
#pragma once
#include <cstdio>
struct AttentionKVCacheMetadata {
int dimSize[4];
int stride[4];
};
namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
float *input_q, float *input_k, float *input_v,
int *position_id, float *output_matmul,
const AttentionKVCacheMetadata &compMeta,
float *output_O_temp, float *output_sum_temp);
} // namespace infini

View File

@ -1,6 +1,7 @@
#pragma once
#include "core/common.h"
#include <cublas_v2.h>
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_profiler_api.h>
#include <cudnn.h>
@ -111,4 +112,9 @@ inline const char *curandGetErrorString(curandStatus_t error) {
using CudaPtr = void *;
class CUDAStream{
public:
static cudaStream_t stream;
};
} // namespace infini

View File

@ -11,10 +11,15 @@ class CudaRuntimeObj : public RuntimeObj {
private:
cudnnHandle_t cudnn;
cublasHandle_t cublas;
cublasLtHandle_t cublaslt;
std::unique_ptr<CommunicatorObj> comm;
CudaPtr workspace;
size_t workspaceSize;
bool cudaGraphCreated=false;
cudaGraph_t cudaGraph;
cudaGraphExec_t cudaGraphInstance;
public:
explicit CudaRuntimeObj(int deviceId = 0)
: RuntimeObj(Device::CUDA, deviceId) {
@ -22,16 +27,26 @@ class CudaRuntimeObj : public RuntimeObj {
checkCudaError(cudaSetDevice(deviceId));
checkCudnnError(cudnnCreate(&cudnn));
checkCublasError(cublasCreate(&cublas));
checkCublasError(cublasLtCreate(&cublaslt));
// 10GB for Longformer
// size_t longformerNum = 3lu * (1 << 30);
workspaceSize = 7ll << 30; // 7 GB
workspace = alloc(workspaceSize);
checkCudaError(cudaStreamCreate(&CUDAStream::stream));
checkCudnnError(cudnnSetStream(cudnn, CUDAStream::stream));
checkCublasError(cublasSetStream(cublas, CUDAStream::stream));
}
virtual ~CudaRuntimeObj() {
try {
if(cudaGraphCreated){
checkCudaError(cudaGraphExecDestroy(cudaGraphInstance));
checkCudaError(cudaGraphDestroy(cudaGraph));
checkCudaError(cudaStreamDestroy(CUDAStream::stream));
}
dealloc(workspace);
checkCudnnError(cudnnDestroy(cudnn));
checkCublasError(cublasDestroy(cublas));
checkCublasError(cublasLtDestroy(cublaslt));
} catch (const std::exception &e) {
std::cerr << "Error in ~CudaRuntimeObj: " << e.what() << std::endl;
}
@ -52,6 +67,7 @@ class CudaRuntimeObj : public RuntimeObj {
void dealloc(void *ptr) override { checkCudaError(cudaFree(ptr)); }
cudnnHandle_t cudnnHandle() const { return cudnn; }
cublasHandle_t cublasHandle() const { return cublas; }
cublasLtHandle_t cublasLtHandle() const { return cublaslt; }
size_t getWorkspaceSize() const { return workspaceSize; }
CudaPtr getWorkspace(size_t size) const {
IT_ASSERT(size <= workspaceSize);
@ -75,6 +91,8 @@ class CudaRuntimeObj : public RuntimeObj {
void runWithoutSync(const Graph &graph) const;
void runWithCudaGraph(const Graph &graph);
// init communicator
void initComm(const string &name, int worldSize, int rank) final;

View File

@ -0,0 +1,46 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Fused Attention with KVCache input operator. All the input and output
* tensors should have the same rank except for the position_id.
*
*/
class AttentionKVCacheObj : public OperatorObj {
int dim;
public:
/**
* @brief Construct a new AttentionKVCache object.
*
* @param graph The computation graph that this operator belongs to.
* @param input_k_cache The k_cache input tensor.
* @param input_v_cache The v_cache input tensor.
* @param input_q The query input tensor.
* @param input_k The key input tensor.
* @param input_v The value input tensor.
* @param position_id The positon id of the query.
* @param output_matmul The query output tensor.
* @param output_k_cache The output k_cache tensor.
* @param output_v_cache The output v_cache tensor.
*/
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q, Tensor input_k,
Tensor input_v, Tensor position_id,
Tensor output_matmul, Tensor output_k_cache,
Tensor output_v_cache);
OP_CLONE(AttentionKVCacheObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return 6; }
int numOutputs() const override { return 3; }
int getDim() const { return dim; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -23,12 +23,13 @@ from onnx.checker import (
ValidationError,
)
from onnx.shape_inference import infer_shapes
from onnx.numpy_helper import to_array
from onnx.numpy_helper import to_array, from_array
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
from functools import reduce
from onnxsim import simplify
import copy
import warnings
import numpy as np
class OnnxStub:
@ -46,6 +47,9 @@ class OnnxStub:
model = model_simp
except ValidationError:
pass
except RuntimeError:
pass
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}
self.initializer: Dict[int, TensorProto] = {}
@ -183,15 +187,33 @@ class OnnxStub:
op[1],
)
elif node.op_type == "MatMul":
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
False,
None,
backend.ActType.Linear,
)
if False:
if tensors[node.input[0]].shape()[0] == 1 and tensors[node.input[0]].shape()[1] == 1 \
and len(tensors[node.input[1]].shape()) == 2 and node.input[1] in data.keys():
data[node.input[1]] = from_array(
np.transpose(to_array(data[node.input[1]])))
tensors[node.input[1]] = self.handler.tensor(
[tensors[node.input[1]].shape()[1], tensors[node.input[1]].shape()[0]],
tensors[node.input[1]].dtype())
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
True,
None,
backend.ActType.Linear,
)
else:
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
False,
None,
backend.ActType.Linear,
)
elif node.op_type == "Gemm":
attributes = _parse_attribute(
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
@ -545,6 +567,18 @@ class OnnxStub:
(attr.i for attr in node.attribute if attr.name == "axis")
),
)
elif node.op_type == "AttentionKVCache":
tensors[node.output[0]], tensors[node.output[1]], tensors[node.output[2]] = self.handler.attentionKVCache(
tensors[node.input[0]],
tensors[node.input[1]],
tensors[node.input[2]],
tensors[node.input[3]],
tensors[node.input[4]],
tensors[node.input[5]],
tensors.get(node.output[0]),
tensors.get(node.output[1]),
tensors.get(node.output[2]),
)
elif node.op_type == "Split":
for name, tensor in zip(
node.output,
@ -1090,6 +1124,9 @@ class OnnxStub:
def run(self) -> None:
self.handler.run()
def run_with_cudagraph(self) -> None:
self.handler.run_with_cudagraph()
def get_perf_time(self) -> float:
self.handler.get_perf_time()

View File

@ -66,50 +66,33 @@ string GraphObj::toString() const {
}
return oss.str();
}
bool GraphObj::topo_sort() {
if (this->sorted)
if (this->sorted) {
return true;
// std::unordered_set<Tensor> inputs;
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
}
std::vector<Operator> sorted;
while (!waiting.empty()) {
std::unordered_set<OperatorObj *> flags;
while (sorted.size() < ops.size()) {
// Any node is move to sorted in this loop.
auto modified = false;
// Find head nodes.
for (auto it = waiting.begin(); it != waiting.end();) {
const auto &this_inputs = (*it)->getInputs();
// If none of the input tensors is in waiting list,
// this node is a head node.
const auto is_head = std::all_of(
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
auto src = input->getSource();
return src // If the source node is in the waiting list,
// means that this node is not the head node.
? waiting.find(src) == waiting.end()
// This tensor has no source node,
// it must be a input tensor.
: (/*inputs.insert(input),*/ true);
});
// Moves head node to sorted.
if (is_head) {
for (auto const &op : ops) {
if (flags.find(op.get()) != flags.end()) {
continue;
}
if (const auto &inputs = op->getInputs(); std::all_of(
inputs.begin(), inputs.end(), [&](const auto &input) {
auto src = input->getSource().get();
return !src || flags.find(src) != flags.end();
})) {
modified = true;
sorted.emplace_back(std::move(*it));
it = waiting.erase(it);
} else {
++it;
sorted.emplace_back(op);
flags.insert(op.get());
}
}
// Waiting list never modifies during a pass,
// sorting fails.
if (!modified) {
return false;
}
}
// Done.
this->ops = std::move(sorted);
return this->sorted = true;
}

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h"
#include "operators/all_gather.h"
#include "operators/all_reduce.h"
#include "operators/attention_kvcache.h"
#include "operators/batch_norm.h"
#include "operators/broadcast.h"
#include "operators/concat.h"
@ -239,6 +240,28 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
}
}
TensorVec GraphHandlerObj::attentionKVCache(
Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k,
Tensor input_v, Tensor position_id, Tensor output_matmul,
Tensor output_k_cache, Tensor output_v_cache) {
if (output_matmul && output_k_cache && output_v_cache) {
g->addOpWithOutputs<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul, output_k_cache,
output_v_cache);
return {output_matmul, output_k_cache, output_v_cache};
} else {
return g
->addOp<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul, output_k_cache,
output_v_cache)
->getOutputs();
}
}
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
int axis, int num_outputs) {
if (outputs) {
@ -427,6 +450,7 @@ Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
static CastType inferCastType(Tensor input, int to) {
auto iType = input->getDType();
auto oType = DataType(to);
if (iType == DataType::Float32 && oType == DataType::Float16) {
return CastType::Float2Float16;

View File

@ -65,8 +65,10 @@ bool OperatorObj::checkValid(GraphObj *graph) {
if (graph) { // if graph != nullptr, outputs should be created
auto dataTypes = inferDataType();
for (size_t i = 0; i < outputs.size(); i++) {
IT_ASSERT(!outputs[i], "Find empty output while operator creation");
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
if(!outputs[i])
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
else if (shapes[i] != outputs[i]->getDims())
return false;
}
} else { // if outputs have been created, check their shapes
for (size_t i = 0; i < shapes.size(); ++i) {

View File

@ -2,6 +2,7 @@
#include "core/kernel.h"
#include "core/perf_engine.h"
#include "core/runtime.h"
#include "cuda/cuda_common.h"
#ifdef INFINI_USE_NCCL
#include "cuda/nccl_communicator.h"
#endif
@ -17,6 +18,7 @@ void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
exit(EXIT_FAILURE);
}
}
cudaStream_t infini::CUDAStream::stream;
namespace infini {
@ -40,6 +42,20 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
}
}
void CudaRuntimeObj::runWithCudaGraph(const Graph &graph) {
if(!cudaGraphCreated){
cudaStreamBeginCapture(CUDAStream::stream, cudaStreamCaptureModeGlobal);
runWithoutSync(graph);
cudaStreamEndCapture(CUDAStream::stream, &cudaGraph);
cudaGraphInstantiate(&cudaGraphInstance, cudaGraph, NULL, NULL, 0);
cudaGraphCreated=true;
}
else{
cudaGraphLaunch(cudaGraphInstance, CUDAStream::stream);
}
cudaStreamSynchronize(CUDAStream::stream);
}
void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
const auto &kernelRegistry = KernelRegistry::getInstance();
auto &perfEngine = PerfEngine::getInstance();

View File

@ -479,6 +479,7 @@ void init_graph_builder(py::module &m) {
.def("transpose", &Handler::transpose, policy::move)
.def("reshape", &Handler::reshape, policy::move)
.def("concat", &Handler::concat, policy::move)
.def("attentionKVCache", &Handler::attentionKVCache, policy::move)
.def("split", &Handler::split, policy::move)
.def("gather", &Handler::gather, policy::move)
.def("gatherElements", &Handler::gatherElements, policy::move)
@ -503,6 +504,9 @@ void init_graph_builder(py::module &m) {
.def("get_perf_time", &Handler::get_perf_time, policy::automatic)
.def("tune", &Handler::tune, policy::automatic)
.def("run", &Handler::run, policy::automatic)
#ifdef USE_CUDA
.def("run_with_cudagraph", &Handler::run_with_cudagraph, policy::automatic)
#endif
.def("get_perf_time", &Handler::get_perf_time, policy::automatic);
}

View File

@ -21,7 +21,7 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
.getNcclComm();
// TODO: Using default stream 0 for now.
checkNcclError(ncclAllReduce(input, output, count, ncclFloat,
getRedOp(), comm, 0));
getRedOp(), comm, CUDAStream::stream));
}
virtual ncclRedOp_t getRedOp() const = 0;

View File

@ -0,0 +1,55 @@
#include "operators/attention_kvcache.h"
#include "cuda/cuda_attention_kvcache.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include <functional>
namespace infini {
class AttentionKVCacheCompute {
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
Tensor tensor) const {
int nDims = tensor->getRank();
auto strides = tensor->getStride();
IT_ASSERT(nDims == 4);
IT_ASSERT(strides.size() == (size_t)nDims);
for (int i = 0; i < nDims; ++i) {
metadata.dimSize[i] = tensor->getDims().at(i);
metadata.stride[i] = strides.at(i);
}
}
public:
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v, Tensor position_id,
Tensor output_matmul, Tensor output_temp_O, Tensor output_temp_sum) const {
AttentionKVCacheMetadata metadata;
initAttentionKVCacheMetadata(metadata, input_v_cache);
attention_kvcache_kernel(input_k_cache->getRawDataPtr<float *>(),
input_v_cache->getRawDataPtr<float *>(),
input_q->getRawDataPtr<float *>(),
input_k->getRawDataPtr<float *>(),
input_v->getRawDataPtr<float *>(),
position_id->getRawDataPtr<int *>(),
output_matmul->getRawDataPtr<float *>(),
metadata,
output_temp_O->getRawDataPtr<float *>(),
output_temp_sum->getRawDataPtr<float *>());
}
};
class AttentionKVCacheCuda : private AttentionKVCacheCompute,
public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
do_compute(_op->getInputs()[0], _op->getInputs()[1],
_op->getInputs()[2], _op->getInputs()[3],
_op->getInputs()[4], _op->getInputs()[5],
_op->getOutputs()[0], _op->getOutputs()[1],
_op->getOutputs()[2]);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, DataType::Float32,
AttentionKVCacheCuda, "AttentionKVCache_CUDA_Float32");
} // namespace infini

View File

@ -0,0 +1,272 @@
#include "cuda/cuda_common.h"
#include "cuda/cuda_attention_kvcache.h"
#define WARP_SIZE 32
#define BLOCKSIZE WARP_SIZE
#define SEQ_UNIT 32
__global__ void _attention_kvcache_kernel_64(float* input_k_cache,
float* input_v_cache,
float* input_q,
float* input_k,
float* input_v,
int* position_id,
float* output_matmul,
AttentionKVCacheMetadata compMeta) {
int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
return;
float ptr_V[SEQ_UNIT*2];
float ptr_K[SEQ_UNIT*2];
float ptr_Q[2];
float ptr_P[SEQ_UNIT];
float ptr_O[2];
float ptr_max[1];
float ptr_sum[1];
float ptr_max_last[1];
float ptr_sum_last[1];
float ptr_O_last[2];
(float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)];
int SEQ_LENGTH = position_id[0] + 1;
int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]);
for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){
ptr_max_last[0] = ptr_max[0];
ptr_sum_last[0] = ptr_sum[0];
(float2 &)ptr_O_last[0] = (float2 &)ptr_O[0];
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
= (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
= (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
(float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float2 &)ptr_K[idx_SEQ_UNIT * 2];
}
ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2];
ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset);
}
ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)];
}
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= 8;
ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]);
}
ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]);
ptr_sum[0] = 0;
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
}
ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0];
ptr_O[0] = 0;
ptr_O[1] = 0;
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
= (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
= (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
(float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float2 &)ptr_V[idx_SEQ_UNIT * 2];
}
ptr_P[idx_SEQ_UNIT] /= ptr_sum[0];
ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]);
ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]);
}
ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
}
(float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0];
}
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
float* input_v_cache,
float* input_q,
float* input_k,
float* input_v,
int* position_id,
AttentionKVCacheMetadata compMeta,
float* output_O_temp,
float* output_sum_temp) {
int seq_length = position_id[0] + 1;
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
if(blockIdx.y >= stride)
return;
int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
int idx_seq = blockIdx.y * SEQ_UNIT;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
return;
float ptr_V[SEQ_UNIT*4];
float ptr_K[SEQ_UNIT*4];
float ptr_Q[4];
float ptr_P[SEQ_UNIT] = {0};
float ptr_O[4] = {0};
float ptr_sum[1] = {0};
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)];
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float4 &)ptr_K[idx_SEQ_UNIT * 4];
}
#pragma unroll
for (int i = 0; i < 4; i ++){
ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i];
}
}
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
}
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
}
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]
= (float4 &)ptr_V[idx_SEQ_UNIT * 4];
}
#pragma unroll
for (int i = 0; i < 4; i ++)
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]);
}
#pragma unroll
for (int i = 0; i < 4; i ++)
ptr_O[i] /= ptr_sum[0];
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
if(threadIdx.x == 0){
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
}
}
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
float* output_matmul,
AttentionKVCacheMetadata compMeta,
float* output_O_temp,
float* output_sum_temp) {
int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
float ptr_O[4] = {0};
float ptr_O_sum[4] = {0};
float ptr_sum = 0;
float ptr_sum_temp;
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
#pragma unroll
for(int i = 0; i < size; i ++){
(float4 &)ptr_O[0]
= (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
#pragma unroll
for(int k = 0; k < 4; k ++)
ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
ptr_sum += ptr_sum_temp;
}
#pragma unroll
for(int k = 0; k < 4; k ++)
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
}
namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
float *input_q, float *input_k,
float *input_v, int *position_id, float *output_matmul,
const AttentionKVCacheMetadata &compMeta,
float *output_O_temp, float *output_sum_temp) {
IT_ASSERT(compMeta.dimSize[3] == 64 || compMeta.dimSize[3] == 128);
int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT;
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
dim3 blockDim(BLOCKSIZE, 1);
if(compMeta.dimSize[3] == 64)
_attention_kvcache_kernel_64<<<gridDim.x, blockDim, 0, CUDAStream::stream>>>(
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta);
else{
_attention_kvcache_kernel_128_1<<<gridDim, blockDim, 0, CUDAStream::stream>>>(
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, compMeta, output_O_temp, output_sum_temp);
_attention_kvcache_kernel_128_2<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE, 0, CUDAStream::stream>>>(
position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
}
}
} // namespace infini

View File

@ -25,7 +25,7 @@ void clip_kernel(float *input, float *output, int num, float minValue,
float maxValue) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_clip_kernel<<<gridsize, blocksize>>>(input, output, num, minValue,
_clip_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num, minValue,
maxValue);
}

View File

@ -71,7 +71,7 @@ void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
_div_kernel<<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
_div_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
b3, c0, c1, c2, c3);
}
void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
@ -80,7 +80,7 @@ void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
_pow_kernel<<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
_pow_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
b3, c0, c1, c2, c3);
}

View File

@ -42,7 +42,7 @@ void expandKernel(float *input, float *output, int nDims, int outputsize,
SmallArray inputShape, SmallArray outputShape) {
int blocksize = block_work_size();
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
_expandKernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
_expandKernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, nDims, outputsize,
inputShape, outputShape);
}

View File

@ -19,7 +19,7 @@ void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter,
int oSize) {
int blocksize = 32 * 16;
int gridsize = (oSize + blocksize - 1) / blocksize;
_extend_kernel<<<gridsize, blocksize>>>(in, out, blockSize, blockSizeOuter,
_extend_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(in, out, blockSize, blockSizeOuter,
oSize);
}
} // namespace infini

View File

@ -46,9 +46,9 @@ void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num) {
int gridSize = (num + blockSize - 1) / blockSize;
if (metaData.indexType == DataType::Int64) {
_gather_kernel<int64_t>
<<<gridSize, blockSize>>>(in, out, metaData, num);
<<<gridSize, blockSize, 0, CUDAStream::stream>>>(in, out, metaData, num);
} else {
_gather_kernel<int><<<gridSize, blockSize>>>(in, out, metaData, num);
_gather_kernel<int><<<gridSize, blockSize, 0, CUDAStream::stream>>>(in, out, metaData, num);
}
}
} // namespace infini

View File

@ -40,22 +40,22 @@ void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
int gridSize = (num + blockSize - 1) / blockSize;
if (metaData.dataType == DataType::Float32 &&
metaData.indexType == DataType::Int64) {
_gather_elements_kernel<float, int64_t><<<gridSize, blockSize>>>(
_gather_elements_kernel<float, int64_t><<<gridSize, blockSize, 0, CUDAStream::stream>>>(
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
metaData, num);
} else if (metaData.dataType == DataType::Int32 &&
metaData.indexType == DataType::Int64) {
_gather_elements_kernel<int, int64_t><<<gridSize, blockSize>>>(
_gather_elements_kernel<int, int64_t><<<gridSize, blockSize, 0, CUDAStream::stream>>>(
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
num);
} else if (metaData.dataType == DataType::Float32 &&
metaData.indexType == DataType::Int32) {
_gather_elements_kernel<float, int><<<gridSize, blockSize>>>(
_gather_elements_kernel<float, int><<<gridSize, blockSize, 0, CUDAStream::stream>>>(
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
metaData, num);
} else if (metaData.dataType == DataType::Int32 &&
metaData.indexType == DataType::Int32) {
_gather_elements_kernel<int, int><<<gridSize, blockSize>>>(
_gather_elements_kernel<int, int><<<gridSize, blockSize, 0, CUDAStream::stream>>>(
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
num);
} else {

View File

@ -33,6 +33,7 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
};
class matmulCublas : public Kernel {
int a;
bool do_compute(const Operator &_op, const PerfRecord &_record,
const RuntimeObj *_context) const {
auto op = as<MatmulObj>(_op);
@ -43,6 +44,8 @@ class matmulCublas : public Kernel {
auto record = as<MatmulCublasPerfRecordObj>(_record);
const auto [b, m, n, k] = op->getBMNK();
// std::cout << b << " " << m << " " << n << " " << k << std::endl;
// std::cout << op->getTransA() << " " << op->getTransB() << std::endl;
auto opA =
op->getTransA() ? CUBLAS_OP_T : CUBLAS_OP_N; // BLAS_N = col major
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
@ -95,10 +98,41 @@ class matmulCublas : public Kernel {
&beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F,
(cublasGemmAlgo_t)record->algo);
} else {
stat = cublasGemmEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
// stat = cublasGemmEx(
// context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
// CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
// CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
cublasLtMatmulDesc_t matmulDesc;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
cudaDataType_t dtype = CUDA_R_32F;
cublasLtMatrixLayout_t layout_A, layout_B, layout_C;
cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType);
cublasLtMatrixLayoutCreate(&layout_A, dtype, k, m, k);
cublasLtMatrixLayoutCreate(&layout_B, dtype, n, k, n);
cublasLtMatrixLayoutCreate(&layout_C, dtype, n, m, n);
stat = cublasLtMatmul(
context->cublasLtHandle(),
matmulDesc,
&alpha,
inBData,
layout_B,
inAData,
layout_A,
&beta,
outData,
layout_C,
outData,
layout_C,
nullptr,
NULL,
0,
CUDAStream::stream);
cublasLtMatrixLayoutDestroy(layout_A);
cublasLtMatrixLayoutDestroy(layout_B);
cublasLtMatrixLayoutDestroy(layout_C);
cublasLtMatmulDescDestroy(matmulDesc);
}
// if (stat != CUBLAS_STATUS_SUCCESS)
// cout << cublasGetErrorString(stat);

View File

@ -46,7 +46,7 @@ void pad_slice_kernel(float *partData, float *wholeData,
bool isPad) {
int blockSize = 32 * 16;
int gridSize = (num + blockSize - 1) / blockSize;
_pad_slice_kernel<<<gridSize, blockSize>>>(partData, wholeData, metadata,
_pad_slice_kernel<<<gridSize, blockSize, 0, CUDAStream::stream>>>(partData, wholeData, metadata,
nDims, num, isPad);
}
} // namespace infini

View File

@ -7,7 +7,7 @@ class CopyCuda : public CudaKernelWithoutConfig {
auto inData = op->getInputs(0)->getRawDataPtr<void *>();
auto outData = op->getOutputs()[0]->getRawDataPtr<void *>();
cudaMemcpyAsync(outData, inData, op->getInputs(0)->getBytes(),
cudaMemcpyDeviceToDevice);
cudaMemcpyDeviceToDevice, CUDAStream::stream);
}
};
// reshape/flatten/identity all act as copying from input to output.

View File

@ -213,7 +213,7 @@ void resize_kernel_nearest(float *in, float *out, const MetaData &metaData,
sizeof(p_cooridnate_trans_mode_func[0]));
IT_ASSERT(nearestMode <
sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0]));
_resize_kernel_nearest<<<gridsize, blocksize>>>(
_resize_kernel_nearest<<<gridsize, blocksize, 0, CUDAStream::stream>>>(
in, out, metaData, num, coordinateMode, nearestMode);
}
@ -223,7 +223,7 @@ void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
auto gridsize = (num + blocksize - 1) / blocksize;
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
sizeof(p_cooridnate_trans_mode_func[0]));
_resize_kernel_linear_coeff<<<gridsize, blocksize>>>(in, out, metaData, num,
_resize_kernel_linear_coeff<<<gridsize, blocksize, 0, CUDAStream::stream>>>(in, out, metaData, num,
coordinateMode);
}
@ -233,7 +233,7 @@ void resize_kernel_cubic(float *in, float *out, const MetaData &metaData,
auto gridsize = (num + blocksize - 1) / blocksize;
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
sizeof(p_cooridnate_trans_mode_func[0]));
_resize_kernel_cubic_coeff<<<gridsize, blocksize>>>(in, out, metaData, num,
_resize_kernel_cubic_coeff<<<gridsize, blocksize, 0, CUDAStream::stream>>>(in, out, metaData, num,
coordinateMode);
}
} // namespace infini

View File

@ -141,7 +141,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<1024>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::stream>>>(input, output, size, dimsize, stride);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
@ -150,7 +150,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<32, 32>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
<<<grid_dim, block_dim, 0, CUDAStream::stream>>>(input, output, size, dimsize, stride);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
@ -159,7 +159,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<16, 64>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
<<<grid_dim, block_dim, 0, CUDAStream::stream>>>(input, output, size, dimsize, stride);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
@ -168,7 +168,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<8, 128>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
<<<grid_dim, block_dim, 0, CUDAStream::stream>>>(input, output, size, dimsize, stride);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
@ -177,7 +177,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<4, 256>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
<<<grid_dim, block_dim, 0, CUDAStream::stream>>>(input, output, size, dimsize, stride);
}
}
} // namespace infini

View File

@ -59,11 +59,11 @@ void split_concat_kernel(const ElementTensorMetadata &eleMeta,
// gridsize = max_n_elements / blockSize
int max_n_elements =
*std::max_element(eleMeta.nElements, eleMeta.nElements + batchSize);
int gridDimX = (max_n_elements - 1) / (32 * 16) + 1;
int gridDimX = (max_n_elements + 32 * 16 - 1) / (32 * 16);
// each y is a split among the batch
dim3 gridSize(gridDimX, batchSize);
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims,
_split_concat_kernel<<<gridSize, blockSize, 0, CUDAStream::stream>>>(eleMeta, compMeta, dim, nDims,
isSplit);
}

View File

@ -30,7 +30,7 @@ void transpose_kernel(float *input, float *output, int nDims, int size,
SmallArray strides, SmallArray outputShape) {
int blocksize = block_work_size();
int gridsize = (size + block_work_size() - 1) / block_work_size();
_transpose_kernel<<<gridsize, blocksize>>>(input, output, nDims, size,
_transpose_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, nDims, size,
strides, outputShape);
}

View File

@ -114,67 +114,67 @@ void softmax_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_softmax_kernel1<<<1, 1>>>(input, output, num);
_softmax_kernel2<<<gridsize, blocksize>>>(input, output, num);
_softmax_kernel1<<<1, 1, 0, CUDAStream::stream>>>(input, output, num);
_softmax_kernel2<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
}
void relu_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_relu_kernel<<<gridsize, blocksize>>>(input, output, num);
_relu_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
}
void sigmoid_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num);
_sigmoid_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
}
void hard_sigmoid_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_hard_sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num);
_hard_sigmoid_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
}
void hard_swish_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_hard_swish_kernel<<<gridsize, blocksize>>>(input, output, num);
_hard_swish_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
}
void tanh_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_tanh_kernel<<<gridsize, blocksize>>>(input, output, num);
_tanh_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
}
void abs_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_abs_kernel<<<gridsize, blocksize>>>(input, output, num);
_abs_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
}
void sqrt_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_sqrt_kernel<<<gridsize, blocksize>>>(input, output, num);
_sqrt_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
}
void gelu_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_gelu_kernel<<<gridsize, blocksize>>>(input, output, num);
_gelu_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
}
void erf_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_erf_kernel<<<gridsize, blocksize>>>(input, output, num);
_erf_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
}
void neg_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_neg_kernel<<<gridsize, blocksize>>>(input, output, num);
_neg_kernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(input, output, num);
}
}; // namespace infini

View File

@ -73,7 +73,7 @@ void whereKernel(const float *inputX, const float *inputY,
}
int blocksize = 32 * 16;
int gridsize = (outputsize + blocksize - 1) / blocksize;
_whereKernel<<<gridsize, blocksize>>>(
_whereKernel<<<gridsize, blocksize, 0, CUDAStream::stream>>>(
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputYShape, conditionShape, outputShape);
}

View File

@ -0,0 +1,54 @@
#include "operators/attention_kvcache.h"
#include "utils/operator_utils.h"
namespace infini {
AttentionKVCacheObj::AttentionKVCacheObj(
GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul,
Tensor output_k_cache, Tensor output_v_cache)
: OperatorObj(OpType::AttentionKVCache,
TensorVec{input_k_cache, input_v_cache, input_q, input_k,
input_v, position_id},
TensorVec{output_matmul, output_k_cache, output_v_cache}) {
int rank = inputs[0]->getRank();
IT_ASSERT(rank == 4);
dim = 2;
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>>
AttentionKVCacheObj::inferShape(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() == 6);
Shape dims = inputs[0]->getDims();
ShapeElem n = dims.at(dim);
dims[dim] = n + 1;
return {{inputs[2]->getDims(), dims, dims}};
}
std::string AttentionKVCacheObj::toString() const {
std::ostringstream os;
os << "AttentionKVCache[" << getGuid() << "]";
os << "(";
for (auto input : inputs)
os << vecToString(input->getDims()) << ",";
os << "dim=" << dim << ",";
os << "input=";
for (auto input : inputs)
os << input->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> AttentionKVCacheObj::getWorkloadVector() const {
vector<int> ret = getOutputs()[0]->getDims();
ret.emplace(ret.begin(), (int)inputs.size());
ret.emplace(ret.begin(), dim);
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> AttentionKVCacheObj::getOpAttrVector() const {
return {type.underlying(), dim};
}
} // namespace infini

View File

@ -0,0 +1,68 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/attention_kvcache.h"
#include "test.h"
namespace infini {
TEST(TestCudaRuntime, CudaGraph) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
auto op = gCuda->addOp<AttentionKVCacheObj>(
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
position_id_d, nullptr, nullptr, nullptr);
auto op1 = gCuda->addOp<AttentionKVCacheObj>(input_k_cache_d, input_v_cache_d, op->getOutputs()[0], input_k_d, input_v_d,
position_id_d, nullptr, nullptr, nullptr);
auto op2 = gCuda->addOp<AttentionKVCacheObj>(input_k_cache_d, input_v_cache_d, op1->getOutputs()[0], input_k_d, input_v_d,
position_id_d, nullptr, nullptr, nullptr);
gCuda->dataMalloc();
input_q_d->setData(OneGenerator());
input_k_d->setData(OneGenerator());
input_v_d->setData(OneGenerator());
position_id_d->setData(IncrementalGenerator());
cudaRuntime->run(gCuda);
cudaEvent_t start, stop;
float milliseconds_1 = 0, milliseconds_2 = 0;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaDeviceSynchronize();
cudaEventRecord(start);
cudaRuntime->run(gCuda);
cudaEventRecord(stop);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&milliseconds_1, start, stop);
printf("without cudaGraph, latency: %f ms\n", milliseconds_1);
cudaRuntime->runWithCudaGraph(gCuda);
cudaRuntime->runWithCudaGraph(gCuda);
cudaDeviceSynchronize();
cudaEventRecord(start);
cudaRuntime->runWithCudaGraph(gCuda);
cudaEventRecord(stop);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&milliseconds_2, start, stop);
printf("with cudaGraph, latency: %f ms\n", milliseconds_2);
EXPECT_GE(milliseconds_1, milliseconds_2);
}
} // namespace infini

View File

@ -0,0 +1,42 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/attention_kvcache.h"
#include "test.h"
namespace infini {
TEST(AttentionKVCache, Cuda) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
auto op = gCuda->addOp<AttentionKVCacheObj>(
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
position_id_d, nullptr, nullptr, nullptr);
gCuda->dataMalloc();
input_q_d->setData(OneGenerator());
input_k_d->setData(OneGenerator());
input_v_d->setData(OneGenerator());
position_id_d->setData(IncrementalGenerator());
cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]);
EXPECT_TRUE(oCpu->equalData(vector<float>{
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
}
} // namespace infini