forked from jiuyuan/InfiniTensor
modify rope op
This commit is contained in:
parent
8cc6af0a83
commit
17bd98d453
|
@ -39,15 +39,15 @@ class OnnxStub:
|
|||
|
||||
def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False):
|
||||
# We use some user-defined operators for distributed inference
|
||||
try:
|
||||
# onnx simplifier performs inplace simplify
|
||||
model_simp, check = simplify(copy.deepcopy(model))
|
||||
if check:
|
||||
model = model_simp
|
||||
except ValidationError:
|
||||
pass
|
||||
except RuntimeError:
|
||||
pass
|
||||
# try:
|
||||
# # onnx simplifier performs inplace simplify
|
||||
# model_simp, check = simplify(copy.deepcopy(model))
|
||||
# if check:
|
||||
# model = model_simp
|
||||
# except ValidationError:
|
||||
# pass
|
||||
# except RuntimeError:
|
||||
# pass
|
||||
|
||||
self.inputs: Dict[str, backend.Tensor] = {}
|
||||
self.outputs: Dict[str, backend.Tensor] = {}
|
||||
|
|
|
@ -22,7 +22,7 @@ class RoPECuda : public CudaKernelWithoutConfig {
|
|||
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
|
||||
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
|
||||
int dim_model = inputShape[2];
|
||||
int dim_head = dim_model / 32;
|
||||
int dim_head = 128;
|
||||
int hidden_stride = dim_model * inputShape[1];
|
||||
int pos_stride = inputShape[1];
|
||||
|
||||
|
|
|
@ -3,11 +3,6 @@
|
|||
#include "cuda/cuda_utility.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1)
|
||||
template <class T>
|
||||
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
|
||||
int dim_head, int hidden_stride, int pos_stride) {
|
||||
|
@ -86,8 +81,8 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
|||
namespace infini {
|
||||
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
|
||||
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
||||
dim3 blocksize = dim3(1024,1,1);
|
||||
dim3 gridsize = dim3(1, 1, 4);
|
||||
dim3 blocksize = dim3(32,1,1);
|
||||
dim3 gridsize = dim3(1, 1, dim_model/32);
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue