CPM-9G-8B/9G-Train/cpm/utils/gradient_shrink.py

17 lines
382 B
Python

import torch
class OpGradientShrink(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, alpha: float):
ctx.alpha = alpha
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output * ctx.alpha, None
def gradient_shrink(x: torch.Tensor, alpha: float = 0.1):
return OpGradientShrink.apply(x, alpha)