modify rope op

This commit is contained in:
xiaonans 2024-02-06 17:04:05 +08:00
parent 8cc6af0a83
commit 17bd98d453
3 changed files with 12 additions and 17 deletions

View File

@ -39,15 +39,15 @@ class OnnxStub:
def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False):
# We use some user-defined operators for distributed inference
try:
# onnx simplifier performs inplace simplify
model_simp, check = simplify(copy.deepcopy(model))
if check:
model = model_simp
except ValidationError:
pass
except RuntimeError:
pass
# try:
# # onnx simplifier performs inplace simplify
# model_simp, check = simplify(copy.deepcopy(model))
# if check:
# model = model_simp
# except ValidationError:
# pass
# except RuntimeError:
# pass
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}

View File

@ -22,7 +22,7 @@ class RoPECuda : public CudaKernelWithoutConfig {
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
int dim_model = inputShape[2];
int dim_head = dim_model / 32;
int dim_head = 128;
int hidden_stride = dim_model * inputShape[1];
int pos_stride = inputShape[1];

View File

@ -3,11 +3,6 @@
#include "cuda/cuda_utility.h"
#include "utils/small_array.h"
constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1)
template <class T>
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
int dim_head, int hidden_stride, int pos_stride) {
@ -86,8 +81,8 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
namespace infini {
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
dim3 blocksize = dim3(1024,1,1);
dim3 gridsize = dim3(1, 1, 4);
dim3 blocksize = dim3(32,1,1);
dim3 gridsize = dim3(1, 1, dim_model/32);
SWITCH_DTYPE(dType)
}