fix bugs in rmsnorm op

This commit is contained in:
xiaonans 2024-02-20 10:50:47 +08:00
parent 0f1c04d864
commit 83be7fa373
3 changed files with 6 additions and 5 deletions

View File

@ -125,7 +125,8 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
if (output) {
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight), output);
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight),
output);
return output;
} else {
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)

View File

@ -38,7 +38,7 @@ __global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_token
float variance = 0.0f;
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
const float x = ((float*) in)[blockIdx.x * hidden_size + idx];
const float x = ((T*) in)[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
@ -48,7 +48,7 @@ __global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_token
__syncthreads();
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
float x = ((float*) in)[blockIdx.x * hidden_size + idx];
float x = ((T*) in)[blockIdx.x * hidden_size + idx];
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
}
}