OpenDeltaMirror/opendelta/delta_models/layers/init.py

9 lines
245 B
Python
Raw Normal View History

2022-02-14 21:19:03 +08:00
import torch
import math
def glorot_normal(tensor: torch.Tensor):
return torch.nn.init.xavier_normal_(tensor, gain=math.sqrt(2))
def glorot_uniform(tensor: torch.Tensor):
return torch.nn.init.xavier_uniform_(tensor, gain=math.sqrt(2))