from typing import Optional from typing import Tuple from typing import Union import bmtrain as bmt import torch from .attention import Attention from .feedforward import FeedForward from .layernorm import LayerNorm from .position_embedding import RotaryEmbedding from .position_embedding import RotaryEmbeddingESM class SelfAttentionBlock(bmt.DistributedModule): """The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection. Args: dim_model (int): main dimension of modules in transformer blocks. num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`. dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`. dtype (optional): Defaults to torch.half. eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. dropout_p (float, optional): Defaults to 0. """ # noqa: E501 def __init__( self, dim_model: int, num_heads: int, num_kv_heads: int, dim_head: int, dtype=torch.half, eps: float = 1e-5, dropout_p: Optional[float] = None, scale: bool = True, add_qkv_bias: bool = False, use_flash_attn: bool = False, tp: int = 0, ): super().__init__() self.layernorm_before_attention = LayerNorm( dim_model, dtype=dtype, eps=eps, ) self.self_attention = Attention( dim_model=dim_model, num_heads=num_heads, num_kv_heads=num_kv_heads, dim_head=dim_head, dtype=dtype, dropout_p=dropout_p, scale=scale, add_qkv_bias=add_qkv_bias, use_flash_attn=use_flash_attn, tp=tp, ) if dropout_p: self.dropout = torch.nn.Dropout(dropout_p) else: self.dropout = None def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor = None, position_bias: Union[torch.Tensor, RotaryEmbedding, RotaryEmbeddingESM] = None, use_cache: bool = False, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, pos_bias_type: Optional[str] = "relative", length_mask: Optional[torch.Tensor] = None, attention_mask_bias: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: int = None, position_ids: Optional[torch.Tensor] = None, ): """ Args: hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of self-attention block. It can be the embedding of a batch of sequences. attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation. position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block. Return: :obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of attention block. """ # noqa: E501 x = self.layernorm_before_attention(hidden_states) x = self.self_attention( x, x, attention_mask, position_bias, use_cache, past_key_value, pos_bias_type=pos_bias_type, length_mask=length_mask, attention_mask_bias=attention_mask_bias, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, position_ids=position_ids, ) if use_cache: x, current_key_value = x else: current_key_value = None if self.dropout is not None: x = self.dropout(x) hidden_states = hidden_states + x # / 1.05 if use_cache: return hidden_states, current_key_value else: return hidden_states class FFNBlock(torch.nn.Module): """The whole feed-forward block. A sequence of operation. Consists of layernorm, feed-forward and residual connection. Args: dim_model (int): main dimension of modules in transformer blocks. dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`. dtype (optional): Defaults to torch.half. eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. dropout_p (float, optional): Defaults to 0. """ # noqa: E501 def __init__( self, dim_model: int, dim_ff: int, activate_fn: str, dtype=torch.half, eps: float = 1e-6, dropout_p: Optional[float] = 0, scale: bool = True, tp: int = 0, ): super().__init__() self.layernorm_before_ffn = LayerNorm( dim_model, dtype=dtype, eps=eps, ) self.ffn = FeedForward( dim_model, dim_ff, activate_fn=activate_fn, dtype=dtype, dropout_p=dropout_p, scale=scale, tp=tp, ) if dropout_p: self.dropout = torch.nn.Dropout(dropout_p) else: self.dropout = None def forward( self, hidden_states: torch.Tensor, ): """ Args: hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Hidden states before feed forward layer. Return: :obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of feed-forward block """ # noqa: E501 x = self.layernorm_before_ffn(hidden_states) x = self.ffn(x) if self.dropout is not None: x = self.dropout(x) hidden_states = hidden_states + x # / 1.05 return hidden_states class TransformerBlock(torch.nn.Module): """The whole transformer block. A sequence of operation. Consists of self-attention block[, cross-attention block] and feed-forward block. Args: dim_model (int): main dimension of modules in transformer blocks. dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`. num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`. dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`. dtype (optional): Defaults to torch.half. eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. dropout_p (float, optional): Defaults to 0. """ # noqa: E501 def __init__( self, dim_model: int, dim_ff: int, num_heads: int, num_kv_heads: int, dim_head: int, activate_fn: str = "gelu", dtype=torch.half, eps: float = 1e-6, dropout_p: Optional[float] = None, scale: bool = True, add_qkv_bias: bool = False, mask_att: bool = False, mask_ffn: bool = False, use_flash_attn: bool = False, tp: int = 0, ): super().__init__() self.mask_att = mask_att self.mask_ffn = mask_ffn if not self.mask_att: self.self_att = SelfAttentionBlock( dim_model=dim_model, num_heads=num_heads, num_kv_heads=num_kv_heads, dim_head=dim_head, dtype=dtype, eps=eps, dropout_p=dropout_p, scale=scale, add_qkv_bias=add_qkv_bias, use_flash_attn=use_flash_attn, tp=tp, ) if not self.mask_ffn: self.ffn = FFNBlock( dim_model=dim_model, dim_ff=dim_ff, activate_fn=activate_fn, dtype=dtype, eps=eps, dropout_p=dropout_p, scale=scale, tp=tp, ) def forward( self, self_hidden_states: torch.Tensor, self_attention_mask: torch.Tensor = None, self_position_bias: Optional[torch.Tensor] = None, use_cache: bool = False, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, pos_bias_type: Optional[str] = "relative", length_mask: Optional[torch.Tensor] = None, attention_mask_bias: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, ): """ Args: self_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences. self_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation of self-attention. self_position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block. Return: :obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of transformer block. """ # noqa: E501 # (batch, dim_model, seq_self) current_key_value = None if not self.mask_att: hidden_states = self.self_att( self_hidden_states, attention_mask=self_attention_mask, position_bias=self_position_bias, use_cache=use_cache, past_key_value=past_key_value, pos_bias_type=pos_bias_type, length_mask=length_mask, attention_mask_bias=attention_mask_bias, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, position_ids=position_ids, ) if use_cache: hidden_states, current_key_value = hidden_states else: hidden_states = self_hidden_states # (batch, dim_model, seq_self) if not self.mask_ffn: hidden_states = self.ffn(hidden_states) if use_cache: return hidden_states, current_key_value else: return hidden_states