diff --git a/include/core/op_type.h b/include/core/op_type.h index 3b4045c4..e624877b 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -156,8 +156,8 @@ struct OpType { Resize, ReverseSequence, RoiAlign, - RoPE, // Fusion - Round, // Unary + RoPE, // Fusion + Round, // Unary RMSNorm, // Fusion STFT, Scan, diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 596910f1..0afc5ed9 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -125,7 +125,8 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale, Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) { if (output) { - g->addOpWithOutputs(std::move(input), std::move(weight), output); + g->addOpWithOutputs(std::move(input), std::move(weight), + output); return output; } else { return g->addOp(std::move(input), std::move(weight), output) diff --git a/src/kernels/cuda/rms_norm.cu b/src/kernels/cuda/rms_norm.cu index 9eca738f..530a42ce 100644 --- a/src/kernels/cuda/rms_norm.cu +++ b/src/kernels/cuda/rms_norm.cu @@ -38,7 +38,7 @@ __global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_token float variance = 0.0f; for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){ - const float x = ((float*) in)[blockIdx.x * hidden_size + idx]; + const float x = ((T*) in)[blockIdx.x * hidden_size + idx]; variance += x * x; } variance = blockReduceSum(variance); @@ -48,7 +48,7 @@ __global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_token __syncthreads(); for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){ - float x = ((float*) in)[blockIdx.x * hidden_size + idx]; + float x = ((T*) in)[blockIdx.x * hidden_size + idx]; ((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx]; } }