CPM-9G-8B/9G-Train/cpm/layers/blocks.py

301 lines
10 KiB
Python

from typing import Optional
from typing import Tuple
from typing import Union
import bmtrain as bmt
import torch
from .attention import Attention
from .feedforward import FeedForward
from .layernorm import LayerNorm
from .position_embedding import RotaryEmbedding
from .position_embedding import RotaryEmbeddingESM
class SelfAttentionBlock(bmt.DistributedModule):
"""The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection.
Args:
dim_model (int): main dimension of modules in transformer blocks.
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
dim_model: int,
num_heads: int,
num_kv_heads: int,
dim_head: int,
dtype=torch.half,
eps: float = 1e-5,
dropout_p: Optional[float] = None,
scale: bool = True,
add_qkv_bias: bool = False,
use_flash_attn: bool = False,
tp: int = 0,
):
super().__init__()
self.layernorm_before_attention = LayerNorm(
dim_model,
dtype=dtype,
eps=eps,
)
self.self_attention = Attention(
dim_model=dim_model,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim_head=dim_head,
dtype=dtype,
dropout_p=dropout_p,
scale=scale,
add_qkv_bias=add_qkv_bias,
use_flash_attn=use_flash_attn,
tp=tp,
)
if dropout_p:
self.dropout = torch.nn.Dropout(dropout_p)
else:
self.dropout = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None,
position_bias: Union[torch.Tensor, RotaryEmbedding, RotaryEmbeddingESM] = None,
use_cache: bool = False,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
attention_mask_bias: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: int = None,
position_ids: Optional[torch.Tensor] = None,
):
"""
Args:
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of self-attention block. It can be the embedding of a batch of sequences.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation.
position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of attention block.
""" # noqa: E501
x = self.layernorm_before_attention(hidden_states)
x = self.self_attention(
x,
x,
attention_mask,
position_bias,
use_cache,
past_key_value,
pos_bias_type=pos_bias_type,
length_mask=length_mask,
attention_mask_bias=attention_mask_bias,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
if use_cache:
x, current_key_value = x
else:
current_key_value = None
if self.dropout is not None:
x = self.dropout(x)
hidden_states = hidden_states + x # / 1.05
if use_cache:
return hidden_states, current_key_value
else:
return hidden_states
class FFNBlock(torch.nn.Module):
"""The whole feed-forward block. A sequence of operation. Consists of layernorm, feed-forward and residual connection.
Args:
dim_model (int): main dimension of modules in transformer blocks.
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
dim_model: int,
dim_ff: int,
activate_fn: str,
dtype=torch.half,
eps: float = 1e-6,
dropout_p: Optional[float] = 0,
scale: bool = True,
tp: int = 0,
):
super().__init__()
self.layernorm_before_ffn = LayerNorm(
dim_model,
dtype=dtype,
eps=eps,
)
self.ffn = FeedForward(
dim_model,
dim_ff,
activate_fn=activate_fn,
dtype=dtype,
dropout_p=dropout_p,
scale=scale,
tp=tp,
)
if dropout_p:
self.dropout = torch.nn.Dropout(dropout_p)
else:
self.dropout = None
def forward(
self,
hidden_states: torch.Tensor,
):
"""
Args:
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Hidden states before feed forward layer.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of feed-forward block
""" # noqa: E501
x = self.layernorm_before_ffn(hidden_states)
x = self.ffn(x)
if self.dropout is not None:
x = self.dropout(x)
hidden_states = hidden_states + x # / 1.05
return hidden_states
class TransformerBlock(torch.nn.Module):
"""The whole transformer block. A sequence of operation. Consists of self-attention block[, cross-attention block] and feed-forward block.
Args:
dim_model (int): main dimension of modules in transformer blocks.
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
dim_model: int,
dim_ff: int,
num_heads: int,
num_kv_heads: int,
dim_head: int,
activate_fn: str = "gelu",
dtype=torch.half,
eps: float = 1e-6,
dropout_p: Optional[float] = None,
scale: bool = True,
add_qkv_bias: bool = False,
mask_att: bool = False,
mask_ffn: bool = False,
use_flash_attn: bool = False,
tp: int = 0,
):
super().__init__()
self.mask_att = mask_att
self.mask_ffn = mask_ffn
if not self.mask_att:
self.self_att = SelfAttentionBlock(
dim_model=dim_model,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim_head=dim_head,
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
scale=scale,
add_qkv_bias=add_qkv_bias,
use_flash_attn=use_flash_attn,
tp=tp,
)
if not self.mask_ffn:
self.ffn = FFNBlock(
dim_model=dim_model,
dim_ff=dim_ff,
activate_fn=activate_fn,
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
scale=scale,
tp=tp,
)
def forward(
self,
self_hidden_states: torch.Tensor,
self_attention_mask: torch.Tensor = None,
self_position_bias: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
attention_mask_bias: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
):
"""
Args:
self_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
self_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation of self-attention.
self_position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of transformer block.
""" # noqa: E501
# (batch, dim_model, seq_self)
current_key_value = None
if not self.mask_att:
hidden_states = self.self_att(
self_hidden_states,
attention_mask=self_attention_mask,
position_bias=self_position_bias,
use_cache=use_cache,
past_key_value=past_key_value,
pos_bias_type=pos_bias_type,
length_mask=length_mask,
attention_mask_bias=attention_mask_bias,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
if use_cache:
hidden_states, current_key_value = hidden_states
else:
hidden_states = self_hidden_states
# (batch, dim_model, seq_self)
if not self.mask_ffn:
hidden_states = self.ffn(hidden_states)
if use_cache:
return hidden_states, current_key_value
else:
return hidden_states