forked from jiuyuan/InfiniTensor
fix bugs in rmsnorm op
This commit is contained in:
parent
0f1c04d864
commit
83be7fa373
|
@ -156,8 +156,8 @@ struct OpType {
|
||||||
Resize,
|
Resize,
|
||||||
ReverseSequence,
|
ReverseSequence,
|
||||||
RoiAlign,
|
RoiAlign,
|
||||||
RoPE, // Fusion
|
RoPE, // Fusion
|
||||||
Round, // Unary
|
Round, // Unary
|
||||||
RMSNorm, // Fusion
|
RMSNorm, // Fusion
|
||||||
STFT,
|
STFT,
|
||||||
Scan,
|
Scan,
|
||||||
|
|
|
@ -125,7 +125,8 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
|
||||||
|
|
||||||
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
|
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
|
||||||
if (output) {
|
if (output) {
|
||||||
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight), output);
|
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight),
|
||||||
|
output);
|
||||||
return output;
|
return output;
|
||||||
} else {
|
} else {
|
||||||
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)
|
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)
|
||||||
|
|
|
@ -38,7 +38,7 @@ __global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_token
|
||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||||
const float x = ((float*) in)[blockIdx.x * hidden_size + idx];
|
const float x = ((T*) in)[blockIdx.x * hidden_size + idx];
|
||||||
variance += x * x;
|
variance += x * x;
|
||||||
}
|
}
|
||||||
variance = blockReduceSum<float>(variance);
|
variance = blockReduceSum<float>(variance);
|
||||||
|
@ -48,7 +48,7 @@ __global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_token
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||||
float x = ((float*) in)[blockIdx.x * hidden_size + idx];
|
float x = ((T*) in)[blockIdx.x * hidden_size + idx];
|
||||||
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
|
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue