CPM-9G-8B/9G-Train/cpm/native_layers/transformer.py

133 lines
5.1 KiB
Python
Raw Normal View History

2024-02-27 14:33:33 +08:00
from typing import List
from typing import Optional
from typing import Tuple
import torch
from .blocks import TransformerBlock
from .layernorm import LayerNorm
class Encoder(torch.nn.Module):
"""Layers of encoder transformer blocks plus an final layernorm.
Args:
num_layers (int): number of layers.
dim_model (int): main dimension of modules in transformer blocks.
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-6.
dropout_p (float, optional): Defaults to 0.
"""
def __init__(
self,
num_layers: int,
dim_model: int,
dim_ff: int,
num_heads: int,
dim_head: int,
num_kv_heads: int = -1,
activate_fn: str = "gelu",
dtype: torch.dtype = torch.half,
eps: float = 1e-6,
dropout_p: Optional[float] = None,
scale: bool = True,
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
use_flash_attn: bool = False,
):
super().__init__()
if num_kv_heads == -1:
num_kv_heads = num_heads
self.num_layers = num_layers
if mask_modules is not None:
assert len(mask_modules) == num_layers, "The total number of masks should equal to num_layers"
for mask_module in mask_modules:
assert len(mask_module) == 2, "For encoder, each mask should be (mask_att, mask_ffn)"
else:
mask_modules = [(False, False)] * num_layers
self.layers = torch.nn.ModuleList(
[
TransformerBlock(
dim_model=dim_model,
dim_ff=dim_ff,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim_head=dim_head,
activate_fn=activate_fn,
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
scale=scale,
mask_att=mask_modules[ith][0],
mask_ffn=mask_modules[ith][1],
use_flash_attn=use_flash_attn,
)
for ith in range(num_layers)
]
)
self.output_layernorm = LayerNorm(dim_norm=dim_model, dtype=dtype, eps=eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_bias: torch.Tensor,
use_cache: bool = False,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
):
"""
Args:
hidden-states (:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``): Input of encoder, might be the embedding of a batch of sequences.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_enc, seq_enc)``): Avoid invalid areas to participate in the calculation
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, seq_enc, seq_enc)``) Provides position information to attention mechanism.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``: The encoder output.
"""
if not use_cache:
for layer in self.layers:
hidden_states = layer(
hidden_states,
attention_mask,
position_bias,
pos_bias_type=pos_bias_type,
length_mask=length_mask,
context_mask=context_mask,
)
hidden_states = self.output_layernorm(hidden_states)
return hidden_states
else:
with torch.no_grad():
current_key_values = []
current_hidden_states = []
for i, module in enumerate(self.layers):
hidden_states = module(
hidden_states,
attention_mask,
position_bias,
past_key_value=past_key_values[i] if past_key_values else None,
use_cache=use_cache,
pos_bias_type=pos_bias_type,
length_mask=length_mask,
context_mask=context_mask,
)
if use_cache:
current_key_values.append(hidden_states[1])
current_hidden_states.append(hidden_states[0])
hidden_states = hidden_states[0]
hidden_states = self.output_layernorm(hidden_states)
if use_cache:
return hidden_states, current_key_values, current_hidden_states
else:
return hidden_states