fix bugs in rmsnorm op

This commit is contained in:
xiaonans 2024-02-20 10:50:47 +08:00
parent 0f1c04d864
commit 83be7fa373
3 changed files with 6 additions and 5 deletions

View File

@ -156,8 +156,8 @@ struct OpType {
Resize, Resize,
ReverseSequence, ReverseSequence,
RoiAlign, RoiAlign,
RoPE, // Fusion RoPE, // Fusion
Round, // Unary Round, // Unary
RMSNorm, // Fusion RMSNorm, // Fusion
STFT, STFT,
Scan, Scan,

View File

@ -125,7 +125,8 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) { Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
if (output) { if (output) {
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight), output); g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight),
output);
return output; return output;
} else { } else {
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output) return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)

View File

@ -38,7 +38,7 @@ __global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_token
float variance = 0.0f; float variance = 0.0f;
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){ for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
const float x = ((float*) in)[blockIdx.x * hidden_size + idx]; const float x = ((T*) in)[blockIdx.x * hidden_size + idx];
variance += x * x; variance += x * x;
} }
variance = blockReduceSum<float>(variance); variance = blockReduceSum<float>(variance);
@ -48,7 +48,7 @@ __global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_token
__syncthreads(); __syncthreads();
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){ for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
float x = ((float*) in)[blockIdx.x * hidden_size + idx]; float x = ((T*) in)[blockIdx.x * hidden_size + idx];
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx]; ((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
} }
} }