forked from jiuyuan/CPM-9G-8B
17 lines
382 B
Python
17 lines
382 B
Python
|
import torch
|
||
|
|
||
|
|
||
|
class OpGradientShrink(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx, x: torch.Tensor, alpha: float):
|
||
|
ctx.alpha = alpha
|
||
|
return x
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
return grad_output * ctx.alpha, None
|
||
|
|
||
|
|
||
|
def gradient_shrink(x: torch.Tensor, alpha: float = 0.1):
|
||
|
return OpGradientShrink.apply(x, alpha)
|