forked from jiuyuan/InfiniTensor
Compare commits
4 Commits
master
...
accelerate
Author | SHA1 | Date |
---|---|---|
xiaonans | 936797b960 | |
xiaonans | 17bd98d453 | |
xiaonans | 8cc6af0a83 | |
xiaonans | c04910f118 |
|
@ -5,6 +5,10 @@
|
|||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#endif
|
||||
|
||||
namespace infini {
|
||||
|
||||
class GraphHandlerObj {
|
||||
|
@ -32,6 +36,7 @@ class GraphHandlerObj {
|
|||
float momentum, float eps, bool training);
|
||||
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
|
||||
Tensor bias, float eps, int axis, int stash_type);
|
||||
Tensor rmsNorm(Tensor input, Tensor weight, Tensor output);
|
||||
|
||||
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||
int ph, int pw, int sh, int sw, int ceilMode);
|
||||
|
@ -137,6 +142,12 @@ class GraphHandlerObj {
|
|||
inline void run() { g->getRuntime()->run(g); }
|
||||
|
||||
inline double get_perf_time() { return g->getRuntime()->getPerfTime(g); }
|
||||
|
||||
#ifdef USE_CUDA
|
||||
inline void run_with_cudagraph() {
|
||||
(as<CudaRuntimeObj>(g->getRuntime()))->runWithCudaGraph(g);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -158,6 +158,7 @@ struct OpType {
|
|||
RoiAlign,
|
||||
RoPE, // Fusion
|
||||
Round, // Unary
|
||||
RMSNorm, // Fusion
|
||||
STFT,
|
||||
Scan,
|
||||
Scatter,
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include <cuda_profiler_api.h>
|
||||
#include <cudnn.h>
|
||||
#include <curand.h>
|
||||
#include <memory>
|
||||
|
||||
#define checkCudaError(call) \
|
||||
if (auto err = call; err != cudaSuccess) \
|
||||
|
@ -111,4 +112,20 @@ inline const char *curandGetErrorString(curandStatus_t error) {
|
|||
|
||||
using CudaPtr = void *;
|
||||
|
||||
class CUDAStream {
|
||||
public:
|
||||
CUDAStream(const CUDAStream &) = delete;
|
||||
CUDAStream(CUDAStream &&) = delete;
|
||||
void operator=(const CUDAStream &) = delete;
|
||||
void operator=(CUDAStream &&) = delete;
|
||||
static cudaStream_t getCurrentStream() { return _stream; }
|
||||
static void Init() { CUDAStream::_stream = 0; };
|
||||
static void createStream() { checkCudaError(cudaStreamCreate(&_stream)); }
|
||||
static void destroyStream() { checkCudaError(cudaStreamDestroy(_stream)); }
|
||||
|
||||
private:
|
||||
CUDAStream(){};
|
||||
static cudaStream_t _stream;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
#include "operators/rms_norm.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
|
||||
int num_tokens, int hidden_size);
|
||||
|
||||
}; // namespace infini
|
|
@ -14,6 +14,9 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
std::unique_ptr<CommunicatorObj> comm;
|
||||
CudaPtr workspace;
|
||||
size_t workspaceSize;
|
||||
bool isCudaGraphCreated;
|
||||
cudaGraph_t cudaGraph;
|
||||
cudaGraphExec_t cudaGraphInstance;
|
||||
|
||||
public:
|
||||
explicit CudaRuntimeObj(int deviceId = 0)
|
||||
|
@ -26,9 +29,16 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
// size_t longformerNum = 3lu * (1 << 30);
|
||||
workspaceSize = 7ll << 30; // 7 GB
|
||||
workspace = alloc(workspaceSize);
|
||||
isCudaGraphCreated = false;
|
||||
CUDAStream::Init();
|
||||
}
|
||||
virtual ~CudaRuntimeObj() {
|
||||
try {
|
||||
if (isCudaGraphCreated) {
|
||||
checkCudaError(cudaGraphExecDestroy(cudaGraphInstance));
|
||||
checkCudaError(cudaGraphDestroy(cudaGraph));
|
||||
CUDAStream::destroyStream();
|
||||
}
|
||||
dealloc(workspace);
|
||||
checkCudnnError(cudnnDestroy(cudnn));
|
||||
checkCublasError(cublasDestroy(cublas));
|
||||
|
@ -75,6 +85,8 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
|
||||
void runWithoutSync(const Graph &graph) const;
|
||||
|
||||
void runWithCudaGraph(const Graph &graph);
|
||||
|
||||
// init communicator
|
||||
void initComm(const string &name, int worldSize, int rank) final;
|
||||
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Fused RMSNorm Operator
|
||||
*
|
||||
*/
|
||||
class RMSNormObj : public OperatorObj {
|
||||
int dim;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new RMSNorm object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
*/
|
||||
RMSNormObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output);
|
||||
OP_CLONE(RMSNormObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
int getDim() const { return dim; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -39,15 +39,15 @@ class OnnxStub:
|
|||
|
||||
def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False):
|
||||
# We use some user-defined operators for distributed inference
|
||||
try:
|
||||
# onnx simplifier performs inplace simplify
|
||||
model_simp, check = simplify(copy.deepcopy(model))
|
||||
if check:
|
||||
model = model_simp
|
||||
except ValidationError:
|
||||
pass
|
||||
except RuntimeError:
|
||||
pass
|
||||
# try:
|
||||
# # onnx simplifier performs inplace simplify
|
||||
# model_simp, check = simplify(copy.deepcopy(model))
|
||||
# if check:
|
||||
# model = model_simp
|
||||
# except ValidationError:
|
||||
# pass
|
||||
# except RuntimeError:
|
||||
# pass
|
||||
|
||||
self.inputs: Dict[str, backend.Tensor] = {}
|
||||
self.outputs: Dict[str, backend.Tensor] = {}
|
||||
|
@ -277,6 +277,12 @@ class OnnxStub:
|
|||
axis,
|
||||
stash_type,
|
||||
)
|
||||
elif node.op_type == "RMSNorm":
|
||||
tensors[node.output[0]] = self.handler.RMSNorm(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "MaxPool":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
|
@ -1376,6 +1382,9 @@ class OnnxStub:
|
|||
def run(self) -> None:
|
||||
self.handler.run()
|
||||
|
||||
def run_with_cudagraph(self) -> None:
|
||||
self.handler.run_with_cudagraph()
|
||||
|
||||
def get_perf_time(self) -> float:
|
||||
self.handler.get_perf_time()
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "operators/reduce.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/resize.h"
|
||||
#include "operators/rms_norm.h"
|
||||
#include "operators/rope.h"
|
||||
#include "operators/send.h"
|
||||
#include "operators/slice.h"
|
||||
|
@ -122,6 +123,16 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight), output);
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
|
||||
int dh, int dw, int ph, int pw, int sh, int sw,
|
||||
int ceilMode) {
|
||||
|
|
|
@ -19,7 +19,6 @@ void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
|
|||
}
|
||||
|
||||
namespace infini {
|
||||
|
||||
void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
|
@ -39,6 +38,27 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
|||
}
|
||||
}
|
||||
|
||||
void CudaRuntimeObj::runWithCudaGraph(const Graph &graph) {
|
||||
if (!isCudaGraphCreated) {
|
||||
CUDAStream::createStream();
|
||||
checkCudnnError(cudnnSetStream(cudnn, CUDAStream::getCurrentStream()));
|
||||
checkCublasError(
|
||||
cublasSetStream(cublas, CUDAStream::getCurrentStream()));
|
||||
checkCudaError(cudaStreamBeginCapture(CUDAStream::getCurrentStream(),
|
||||
cudaStreamCaptureModeGlobal));
|
||||
runWithoutSync(graph);
|
||||
checkCudaError(
|
||||
cudaStreamEndCapture(CUDAStream::getCurrentStream(), &cudaGraph));
|
||||
checkCudaError(
|
||||
cudaGraphInstantiate(&cudaGraphInstance, cudaGraph, NULL, NULL, 0));
|
||||
isCudaGraphCreated = true;
|
||||
} else {
|
||||
checkCudaError(
|
||||
cudaGraphLaunch(cudaGraphInstance, CUDAStream::getCurrentStream()));
|
||||
}
|
||||
checkCudaError(cudaStreamSynchronize(CUDAStream::getCurrentStream()));
|
||||
}
|
||||
|
||||
void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
|
@ -102,4 +122,5 @@ void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) {
|
|||
#endif
|
||||
}
|
||||
|
||||
cudaStream_t CUDAStream::_stream = 0;
|
||||
} // namespace infini
|
||||
|
|
|
@ -16,7 +16,8 @@ __global__ void cudaPrintFloatImpl(float *x, int len) {
|
|||
namespace infini {
|
||||
|
||||
void cudaPrintFloat(float *x, int len) {
|
||||
cudaPrintFloatImpl<<<1, 1>>>(x, len);
|
||||
cudaPrintFloatImpl
|
||||
<<<1, 1, 0, CUDAStream::getCurrentStream()>>>(x, len);
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
|
|
|
@ -504,6 +504,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("matmul", &Handler::matmul, policy::move)
|
||||
.def("batchNormalization", &Handler::batchNormalization, policy::move)
|
||||
.def("layerNormalization", &Handler::layerNormalization, policy::move)
|
||||
.def("RMSNorm", &Handler::rmsNorm, policy::move)
|
||||
.def("maxPool", &Handler::maxPool, policy::move)
|
||||
.def("avgPool", &Handler::avgPool, policy::move)
|
||||
.def("add", &Handler::add, policy::move)
|
||||
|
@ -571,6 +572,10 @@ void init_graph_builder(py::module &m) {
|
|||
.def("get_perf_time", &Handler::get_perf_time, policy::automatic)
|
||||
.def("tune", &Handler::tune, policy::automatic)
|
||||
.def("run", &Handler::run, policy::automatic)
|
||||
#ifdef USE_CUDA
|
||||
.def("run_with_cudagraph", &Handler::run_with_cudagraph,
|
||||
policy::automatic)
|
||||
#endif
|
||||
.def("shape_infer", &Handler::shape_infer, policy::automatic)
|
||||
.def("change_shape", &Handler::change_shape, policy::automatic)
|
||||
.def("getDims", &Handler::getDims, policy::automatic)
|
||||
|
|
|
@ -28,9 +28,8 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
|
|||
ncclComm_t comm =
|
||||
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
||||
.getNcclComm();
|
||||
// TODO: Using default stream 0 for now.
|
||||
checkNcclError(
|
||||
ncclAllReduce(input, output, count, ncclType, getRedOp(), comm, 0));
|
||||
checkNcclError(ncclAllReduce(input, output, count, ncclType, getRedOp(),
|
||||
comm, CUDAStream::getCurrentStream()));
|
||||
}
|
||||
|
||||
virtual ncclRedOp_t getRedOp() const = 0;
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
#include "cuda/cuda_attention_kvcache.h"
|
||||
#define WARP_SIZE 32
|
||||
#define BLOCKSIZE WARP_SIZE
|
||||
#define SEQ_UNIT 32
|
||||
#define SEQ_UNIT 16
|
||||
|
||||
// ASSUME SEQ_LEN OF Q IS 1
|
||||
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
||||
|
@ -103,7 +103,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
|||
ptr_O[i] /= ptr_sum[0];
|
||||
|
||||
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
|
||||
if(threadIdx.x == 0){
|
||||
if(lane_id == 0){
|
||||
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
|
||||
}
|
||||
|
||||
|
@ -157,13 +157,15 @@ void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
|||
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
|
||||
dim3 blockDim(BLOCKSIZE, 1);
|
||||
|
||||
assert(compMeta.dimSize[3] == 128);
|
||||
_attention_kvcache_kernel_128_1<<<gridDim, blockDim>>>(
|
||||
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id,
|
||||
_attention_kvcache_kernel_128_1
|
||||
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input_k_cache, input_v_cache, input_q, input_k, input_v, position_id,
|
||||
compMeta, output_O_temp, output_sum_temp);
|
||||
_attention_kvcache_kernel_128_2<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE>>>(
|
||||
position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
|
||||
|
||||
|
||||
_attention_kvcache_kernel_128_2
|
||||
<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE,
|
||||
0, CUDAStream::getCurrentStream()>>>
|
||||
(position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -25,8 +25,9 @@ void clip_kernel(float *input, float *output, int num, float minValue,
|
|||
float maxValue) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_clip_kernel<<<gridsize, blocksize>>>(input, output, num, minValue,
|
||||
maxValue);
|
||||
_clip_kernel
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num, minValue, maxValue);
|
||||
}
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -131,8 +131,9 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
|||
}
|
||||
|
||||
#define CASE(OP, T) \
|
||||
_##OP##_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
|
||||
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
_##OP##_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
|
||||
#define SWITCH_DTYPE(OP, DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
|
@ -202,11 +203,13 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
|||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
if (dType == 1) {
|
||||
_pow_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
|
||||
b1, b2, b3, c0, c1, c2, c3);
|
||||
_pow_kernel<float>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
} else if (dType == 3) {
|
||||
_pow_kernel<int8_t><<<gridsize, blocksize>>>(
|
||||
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
_pow_kernel<int8_t>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
} else if (dType == 10) {
|
||||
int a_size = a0 * a1 * a2 * a3;
|
||||
int b_size = b0 * b1 * b2 * b3;
|
||||
|
@ -220,8 +223,9 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
|||
for (int i = 0; i < b_size; ++i) {
|
||||
b_float[i] = __half2float(((half *)b)[i]);
|
||||
}
|
||||
_pow_kernel<float><<<gridsize, blocksize>>>(
|
||||
a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0,
|
||||
_pow_kernel<float>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0,
|
||||
b1, b2, b3, c0, c1, c2, c3);
|
||||
for (int i = 0; i < c_size; ++i) {
|
||||
((half *)c)[i] = __float2half(c_float[i]);
|
||||
|
|
|
@ -42,7 +42,8 @@ __global__ void _expandKernel(void *input, void *output, int nDims,
|
|||
namespace infini {
|
||||
|
||||
#define CASE(T) \
|
||||
_expandKernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
|
||||
_expandKernel<DT_CUDA<T>::t><<<gridsize, blocksize, \
|
||||
0, CUDAStream::getCurrentStream()>>>( \
|
||||
input, output, nDims, outputsize, inputShape, outputShape);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
|
|
|
@ -19,7 +19,8 @@ void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter,
|
|||
int oSize) {
|
||||
int blocksize = 32 * 16;
|
||||
int gridsize = (oSize + blocksize - 1) / blocksize;
|
||||
_extend_kernel<<<gridsize, blocksize>>>(in, out, blockSize, blockSizeOuter,
|
||||
oSize);
|
||||
_extend_kernel
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
in, out, blockSize, blockSizeOuter, oSize);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -45,9 +45,12 @@ void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) {
|
|||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
if (metaData.indexType == DataType::Int64) {
|
||||
_gather_kernel<T, int64_t>
|
||||
<<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(in, out, metaData, num);
|
||||
} else {
|
||||
_gather_kernel<T, int><<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
_gather_kernel<T, int>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(in, out, metaData, num);
|
||||
}
|
||||
}
|
||||
template void gather_kernel<float>(float *in, float *out,
|
||||
|
|
|
@ -40,22 +40,26 @@ void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
|
|||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
if (metaData.dataType == DataType::Float32 &&
|
||||
metaData.indexType == DataType::Int64) {
|
||||
_gather_elements_kernel<float, int64_t><<<gridSize, blockSize>>>(
|
||||
_gather_elements_kernel<float, int64_t>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
|
||||
metaData, num);
|
||||
} else if (metaData.dataType == DataType::Int32 &&
|
||||
metaData.indexType == DataType::Int64) {
|
||||
_gather_elements_kernel<int, int64_t><<<gridSize, blockSize>>>(
|
||||
_gather_elements_kernel<int, int64_t>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
|
||||
num);
|
||||
} else if (metaData.dataType == DataType::Float32 &&
|
||||
metaData.indexType == DataType::Int32) {
|
||||
_gather_elements_kernel<float, int><<<gridSize, blockSize>>>(
|
||||
_gather_elements_kernel<float, int>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
|
||||
metaData, num);
|
||||
} else if (metaData.dataType == DataType::Int32 &&
|
||||
metaData.indexType == DataType::Int32) {
|
||||
_gather_elements_kernel<int, int><<<gridSize, blockSize>>>(
|
||||
_gather_elements_kernel<int, int>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
|
||||
num);
|
||||
} else {
|
||||
|
|
|
@ -344,8 +344,8 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<float, 1024>
|
||||
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
|
||||
eps, scaleSize, bias, biasSize);
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
|
@ -353,9 +353,10 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
warpLaynormKernel<float, 32, 32>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
|
@ -363,8 +364,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
warpLaynormKernel<float, 16, 64>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
|
@ -373,8 +375,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
warpLaynormKernel<float, 8, 128>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
|
@ -383,8 +386,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
warpLaynormKernel<float, 4, 256>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
}
|
||||
}
|
||||
|
@ -396,8 +400,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
if (dimsize > 1024) {
|
||||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<float, 1024><<<num_block, BLOCK_DIM>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize);
|
||||
blockLaynormKernel<float, 1024>
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
|
@ -405,8 +410,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
warpLaynormKernel<float, 32, 32>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
|
@ -414,8 +420,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
warpLaynormKernel<float, 16, 64>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
int BLOCK_DIM_y = 128;
|
||||
|
@ -423,8 +430,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
warpLaynormKernel<float, 8, 128>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
int BLOCK_DIM_y = 256;
|
||||
|
@ -432,8 +440,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
warpLaynormKernel<float, 4, 256>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
}
|
||||
}
|
||||
//-----------------
|
||||
|
@ -445,8 +454,8 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<half, 1024>
|
||||
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
|
||||
eps, scaleSize, bias, biasSize);
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
|
@ -454,8 +463,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
warpLaynormKernel<half, 32, 32>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
|
@ -464,8 +474,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
warpLaynormKernel<half, 16, 64>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
|
@ -474,8 +485,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
warpLaynormKernel<half, 8, 128>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
|
@ -484,8 +496,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
warpLaynormKernel<half, 4, 256>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
}
|
||||
}
|
||||
|
@ -497,8 +510,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
if (dimsize > 1024) {
|
||||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<half, 1024><<<num_block, BLOCK_DIM>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize);
|
||||
blockLaynormKernel<half, 1024>
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
|
@ -506,8 +520,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
warpLaynormKernel<half, 32, 32>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
|
@ -515,8 +530,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
warpLaynormKernel<half, 16, 64>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
int BLOCK_DIM_y = 128;
|
||||
|
@ -524,8 +540,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
warpLaynormKernel<half, 8, 128>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
int BLOCK_DIM_y = 256;
|
||||
|
@ -533,8 +550,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
warpLaynormKernel<half, 4, 256>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -48,8 +48,9 @@ __global__ void _pad_slice_kernel(void *part, void *whole,
|
|||
|
||||
namespace infini {
|
||||
#define CASE(T) \
|
||||
_pad_slice_kernel<DT_CUDA<T>::t><<<gridSize, blockSize>>>( \
|
||||
partData, wholeData, metadata, nDims, num, isPad);
|
||||
_pad_slice_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(partData, wholeData, metadata, nDims, num, isPad);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
|
|
|
@ -7,7 +7,8 @@ class CopyCuda : public CudaKernelWithoutConfig {
|
|||
auto inData = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
auto outData = op->getOutputs()[0]->getRawDataPtr<void *>();
|
||||
cudaMemcpyAsync(outData, inData, op->getInputs(0)->getBytes(),
|
||||
cudaMemcpyDeviceToDevice);
|
||||
cudaMemcpyDeviceToDevice,
|
||||
CUDAStream::getCurrentStream());
|
||||
}
|
||||
};
|
||||
// reshape/flatten/identity all act as copying from input to output.
|
||||
|
|
|
@ -213,8 +213,9 @@ void resize_kernel_nearest(float *in, float *out, const MetaData &metaData,
|
|||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
IT_ASSERT(nearestMode <
|
||||
sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0]));
|
||||
_resize_kernel_nearest<<<gridsize, blocksize>>>(
|
||||
in, out, metaData, num, coordinateMode, nearestMode);
|
||||
_resize_kernel_nearest
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(in, out, metaData, num, coordinateMode, nearestMode);
|
||||
}
|
||||
|
||||
void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
|
||||
|
@ -223,8 +224,9 @@ void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
|
|||
auto gridsize = (num + blocksize - 1) / blocksize;
|
||||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
_resize_kernel_linear_coeff<<<gridsize, blocksize>>>(in, out, metaData, num,
|
||||
coordinateMode);
|
||||
_resize_kernel_linear_coeff
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(in, out, metaData, num, coordinateMode);
|
||||
}
|
||||
|
||||
void resize_kernel_cubic(float *in, float *out, const MetaData &metaData,
|
||||
|
@ -233,7 +235,8 @@ void resize_kernel_cubic(float *in, float *out, const MetaData &metaData,
|
|||
auto gridsize = (num + blocksize - 1) / blocksize;
|
||||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
_resize_kernel_cubic_coeff<<<gridsize, blocksize>>>(in, out, metaData, num,
|
||||
coordinateMode);
|
||||
_resize_kernel_cubic_coeff
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(in, out, metaData, num, coordinateMode);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
#include "operators/rms_norm.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_rmsnorm.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class RMSNormCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<RMSNormObj>(_op);
|
||||
|
||||
auto input = op->getInputs(0);
|
||||
auto weight = op->getInputs(1);
|
||||
auto output = op->getOutput();
|
||||
void *const inputData = input->getRawDataPtr<void *>();
|
||||
void *const weightData = weight->getRawDataPtr<void *>();
|
||||
void *const outputData = output->getRawDataPtr<void *>();
|
||||
const auto &inputShape = input->getDims();
|
||||
int nDims = input->getDims().size();
|
||||
|
||||
int hidden_size = inputShape[nDims - 1];
|
||||
int num_tokens = input->size() / hidden_size;
|
||||
IT_ASSERT(hidden_size == (int)weight->size());
|
||||
|
||||
const int dType = op->getDType().getIndex();
|
||||
rmsnorm_kernel(dType, inputData, weightData, outputData, num_tokens,
|
||||
hidden_size);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::RMSNorm, RMSNormCuda, "RMSNorm_CUDA");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,112 @@
|
|||
#include "core/common.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
template<class T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(uint32_t(-1), val, mask);
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template<class T>
|
||||
__inline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
||||
val = warpReduceSum<T>(val);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_tokens, int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||
const float x = ((float*) in)[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if(threadIdx.x == 0){
|
||||
s_variance = rsqrtf(variance / hidden_size + 0.00001f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||
float x = ((float*) in)[blockIdx.x * hidden_size + idx];
|
||||
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define CASE(T) \
|
||||
_rmsnorm_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(input, weight, output, num_tokens, hidden_size);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE(1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE(2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE(3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE(4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE(5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE(6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE(7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE(10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE(11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE(12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE(13) \
|
||||
break; \
|
||||
case 16: \
|
||||
CASE(16) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
|
||||
int num_tokens, int hidden_size) {
|
||||
dim3 blocksize = dim3(std::min(hidden_size, 1024));
|
||||
dim3 gridsize = dim3(num_tokens);
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -22,7 +22,7 @@ class RoPECuda : public CudaKernelWithoutConfig {
|
|||
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
|
||||
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
|
||||
int dim_model = inputShape[2];
|
||||
int dim_head = dim_model / 32;
|
||||
int dim_head = 128;
|
||||
int hidden_stride = dim_model * inputShape[1];
|
||||
int pos_stride = inputShape[1];
|
||||
|
||||
|
|
|
@ -3,13 +3,9 @@
|
|||
#include "cuda/cuda_utility.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1)
|
||||
template <class T>
|
||||
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
||||
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
|
||||
int dim_head, int hidden_stride, int pos_stride) {
|
||||
int batch_id = blockIdx.x;
|
||||
int target_pos = pos[batch_id * pos_stride + blockIdx.y];
|
||||
int ith = blockIdx.z * blockDim.x + threadIdx.x;
|
||||
|
@ -36,8 +32,9 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
|||
|
||||
|
||||
#define CASE(T) \
|
||||
_rope_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
|
||||
pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride);
|
||||
_rope_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
|
@ -82,9 +79,10 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
|||
}
|
||||
|
||||
namespace infini {
|
||||
void rope_kernel(int dType, int * pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
||||
dim3 blocksize = dim3(1024,1,1);
|
||||
dim3 gridsize = dim3(1, 1, 4);
|
||||
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
|
||||
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
||||
dim3 blocksize = dim3(32,1,1);
|
||||
dim3 gridsize = dim3(1, 1, dim_model/32);
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
|
|
|
@ -246,32 +246,38 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 64) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 128>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 32) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 64>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 16) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 32>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 4) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 16>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 4>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
|
@ -280,7 +286,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<float, 32, 32, 32>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
|
@ -289,7 +296,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<float, 16, 64, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
int BLOCK_DIM_y = 128;
|
||||
|
@ -298,7 +306,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<float, 8, 128, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
int BLOCK_DIM_y = 256;
|
||||
|
@ -307,7 +316,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<float, 4, 256, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
}
|
||||
}
|
||||
//------------------
|
||||
|
@ -318,32 +328,38 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
|||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 64) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 128>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 32) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 64>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 16) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 32>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 4) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 16>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 4>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
|
@ -352,7 +368,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 32, 32, 32>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
|
@ -361,7 +378,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 16, 64, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
int BLOCK_DIM_y = 128;
|
||||
|
@ -370,7 +388,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 8, 128, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
int BLOCK_DIM_y = 256;
|
||||
|
@ -379,7 +398,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 4, 256, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -70,7 +70,8 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
|
|||
_op->getOutputs()[0]->getRawDataPtr<void *>();
|
||||
cudaMemcpyAsync(outData, inData,
|
||||
_op->getInputs(1 - i)->getBytes(),
|
||||
cudaMemcpyDeviceToDevice);
|
||||
cudaMemcpyDeviceToDevice,
|
||||
CUDAStream::getCurrentStream());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,8 +63,9 @@ void split_concat_kernel(const ElementTensorMetadata<float> &eleMeta,
|
|||
// each y is a split among the batch
|
||||
dim3 gridSize(gridDimX, batchSize);
|
||||
|
||||
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims,
|
||||
isSplit);
|
||||
_split_concat_kernel
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(eleMeta, compMeta, dim, nDims, isSplit);
|
||||
}
|
||||
void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
|
||||
const ComposedTensorMetadata<half> &compMeta, int dim,
|
||||
|
@ -77,8 +78,9 @@ void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
|
|||
// each y is a split among the batch
|
||||
dim3 gridSize(gridDimX, batchSize);
|
||||
|
||||
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims,
|
||||
isSplit);
|
||||
_split_concat_kernel
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(eleMeta, compMeta, dim, nDims, isSplit);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -23,8 +23,9 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
|
|||
}
|
||||
}
|
||||
#define CASE(T) \
|
||||
_transpose_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
|
||||
input, output, nDims, size, strides, outputShape);
|
||||
_transpose_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(input, output, nDims, size, strides, outputShape);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
|
|
|
@ -148,78 +148,104 @@ template <typename T> void softmax_kernel(T *input, T *output, size_t num) {
|
|||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_softmax_kernel1<T><<<1, 1>>>(input, output, num);
|
||||
_softmax_kernel2<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
_softmax_kernel1<T>
|
||||
<<<1, 1, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
_softmax_kernel2<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void relu_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_relu_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
_relu_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_sigmoid_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
_sigmoid_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T>
|
||||
void hard_sigmoid_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_hard_sigmoid_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
_hard_sigmoid_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_hard_swish_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
_hard_swish_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void tanh_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_tanh_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
_tanh_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void abs_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_abs_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
_abs_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void sqrt_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_sqrt_kernel<<<gridsize, blocksize>>>((T *)input, (T *)output, num);
|
||||
_sqrt_kernel
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
((T *)input, (T *)output, num);
|
||||
}
|
||||
|
||||
template <typename T> void gelu_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_gelu_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
_gelu_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
|
||||
template <typename T> void silu_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_silu_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
_silu_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
|
||||
template <typename T> void erf_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_erf_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
_erf_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void neg_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_neg_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
_neg_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
|
||||
void unary_kernel(const Operator &_op) {
|
||||
|
@ -317,7 +343,9 @@ void cast_kernel(INPUT *input, OUTPUT *output, size_t num) {
|
|||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_cast_kernel<INPUT, OUTPUT><<<gridsize, blocksize>>>(input, output, num);
|
||||
_cast_kernel<INPUT, OUTPUT>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
|
||||
template void cast_kernel<float, half>(float *input, half *output, size_t num);
|
||||
|
|
|
@ -61,7 +61,8 @@ void whereKernel(const float *inputX, const float *inputY,
|
|||
blocksize = 32;
|
||||
}
|
||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||
_whereKernel<float><<<gridsize, blocksize>>>(
|
||||
_whereKernel<float>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
|
||||
}
|
||||
|
@ -85,7 +86,8 @@ void whereKernel(const half *inputX, const half *inputY,
|
|||
blocksize = 32;
|
||||
}
|
||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||
_whereKernel<half><<<gridsize, blocksize>>>(
|
||||
_whereKernel<half>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
#include "operators/rms_norm.h"
|
||||
|
||||
namespace infini {
|
||||
RMSNormObj::RMSNormObj(GraphObj *graph, Tensor input, Tensor weight,
|
||||
Tensor output)
|
||||
: OperatorObj(OpType::RMSNorm, {input, weight}, {output}) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> RMSNormObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
auto input_dim = A->getDims();
|
||||
auto output_dim = input_dim;
|
||||
return {{output_dim}};
|
||||
}
|
||||
|
||||
std::string RMSNormObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << type.toString() << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> RMSNormObj::getWorkloadVector() const {
|
||||
vector<int> ret{type.underlying()};
|
||||
const Shape shape = outputs[0]->getDims();
|
||||
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> RMSNormObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,70 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/attention_kvcache.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(TestCudaRuntime, CudaGraph) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
|
||||
auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
|
||||
auto input_q_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
|
||||
auto input_k_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
|
||||
auto input_v_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
|
||||
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
|
||||
|
||||
auto op = gCuda->addOp<AttentionKVCacheObj>(
|
||||
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
|
||||
position_id_d, nullptr);
|
||||
auto op1 = gCuda->addOp<AttentionKVCacheObj>(
|
||||
input_k_cache_d, input_v_cache_d, op->getOutputs()[0], input_k_d,
|
||||
input_v_d, position_id_d, nullptr);
|
||||
auto op2 = gCuda->addOp<AttentionKVCacheObj>(
|
||||
input_k_cache_d, input_v_cache_d, op1->getOutputs()[0], input_k_d,
|
||||
input_v_d, position_id_d, nullptr);
|
||||
gCuda->dataMalloc();
|
||||
|
||||
input_q_d->setData(OneGenerator());
|
||||
input_k_d->setData(OneGenerator());
|
||||
input_v_d->setData(OneGenerator());
|
||||
position_id_d->setData(IncrementalGenerator());
|
||||
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
float milliseconds_1 = 0, milliseconds_2 = 0;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaEventRecord(start);
|
||||
cudaRuntime->run(gCuda);
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
cudaEventElapsedTime(&milliseconds_1, start, stop);
|
||||
printf("without cudaGraph, latency: %f ms\n", milliseconds_1);
|
||||
|
||||
cudaRuntime->runWithCudaGraph(gCuda);
|
||||
cudaRuntime->runWithCudaGraph(gCuda);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaEventRecord(start);
|
||||
cudaRuntime->runWithCudaGraph(gCuda);
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
cudaEventElapsedTime(&milliseconds_2, start, stop);
|
||||
printf("with cudaGraph, latency: %f ms\n", milliseconds_2);
|
||||
EXPECT_GE(milliseconds_1, milliseconds_2);
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue