rope and attention ops support multiple batchs/sequences.

This commit is contained in:
xiaonans 2024-04-08 12:01:08 +08:00
parent eb3a2d123d
commit c01e64db50
8 changed files with 162 additions and 156 deletions

View File

@ -3,14 +3,16 @@
#include <cstdio>
struct AttentionKVCacheMetadata {
int dimSize[4];
int stride[4];
int head_dim;
int num_heads;
int num_seqs;
int max_kv_seqlen;
};
namespace infini {
void attention_kvcache_kernel(int dType, void *input_k_cache,
void *input_v_cache, void *input_q, void *input_k,
void *input_v, int *position_id,
void *input_v, int64_t *position_id,
void *output_matmul,
const AttentionKVCacheMetadata &compMeta,
float *output_O_temp, float *output_sum_temp);

View File

@ -5,8 +5,7 @@
namespace infini {
void rope_kernel(int dType, int *pos, void *input, void *output, int size,
int dim_model, int dim_head, int hidden_stride,
int pos_stride);
void rope_kernel(int dType, int64_t *pos, void *input, void *output,
int dim_model, int dim_head, int batchsize, int pos_stride);
}; // namespace infini

View File

@ -20,7 +20,7 @@ class RoPEObj : public OperatorObj {
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
DataType getDType() const { return getInputs(1)->getDType(); }
vector<DataType> inferDataType(const TensorVec &inputs) const override {
return {inputs[1]->getDType()};
};

View File

@ -208,8 +208,9 @@ class OnnxStub:
op[1],
)
elif node.op_type == "MatMul":
if to_array(data[node.input[1]]).dtype == np.float32 \
and type(runtime) == backend.CudaRuntime \
if node.input[1] in data.keys() \
and to_array(data[node.input[1]]).dtype == np.float32 \
and 'cuda_runtime' in dir(backend) \
and tensors[node.input[0]].shape()[0] == 1 \
and tensors[node.input[0]].shape()[1] == 1 \
and len(tensors[node.input[1]].shape()) == 2 \

View File

@ -7,15 +7,19 @@ namespace infini {
class AttentionKVCacheCompute {
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
Tensor tensor) const {
int nDims = tensor->getRank();
auto strides = tensor->getStride();
Tensor input_v_cache,
Tensor position_id) const {
int nDims = input_v_cache->getRank();
auto strides = input_v_cache->getStride();
IT_ASSERT(nDims == 4);
IT_ASSERT(strides.size() == (size_t)nDims);
for (int i = 0; i < nDims; ++i) {
metadata.dimSize[i] = tensor->getDims().at(i);
metadata.stride[i] = strides.at(i);
int dim_position_id = position_id->getRank();
metadata.num_seqs = 1;
for (int i = 0; i < dim_position_id; i++) {
metadata.num_seqs *= position_id->getDims().at(i);
}
metadata.head_dim = input_v_cache->getDims().at(3);
metadata.num_heads = input_v_cache->getDims().at(1);
metadata.max_kv_seqlen = input_v_cache->getDims().at(2);
}
public:
@ -24,17 +28,16 @@ class AttentionKVCacheCompute {
Tensor position_id, Tensor output_matmul,
CudaPtr p_workspace) const {
AttentionKVCacheMetadata metadata;
initAttentionKVCacheMetadata(metadata, input_v_cache);
initAttentionKVCacheMetadata(metadata, input_v_cache, position_id);
attention_kvcache_kernel(dType, input_k_cache->getRawDataPtr<void *>(),
input_v_cache->getRawDataPtr<void *>(),
input_q->getRawDataPtr<void *>(),
input_k->getRawDataPtr<void *>(),
input_v->getRawDataPtr<void *>(),
position_id->getRawDataPtr<int *>(),
output_matmul->getRawDataPtr<void *>(),
metadata, (float *)p_workspace,
(float *)(p_workspace + (1ll << 30)));
attention_kvcache_kernel(
dType, input_k_cache->getRawDataPtr<void *>(),
input_v_cache->getRawDataPtr<void *>(),
input_q->getRawDataPtr<void *>(), input_k->getRawDataPtr<void *>(),
input_v->getRawDataPtr<void *>(),
position_id->getRawDataPtr<int64_t *>(),
output_matmul->getRawDataPtr<void *>(), metadata,
(float *)p_workspace, (float *)(p_workspace + (1ll << 30)));
}
};
@ -44,15 +47,15 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
const RuntimeObj *_context) const override {
auto op = as<AttentionKVCacheObj>(_op);
int dType = op->getDType().getIndex();
IT_ASSERT(dType == 1 || dType == 10);
int position_idx_dtype = op->getInputs()[5]->getDTypeIndex();
IT_ASSERT(dType == 1 || dType == 10 || position_idx_dtype == 7);
size_t workspaceSize = 2ll << 30;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
CudaPtr idxWsData = context->getWorkspace(workspaceSize);
do_compute(dType, op->getInputs()[0], op->getInputs()[1],
op->getInputs()[2], op->getInputs()[3],
op->getInputs()[4], op->getInputs()[5],
op->getOutputs()[0], idxWsData);
op->getInputs()[2], op->getInputs()[3], op->getInputs()[4],
op->getInputs()[5], op->getOutputs()[0], idxWsData);
}
};

View File

@ -1,8 +1,9 @@
#include "cuda/cuda_common.h"
#include "cuda/cuda_attention_kvcache.h"
#define WARP_SIZE 32
#define BLOCKSIZE WARP_SIZE*2
#define SEQ_UNIT 8
#define SEQ_UNIT 16
#define BLOCKSIZE_2 WARP_SIZE*4
#define MAX_PARTITION 1024
// ASSUME SEQ_LEN OF Q IS 1
template <class T>
@ -11,190 +12,186 @@ __global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
T* input_q,
T* input_k,
T* input_v,
int* position_id,
int64_t* position_id,
AttentionKVCacheMetadata compMeta,
half* output_O_temp,
float* output_sum_temp) {
int seq_length = position_id[0] + 1;
int seq_length = position_id[blockIdx.y] + 1;
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
if(blockIdx.y >= stride)
if(blockIdx.z >= stride)
return;
int lane_id_x2 = threadIdx.x % WARP_SIZE * 2;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
int idx_seq = blockIdx.y * SEQ_UNIT;
int parallel_idx = blockIdx.x + blockIdx.y * gridDim.x;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
return;
int idx_seq = blockIdx.z * SEQ_UNIT;
half ptr_V[4]; // V
half ptr_K[4]; // K
half ptr_Q[4]; // Q
float ptr_P[SEQ_UNIT] = {0};
half reg_V[4];
half reg_K[4];
half reg_Q[4];
float reg_P;
float ptr_O[4] = {0};
float ptr_sum[1] = {0};
float reg_O[4] = {0};
float reg_sum = 0;
float temp[4];
bool is_fp16 = sizeof(T) == 2 ? true : false;
int idx_qkv = lane_id_x2 + parallel_idx * compMeta.stride[2];
int idx_qkv = lane_id_x2 + parallel_idx * compMeta.head_dim;
// readin Q
if(!is_fp16){
#pragma unroll
for(int i = 0; i < 4; i += 2){
(float2 &)temp[i] = (float2 &)input_q[idx_qkv + i*WARP_SIZE];
*((half2*)(&ptr_Q[i])) = __float22half2_rn(*((float2*)(&temp[i])));
*((half2*)(&reg_Q[i])) = __float22half2_rn(*((float2*)(&temp[i])));
}
}
else{
#pragma unroll
for(int i = 0; i < 4; i += 2){
(half2 &)ptr_Q[i] = (half2 &)input_q[idx_qkv + i*WARP_SIZE];
(half2 &)reg_Q[i] = (half2 &)input_q[idx_qkv + i*WARP_SIZE];
}
}
int common_idx = lane_id_x2 + (parallel_idx * compMeta.stride[1]);
int common_idx = lane_id_x2 + (parallel_idx * compMeta.max_kv_seqlen * compMeta.head_dim);
// Q*K
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2]);
reg_P = 0;
int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.head_dim);
// readin K & V
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
#pragma unroll
for(int i = 0; i < 4; i += 2){
*((half2*)(&ptr_K[i])) = *((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE]));
*((half2*)(&reg_K[i])) = *((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE]));
*((half2*)(&reg_V[i])) = *((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE]));
}
}
else{
if(!is_fp16){
#pragma unroll
for(int i = 0; i < 4; i += 2){
(float2 &)temp[i] = (float2 &) input_k[idx_qkv + i*WARP_SIZE];
*((half2*)(&ptr_K[i])) = __float22half2_rn(*((float2*)(&temp[i])));
}
}
else{
for(int i = 0; i < 4; i += 2){
(half2 &)ptr_K[i] = (half2 &)input_k[idx_qkv + i*WARP_SIZE];
}
}
for(int i = 0; i < 4; i += 2){
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&ptr_K[i]));
}
}
// * V
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
for(int i = 0; i < 4; i += 2){
*((half2*)(&ptr_V[i])) = *((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE]));
}
}
else{
if(!is_fp16){
for(int i = 0; i < 4; i += 2){
*((half2*)(&reg_K[i])) = __float22half2_rn(*((float2*)(&temp[i])));
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_K[i]));
(float2 &)temp[i] = (float2 &) input_v[idx_qkv + i*WARP_SIZE];
*((half2*)(&ptr_V[i])) = __float22half2_rn(*((float2*)(&temp[i])));
*((half2*)(&reg_V[i])) = __float22half2_rn(*((float2*)(&temp[i])));
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_V[i]));
}
}
else{
#pragma unroll
for(int i = 0; i < 4; i += 2){
(half2 &)ptr_V[i] = (half2 &)input_v[idx_qkv + i*WARP_SIZE];
(half2 &)reg_K[i] = (half2 &)input_k[idx_qkv + i*WARP_SIZE];
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_K[i]));
(half2 &)reg_V[i] = (half2 &)input_v[idx_qkv + i*WARP_SIZE];
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_V[i]));
}
}
for(int i = 0; i < 4; i += 2){
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&ptr_V[i]));
}
}
// Q*K
#pragma unroll
for (int i = 0; i < 4; i ++){
ptr_K[i] = ptr_Q[i] * ptr_K[i];
for (int i = 0; i < 4; i += 2){
(half2 &)reg_K[i] = (half2 &)reg_Q[i] * (half2 &)reg_K[i];
#pragma unroll
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
ptr_K[i] += __shfl_xor_sync(0xffffffff, ptr_K[i], offset);
(half2 &)reg_K[i] += __shfl_xor_sync(0xffffffff, (half2 &)reg_K[i], offset);
}
ptr_P[idx_SEQ_UNIT] += __half2float(ptr_K[i]);
(float2 &) temp[i] = __half22float2((half2 &)reg_K[i]);
reg_P += (temp[i] + temp[i+1]);
(float2 &) temp[i] = __half22float2((half2 &)reg_V[i]);
}
// div sqrt(d)
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
reg_P /= sqrt(128.0);
// softmax
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
reg_P = expf(reg_P);
reg_sum += reg_P;
#pragma unroll
for (int i = 0; i < 4; i ++)
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], __half2float(ptr_V[i]), ptr_O[i]);
reg_O[i] = fmaf(reg_P, temp[i], reg_O[i]);
}
#pragma unroll
for (int i = 0; i < 4; i ++)
ptr_O[i] /= ptr_sum[0];
reg_O[i] /= reg_sum;
#pragma unroll
for(int i = 0; i < 4; i += 2)
(half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = __float22half2_rn((float2 &)ptr_O[i]);
(half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + (blockIdx.z * compMeta.head_dim) + (parallel_idx * compMeta.head_dim * stride)] = __float22half2_rn((float2 &)reg_O[i]);
if(lane_id_x2 == 0){
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
output_sum_temp[blockIdx.z + parallel_idx * stride] = reg_sum;
}
}
template <class T>
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
__global__ void _attention_kvcache_kernel_128_2(int64_t* position_id,
T* output_matmul,
AttentionKVCacheMetadata compMeta,
half* output_O_temp,
float* output_sum_temp) {
int lane_id_x4 = threadIdx.x % WARP_SIZE * 4;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
int lane_id = threadIdx.x % WARP_SIZE;
int parallel_idx = blockIdx.x;
int offset = parallel_idx * compMeta.head_dim;
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
bool is_fp16 = sizeof(T) == 2 ? true : false;
if(size > 1){
float ptr_O[4] = {0};
float ptr_O_sum[4] = {0};
float ptr_sum = 0;
float ptr_sum_temp;
half temp_half[4];
#pragma unroll
for(int i = 0; i < size; i ++){
(float2 &)temp_half[0]
= (float2 &)output_O_temp[lane_id_x4 + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
for(int j = 0; j < 4; j += 2)
(float2 &)ptr_O[j] = __half22float2((half2 &)temp_half[j]);
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
if(size == 1){
if(!is_fp16){
#pragma unroll
for(int k = 0; k < 4; k ++)
ptr_O_sum[k] = fma(ptr_O[k], ptr_sum_temp, ptr_O_sum[k]);
ptr_sum += ptr_sum_temp;
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
output_matmul[i + offset]
= __half2float(output_O_temp[i + offset]);
}
else{
#pragma unroll
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
output_matmul[i + offset]
= output_O_temp[i + offset];
}
return;
}
__shared__ float shm_sum_temp[MAX_PARTITION];
__shared__ float shm_sum[WARP_SIZE];
float temp_sum = 0;
#pragma unroll
for(int i = threadIdx.x; i < size; i += blockDim.x){
shm_sum_temp[i] = output_sum_temp[i + parallel_idx * size];
temp_sum += shm_sum_temp[i];
}
#pragma unroll
for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
temp_sum += __shfl_down_sync(0xffffffff, temp_sum, offset);
if(lane_id == 0)
shm_sum[threadIdx.x/WARP_SIZE] = temp_sum;
__syncthreads();
temp_sum = lane_id < (size + WARP_SIZE - 1) / WARP_SIZE ? shm_sum[lane_id] : 0;
#pragma unroll
for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
temp_sum += __shfl_xor_sync(0xffffffff, temp_sum, offset);
temp_sum = __fdividef(1.0f, temp_sum + 1e-6f);
#pragma unroll
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x){
float acc = 0.0f;
for(int j = 0; j < size; j ++){
acc = fma(__half2float(output_O_temp[i + (j * compMeta.head_dim) + offset * size]) * shm_sum_temp[j], temp_sum, acc);
}
#pragma unroll
for(int k = 0; k < 4; k ++)
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
if(!is_fp16){
(float4 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
output_matmul[i + offset] = acc;
}
else{
for(int j = 0; j < 4; j += 2)
(half2 &)temp_half[j] = __float22half2_rn((float2 &)ptr_O_sum[j]);
(float2 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])] = (float2 &)temp_half[0];
}
}
else{
half temp_half[4];
float temp_float[4];
if(!is_fp16){
(float2 &)temp_half[0]
= (float2 &)output_O_temp[lane_id_x4 + parallel_idx * compMeta.dimSize[3]];
for(int i = 0; i < 4; i += 2)
(float2 &)temp_float[i] = __half22float2((half2 &)temp_half[i]);
(float4 &)output_matmul[lane_id_x4 + parallel_idx * compMeta.dimSize[3]] = (float4 &)temp_float[0];
}
else{
(float2 &)output_matmul[lane_id_x4 + parallel_idx * compMeta.dimSize[3]]
= (float2 &)output_O_temp[lane_id_x4 + parallel_idx * compMeta.dimSize[3]];
output_matmul[i + offset] = __float2half(acc);
}
}
}
@ -203,14 +200,14 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
namespace infini {
void attention_kvcache_kernel(int dType, void *input_k_cache, void *input_v_cache,
void *input_q, void *input_k,
void *input_v, int *position_id, void *output_matmul,
void *input_v, int64_t *position_id, void *output_matmul,
const AttentionKVCacheMetadata &compMeta,
float *output_O_temp, float *output_sum_temp) {
IT_ASSERT(compMeta.dimSize[3] == 128 && (dType == 1 || dType == 10));
IT_ASSERT(dType == 1 || dType == 10);
int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT;
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
dim3 blockDim(BLOCKSIZE, 1);
int gridsize_y = (compMeta.max_kv_seqlen - 1 + SEQ_UNIT) / SEQ_UNIT;
dim3 gridDim(compMeta.num_heads, compMeta.num_seqs, gridsize_y);
dim3 blockDim(WARP_SIZE, 1);
if(dType == 1){
_attention_kvcache_kernel_128_1<float>
@ -219,7 +216,7 @@ void attention_kvcache_kernel(int dType, void *input_k_cache, void *input_v_cach
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
_attention_kvcache_kernel_128_2<float>
<<<compMeta.dimSize[0]*compMeta.dimSize[1], WARP_SIZE,
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
0, CUDAStream::getCurrentStream()>>>
(position_id, (float*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
}
@ -230,7 +227,7 @@ void attention_kvcache_kernel(int dType, void *input_k_cache, void *input_v_cach
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
_attention_kvcache_kernel_128_2<half>
<<<compMeta.dimSize[0]*compMeta.dimSize[1], WARP_SIZE,
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
0, CUDAStream::getCurrentStream()>>>
(position_id, (half*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
}

View File

@ -18,17 +18,19 @@ class RoPECuda : public CudaKernelWithoutConfig {
const auto &inputShape = input->getDims();
int nDims = input->getDims().size();
int size = input->size();
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
IT_ASSERT(inputShape[0] == pos->getDims()[0] &&
inputShape[1] == pos->getDims()[1]);
int position_idx_dtype = op->getInputs()[0]->getDTypeIndex();
IT_ASSERT(position_idx_dtype == 7);
int dim_model = inputShape[2];
int dim_head = 128;
int hidden_stride = dim_model * inputShape[1];
int dim_head = 128; // TODO: get dim_head from the framework
int pos_stride = inputShape[1];
int batchsize = inputShape[0];
const int dType = op->getDType().getIndex();
rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData,
size, dim_model, dim_head, hidden_stride, pos_stride);
rope_kernel(dType, pos->getRawDataPtr<int64_t *>(), inputData,
outputData, dim_model, dim_head, batchsize, pos_stride);
}
};

View File

@ -4,13 +4,15 @@
#include "utils/small_array.h"
template <class T>
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
int dim_head, int hidden_stride, int pos_stride) {
__global__ void _rope_kernel(int64_t* pos, void *in, void *out, int dim_model,
int dim_head, int batchsize, int pos_stride) {
int batch_id = blockIdx.x;
int target_pos = pos[batch_id * pos_stride + blockIdx.y];
int ith = blockIdx.z * blockDim.x + threadIdx.x;
int col = ith % dim_head;
int offset = batch_id * hidden_stride + blockIdx.y * dim_model;
int batch_stride = pos_stride * dim_model;
int offset = batch_id * batch_stride + blockIdx.y * dim_model;
if (ith >= dim_model)
return;
@ -34,7 +36,7 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
#define CASE(T) \
_rope_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride);
(pos, input, output, dim_model, dim_head, batchsize, pos_stride);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
@ -79,10 +81,10 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
}
namespace infini {
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
void rope_kernel(int dType, int64_t * pos, void *input, void *output,
int dim_model, int dim_head, int batchsize, int pos_stride) {
dim3 blocksize = dim3(32,1,1);
dim3 gridsize = dim3(1, 1, dim_model/32);
dim3 gridsize = dim3(batchsize, pos_stride, dim_model/32);
SWITCH_DTYPE(dType)
}