CPM-9G-8B/9G-Train/cpm/native_layers/linear.py

45 lines
1.3 KiB
Python
Raw Normal View History

2024-02-27 14:33:33 +08:00
import math
import torch
import torch.nn.functional as F
class Linear(torch.nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
scale: bool = True,
scale_before: bool = False,
):
super().__init__()
self.dim_in = self.in_features = dim_in
self.dim_out = self.out_features = dim_out
self.scale = scale
self.scale_before = scale_before
self.weight = torch.nn.parameter.Parameter(torch.empty((dim_out, dim_in), dtype=dtype))
torch.nn.init.normal_(self.weight, mean=init_mean, std=init_std)
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
""" # noqa: E501
if self.scale:
if self.scale_before:
x = x / math.sqrt(self.dim_in)
x = F.linear(x, self.weight)
else:
x = F.linear(x, self.weight)
x = x / math.sqrt(self.dim_in)
else:
x = F.linear(x, self.weight)
return x