From 17bd98d453cfc4b3ba71bf3716cf6dee5e4f1eaf Mon Sep 17 00:00:00 2001 From: xiaonans Date: Tue, 6 Feb 2024 17:04:05 +0800 Subject: [PATCH] modify rope op --- pyinfinitensor/src/pyinfinitensor/onnx.py | 18 +++++++++--------- src/kernels/cuda/rope.cc | 2 +- src/kernels/cuda/rope.cu | 9 ++------- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 1a2e28a7..085ff0c3 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -39,15 +39,15 @@ class OnnxStub: def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False): # We use some user-defined operators for distributed inference - try: - # onnx simplifier performs inplace simplify - model_simp, check = simplify(copy.deepcopy(model)) - if check: - model = model_simp - except ValidationError: - pass - except RuntimeError: - pass + # try: + # # onnx simplifier performs inplace simplify + # model_simp, check = simplify(copy.deepcopy(model)) + # if check: + # model = model_simp + # except ValidationError: + # pass + # except RuntimeError: + # pass self.inputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {} diff --git a/src/kernels/cuda/rope.cc b/src/kernels/cuda/rope.cc index 1ec5cca2..27fc83f4 100644 --- a/src/kernels/cuda/rope.cc +++ b/src/kernels/cuda/rope.cc @@ -22,7 +22,7 @@ class RoPECuda : public CudaKernelWithoutConfig { IT_ASSERT(nDims == 3 && pos->getDims().size() == 2); IT_ASSERT(inputShape[1] == pos->getDims()[1]); int dim_model = inputShape[2]; - int dim_head = dim_model / 32; + int dim_head = 128; int hidden_stride = dim_model * inputShape[1]; int pos_stride = inputShape[1]; diff --git a/src/kernels/cuda/rope.cu b/src/kernels/cuda/rope.cu index 8d35026f..6e947f5c 100644 --- a/src/kernels/cuda/rope.cu +++ b/src/kernels/cuda/rope.cu @@ -3,11 +3,6 @@ #include "cuda/cuda_utility.h" #include "utils/small_array.h" -constexpr unsigned int num_threads() { return 32 * 4; } -constexpr int thread_work_size() { return 4; } -constexpr int block_work_size() { return thread_work_size() * num_threads(); } - -// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1) template __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { @@ -86,8 +81,8 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo namespace infini { void rope_kernel(int dType, int * pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { - dim3 blocksize = dim3(1024,1,1); - dim3 gridsize = dim3(1, 1, 4); + dim3 blocksize = dim3(32,1,1); + dim3 gridsize = dim3(1, 1, dim_model/32); SWITCH_DTYPE(dType) }