forked from jiuyuan/InfiniTensor
fix bugs in rmsnorm op
This commit is contained in:
parent
0f1c04d864
commit
83be7fa373
|
@ -125,7 +125,8 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
|
|||
|
||||
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight), output);
|
||||
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight),
|
||||
output);
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)
|
||||
|
|
|
@ -38,7 +38,7 @@ __global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_token
|
|||
float variance = 0.0f;
|
||||
|
||||
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||
const float x = ((float*) in)[blockIdx.x * hidden_size + idx];
|
||||
const float x = ((T*) in)[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
|
@ -48,7 +48,7 @@ __global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_token
|
|||
__syncthreads();
|
||||
|
||||
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||
float x = ((float*) in)[blockIdx.x * hidden_size + idx];
|
||||
float x = ((T*) in)[blockIdx.x * hidden_size + idx];
|
||||
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue