CPM-9G-8B/9G-Train/cpm/utils/gradient_shrink.py

17 lines
382 B
Python
Raw Permalink Normal View History

2024-02-27 14:33:33 +08:00
import torch
class OpGradientShrink(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, alpha: float):
ctx.alpha = alpha
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output * ctx.alpha, None
def gradient_shrink(x: torch.Tensor, alpha: float = 0.1):
return OpGradientShrink.apply(x, alpha)