CPM-9G-8B/FM_9G/fm9g/dragonfly/modeling_dragonfly.py

1171 lines
42 KiB
Python

# typing: strict
# coding=utf-8
# Copyright 2024 QiYuan Inc.
import math
from typing import Optional
from typing import Tuple
import bmtrain as bmt
import torch
import torch.nn.functional as F
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from .configuration_dragonfly import DragonflyConfig # from fm9g.utils import Config
# TODO:
# 1. add scale_emb to embed and layernorm
# 2. add scale_width to all layers
# 3. add scale_depth to residual
class ScaledRotaryEmbeddingESM(bmt.DistributedModule):
"""
Rotary position embeddings based on those in
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
matrices which depend on their relative positions.
Add multiple Positional Interpolation methods:
+ [Linear](http://arxiv.org/abs/2306.15595)
+ [NTK-aware](https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/)
+ [Dynamic Scaling](https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/)
+ [NTK-by-parts](https://github.com/jquesnelle/yarn/pull/1)
+ [YaRN](http://arxiv.org/abs/2309.00071)
Args:
dim: Dimension of the input, attn_dim // n_heads.
max_position_embeddings: Maximum number of positions to be embedded.
base: Base of the positional encoding function.
pose_prob: Probability of using PoSE.
pose_scaling_factor: max_position_embeddings scaling factor for PoSE.
scaling_type: Type of scaling to use, one of ["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", "Dynamic YaRN", ""].
rope_scaling_factor: RoPE Scaling factor for scaling type, new max length / before extend max length.
beta_fast: Number of rotations to use for fast angular velocity.
beta_slow: Number of rotations to use for slow angular velocity.
extrapolation_factor: [0, 1], 0 is fully extrapolation, 1 is fully NTK-by-parts/YaRN.
attn_factor: Uniform attn scale factor for tuning YaRN, 1 is best for LLaMA-1/2.
"""
def __init__(
self,
dim: int,
max_position_embeddings: int = 2048,
base: int = 10000,
pose_prob: float = 0.0,
pose_scaling_factor: float = 1.0,
scaling_type: str = "",
rope_scaling_factor: float = 1.0,
beta_fast: float = 32.0,
beta_slow: float = 1.0,
extrapolation_factor: int = 1,
attn_factor: int = 1,
original_max_position_embeddings: int = 2048,
persistent: bool = True,
dynamic_scaling_seq_len: int = 512,
device=None,
):
assert scaling_type in ["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""]
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.persistent = persistent
self.device = device
# scaling config
self.scaling_type = scaling_type
self.pose_scaling_factor = pose_scaling_factor
self.rope_scaling_factor = rope_scaling_factor
# PoSE
self.pose_prob = pose_prob
# NTK-by-parts and YaRN args
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.original_max_position_embeddings = original_max_position_embeddings
if pose_prob > 0:
self.scaled_max_position_embeddings = int(max_position_embeddings * pose_scaling_factor)
else:
self.scaled_max_position_embeddings = max_position_embeddings
if self.scaling_type == "NTK-aware":
base = self.base * (self.rope_scaling_factor ** (self.dim / (self.dim - 2)))
else:
base = self.base
# TODO: Implement base NTK-aware in NTK-by-parts
if self.scaling_type in ["NTK-by-parts", "YaRN"]:
self._ntk_parts_update_inv_freq(self.scaled_max_position_embeddings)
else:
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(self.device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Get n-d magnitude scaling corrected for interpolation
self.m_scale = float(self._get_m_scale(self.rope_scaling_factor) * self.attn_factor)
self._set_cos_sin_cache(dynamic_scaling_seq_len)
def _get_m_scale(self, scale=1.0):
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
def _ntk_parts_update_inv_freq(self, seq_len):
# Inverse dim formula to find dim based on number of rotations
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
# Find dim range bounds based on rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(self.device) / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (self.rope_scaling_factor * pos_freqs)
low, high = find_correction_range(
self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(self.device)
) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=self.persistent)
def _set_cos_sin_cache(self, seq_len, device=None):
self.max_seq_len_cached = seq_len
if device is not None:
self.device = device
if self.scaling_type == "Dynamic NTK" and seq_len > self.max_position_embeddings:
base = self.base * (
(self.rope_scaling_factor * seq_len / self.max_position_embeddings) - (self.rope_scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(self.device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=self.persistent)
t = torch.arange(self.max_seq_len_cached, device=self.device).type_as(self.inv_freq)
if self.scaling_type == "Linear":
freqs = torch.outer(t / self.rope_scaling_factor, self.inv_freq.to(device=t.device).to(t.dtype))
else:
freqs = torch.outer(t, self.inv_freq.to(device=t.device).to(t.dtype))
emb = torch.cat((freqs, freqs), dim=-1)
if self.scaling_type == "YaRN":
self.register_buffer("cos_cached", (emb.cos() * self.m_scale), persistent=self.persistent)
self.register_buffer("sin_cached", (emb.sin() * self.m_scale), persistent=self.persistent)
else:
self.register_buffer("cos_cached", emb.cos(), persistent=self.persistent)
self.register_buffer("sin_cached", emb.sin(), persistent=self.persistent)
def _rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary_pos_emb(self, q, k, cos, sin, position_ids) -> Tuple[torch.Tensor, torch.Tensor]:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
orig_dtype = k.dtype
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_fp32 = q.to(dtype=torch.float32, device=q.device)
k_fp32 = k.to(dtype=torch.float32, device=k.device)
q_embed = (q_fp32 * cos) + (self._rotate_half(q_fp32) * sin)
k_embed = (k_fp32 * cos) + (self._rotate_half(k_fp32) * sin)
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
def forward(
self, q: torch.Tensor, k: torch.Tensor, seq_dim, offset=0, cu_seqlens=None, max_length=None, position_ids=None
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_dim = (seq_dim + k.dim()) % k.dim()
# get max current seq len from all workers
if self.pose_prob > 0.0:
seq_len = torch.max(position_ids) + 1
else:
seq_len = k.size(seq_dim) + offset
seq_len_tensor = torch.tensor(seq_len, device=self.device)
seq_len_tensor_reduced = bmt.distributed.all_reduce(seq_len_tensor, op="max")
seq_len_reduced = seq_len_tensor_reduced.item()
# update cache if needed
if seq_len_reduced > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
cos, sin = (
self.cos_cached[:seq_len_reduced],
self.sin_cached[:seq_len_reduced],
)
if position_ids.dtype != torch.long: # 231108 input is int32
position_ids = position_ids.to(dtype=torch.long)
if cu_seqlens is None:
q_embed, k_embed = self._apply_rotary_pos_emb(q, k, cos, sin, position_ids)
else:
assert offset == 0, "past kv is not supported in flash attn"
q_embed, k_embed = self._apply_rotary_pos_emb(q, k, cos, sin, position_ids.view(-1))
return q_embed, k_embed
def Linear(*args, **kwargs):
tp = kwargs.pop("tp", 0)
if tp == 0:
return NormalLinear(*args, **kwargs)
if tp == 1:
return ColumnParallelLinear(*args, **kwargs)
if tp == 2:
return RowParallelLinear(*args, **kwargs)
class NormalLinear(bmt.DistributedModule):
def __init__(
self,
dim_in: int,
dim_out: int,
dtype: torch.dtype = torch.bfloat16,
init_mean: float = 0.0,
init_std: float = 0.02,
):
super().__init__()
self.dim_in = self.in_features = dim_in
self.dim_out = self.out_features = dim_out
# TODO:init
# init_std = 1 / ((dim_in + dim_out) ** 0.5)
self.weight = bmt.DistributedParameter(
torch.empty((dim_out, dim_in), dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
)
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
""" # noqa: E501
x = F.linear(x, self.weight, None)
return x
class ColumnParallelLinear(bmt.DistributedModule):
def __init__(
self,
dim_in: int,
dim_out: int,
dtype: torch.dtype = torch.bfloat16,
init_mean: float = 0.0,
init_std: float = 0.02,
gather_output=False,
gather_input=True,
):
super().__init__()
assert dim_out % bmt.config["tp_size"] == 0
# TODO: init
# init_std = 1 / ((dim_in + dim_out) ** 0.5)
dim_out = dim_out // bmt.config["tp_size"]
self.dim_in = self.in_features = dim_in
self.dim_out = self.out_features = dim_out
self.gather_input = gather_input
self.gather_output = gather_output
self.weight = bmt.DistributedParameter(
torch.empty((dim_out, dim_in), dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
tp_split_dim=0,
tp_mode=True,
)
self.bias = None
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
""" # noqa: E501
x = bmt.nn.OpParallelLinear.apply(
x, self.weight, self.bias, self.gather_input, self.gather_output, False, None, 1
)
return x
class RowParallelLinear(bmt.DistributedModule):
def __init__(
self,
dim_in: int,
dim_out: int,
dtype: torch.dtype = torch.bfloat16,
init_mean: float = 0.0,
init_std: float = 0.02,
split_input=False,
all_reduce_output=False,
):
super().__init__()
assert dim_in % bmt.config["tp_size"] == 0
# init_std = 1 / ((dim_in + dim_out) ** 0.5)
dim_in = dim_in // bmt.config["tp_size"]
self.dim_in = self.in_features = dim_in
self.dim_out = self.out_features = dim_out
self.split_input = split_input
self.all_reduce_output = all_reduce_output
self.weight = bmt.DistributedParameter(
torch.empty((dim_out, dim_in), dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
tp_split_dim=1,
tp_mode=True,
)
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
""" # noqa: E501
if not self.all_reduce_output:
x = x.view(x.shape[0] * bmt.config["tp_size"], -1, x.shape[-1])
x = bmt.nn.OpParallelLinear.apply(
x, self.weight, None, self.split_input, False, self.split_input, 1 if self.all_reduce_output else 2, 1
)
return x
@torch.jit.script
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
old_dtype = hidden.dtype
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
return hidden * weight
class LayerNorm(bmt.DistributedModule):
"""RMS LayerNorm"""
def __init__(
self,
dim_norm: int,
dtype: torch.dtype = torch.bfloat16,
eps: float = 1e-6,
init_var: float = 1.0,
):
super().__init__()
self.eps = eps
self.dim_norm = dim_norm
self.weight = bmt.DistributedParameter(torch.full((dim_norm,), init_var, dtype=dtype))
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``): Input tensor that need to be normalized.
Return:
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``: The layernorm output.
""" # noqa: E501
assert x.size(-1) == self.dim_norm
return rms_layernorm(x, self.weight, self.eps)
class DenseGatedACT(bmt.DistributedModule):
def __init__(
self,
dim_in: int,
dim_ff: int,
activate_fn: str = "silu",
dtype=torch.bfloat16,
tp: int = 0,
scale: bool = False,
init_std: float = 0.02,
scale_width: float = 1.0,
):
super().__init__()
_std = init_std / math.sqrt(scale_width) if scale else init_std
self.w_0 = Linear(
dim_in=dim_in,
dim_out=dim_ff,
dtype=dtype,
tp=tp,
init_std=_std,
)
self.w_1 = Linear(dim_in=dim_in, dim_out=dim_ff, dtype=dtype, tp=tp, init_std=_std)
if activate_fn == "gelu":
self.act = torch.nn.GELU()
elif activate_fn == "silu":
self.act = torch.nn.functional.silu
else:
raise NotImplementedError(f"{activate_fn} is not supported")
def forward(self, x: torch.Tensor):
"""This model inherits from bmt.DistributedModule.
Transform an input tensor from one feature space to another via a nonlinear operation
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): Tensor that will be subject to nonlinear operations.
Return:
out (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_ff)``)
""" # noqa: E501
gate_score = self.act(self.w_0(x))
x = self.w_1(x)
x = gate_score * x
return x
class FeedForward(bmt.DistributedModule):
r"""FeedForward module
Args:
dim_in (int): input dimension.
dim_ff (int): middle dimension.
dim_out (int, optional): output dimension. Defaults to None, which means dim_in = dim_out.
dtype (optional): Defaults to torch.bfloat16.
init_mean (float, optional): mean of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.
init_std (float, optional): std of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.02.
bias (bool, optional): whether to use bias term in fully-connected layers used in feed-forward module. Defaults to False.
activate_fn (str, optional): Defaults to `gated_gelu`.
dropout_p (int, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
dim_model: int,
dim_ff: int,
activate_fn: str = "silu",
dtype=torch.bfloat16,
dropout_p: Optional[float] = None,
tp: int = 0,
scale: bool = False,
init_std: float = 0.02,
scale_width: float = 1.0,
):
super().__init__()
self.w_in = DenseGatedACT(
dim_in=dim_model,
dim_ff=dim_ff,
activate_fn=activate_fn,
dtype=dtype,
scale=scale,
init_std=init_std,
scale_width=scale_width,
)
if dropout_p is not None:
self.dropout = torch.nn.Dropout(dropout_p)
else:
self.dropout = None
_std = init_std / math.sqrt(scale_width) if scale else init_std
self.w_out = Linear(dim_in=dim_ff, dim_out=dim_model, dtype=dtype, init_std=_std)
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of feed-forward module.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of feed-forward module.
""" # noqa: E501
x = self.w_in(x)
if self.dropout is not None:
x = self.dropout(x)
x = self.w_out(x)
return x
class Embedding(bmt.DistributedModule):
def __init__(
self,
vocab_size: int,
embedding_size: int,
dtype: torch.dtype = torch.bfloat16,
init_mean: float = 0.0,
init_std: float = 1,
scale: bool = False,
scale_emb: float = 1.0,
scale_width: float = 1.0,
tp: int = 0,
):
super().__init__()
self.dim_model = embedding_size
self.weight = bmt.DistributedParameter(
torch.empty(vocab_size, embedding_size, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
)
self.tp = tp
self.scale = scale
self.scale_emb = scale_emb
self.scale_width = scale_width
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
Return:
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
""" # noqa: E501
if self.tp:
x = x.view(-1).chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]].view(x.size(0), -1)
embeds = F.embedding(x, self.weight)
if self.scale:
embeds = embeds * self.scale_emb
return embeds
def projection(self, x: torch.Tensor):
"""
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
""" # noqa: E501
if self.scale:
x = x / self.scale_width # TODO: check if it is ok to add before all_gather
logits = F.linear(x, self.weight)
return logits
class Attention(bmt.DistributedModule):
def __init__(
self,
dim_model: int,
num_heads: int,
num_kv_heads: int,
dim_head: int,
dtype: torch.dtype = torch.bfloat16,
dropout_p: Optional[float] = None,
tp: int = 0,
scale: bool = False,
init_std: float = 0.02,
scale_width: float = 1.0,
qk_norm: bool = False,
) -> None:
super().__init__()
self.dim_model = dim_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_groups = num_heads // num_kv_heads
self.dim_head = dim_head
self.scale = scale
_std = init_std / math.sqrt(scale_width) if scale else init_std
self.project_q = Linear(
self.dim_model,
self.num_heads * self.dim_head,
dtype=dtype,
tp=tp,
init_std=_std,
)
self.project_k = Linear(
self.dim_model,
self.num_kv_heads * self.dim_head,
dtype=dtype,
tp=tp,
init_std=_std,
)
self.project_v = Linear(
self.dim_model,
self.num_kv_heads * self.dim_head,
dtype=dtype,
tp=tp,
init_std=_std,
)
self.attention_out = Linear(
self.num_heads * self.dim_head,
self.dim_model,
dtype=dtype,
tp=tp * 2, # TODO
init_std=_std,
)
if dropout_p is not None:
self.dropout = torch.nn.Dropout(p=dropout_p)
self.dropout_p = dropout_p
else:
self.dropout = None
self.tp = tp
def forward(
self,
hidden_q: torch.Tensor,
hidden_kv: torch.Tensor,
position_bias: torch.Tensor, # TODO
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: int = None,
position_ids: Optional[torch.Tensor] = None,
):
"""This model inherits from bmt.DistributedModule.
Args:
hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices.
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`.
Return:
out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output.
""" # noqa: E501
batch_size = hidden_q.size(0)
if self.tp:
assert hidden_q.data_ptr() == hidden_kv.data_ptr()
hidden_q = bmt.nn.OpParallelLinear.apply(
hidden_q,
torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0),
torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0)
if self.project_q.bias is not None
else None,
True,
False,
False,
None,
1,
)
hidden_q = hidden_q.view(batch_size, -1, hidden_q.shape[-1])
block_size = hidden_q.shape[-1] // (self.head_groups + 1 + 1)
h_q = hidden_q[..., : block_size * self.head_groups]
h_k = hidden_q[..., block_size * self.head_groups : block_size * (self.head_groups + 1)]
h_v = hidden_q[..., block_size * (self.head_groups + 1) :]
else:
h_q = self.project_q(hidden_q)
h_k = self.project_k(hidden_kv)
h_v = self.project_v(hidden_kv)
len_q = h_q.size(1)
len_k = h_k.size(1)
h_q = h_q.view(batch_size * len_q, -1, self.dim_head)
h_k = h_k.view(batch_size * len_k, -1, self.dim_head)
h_v = h_v.view(batch_size * len_k, -1, self.dim_head)
h_q, h_k = position_bias(h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids)
score = flash_attn_varlen_func(
h_q,
h_k,
h_v,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
self.dropout_p,
causal=True,
deterministic=True,
)
#print("DEBUG! use flash!!!!!! ARQ")
score = score.view(batch_size, len_q, -1)
score = self.attention_out(score)
return score
class SelfAttentionBlock(bmt.DistributedModule):
"""The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection.
Args:
dim_model (int): main dimension of modules in transformer blocks.
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
dtype (optional): Defaults to torch.bfloat16.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
dim_model: int,
num_heads: int,
num_kv_heads: int,
dim_head: int,
dtype=torch.bfloat16,
eps: float = 1e-6,
dropout_p: Optional[float] = None,
tp: int = 0,
scale: bool = False,
init_std: float = 0.02,
scale_width: float = 1.0,
scale_depth: float = -1,
qk_norm: bool = False,
layer_id: int = 0,
num_layers: int = 0,
):
super().__init__()
self.layernorm_before_attention = LayerNorm(
dim_model,
dtype=dtype,
eps=eps,
)
self.self_attention = Attention(
dim_model=dim_model,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim_head=dim_head,
dtype=dtype,
dropout_p=dropout_p,
tp=tp,
scale=scale,
init_std=init_std,
scale_width=scale_width,
qk_norm=qk_norm,
)
if dropout_p:
self.dropout = torch.nn.Dropout(dropout_p)
else:
self.dropout = None
self.scale = scale
self.scale_depth = scale_depth
self.num_layers = num_layers
def forward(
self,
hidden_states: torch.Tensor,
position_bias: ScaledRotaryEmbeddingESM,
cu_seqlens: torch.Tensor,
max_seqlen: int = None,
position_ids: Optional[torch.Tensor] = None,
):
"""
Args:
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of self-attention block. It can be the embedding of a batch of sequences.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation.
position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of attention block.
""" # noqa: E501
x = self.layernorm_before_attention(hidden_states)
x = self.self_attention(
x,
x,
position_bias,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
if self.dropout is not None:
x = self.dropout(x)
if self.scale_depth > 0:
hidden_states = hidden_states + x * (
self.scale_depth / math.sqrt(self.num_layers)
) # https://arxiv.org/pdf/2310.02244.pdf
else:
hidden_states = hidden_states + x
return hidden_states
class FFNBlock(torch.nn.Module):
"""The whole feed-forward block. A sequence of operation. Consists of layernorm, feed-forward and residual connection.
Args:
dim_model (int): main dimension of modules in transformer blocks.
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
dtype (optional): Defaults to torch.bfloat16.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
dim_model: int,
dim_ff: int,
activate_fn: str,
dtype=torch.bfloat16,
eps: float = 1e-6,
dropout_p: Optional[float] = 0,
tp: int = 0,
scale: bool = False,
init_std: float = 0.02,
scale_width: float = 1.0,
scale_depth: float = -1,
layer_id: int = 0,
num_layers: int = 0,
):
super().__init__()
self.layernorm_before_ffn = LayerNorm(
dim_model,
dtype=dtype,
eps=eps,
)
self.ffn = FeedForward(
dim_model,
dim_ff,
activate_fn=activate_fn,
dtype=dtype,
dropout_p=dropout_p,
tp=tp,
scale=scale,
init_std=init_std,
scale_width=scale_width,
)
if dropout_p:
self.dropout = torch.nn.Dropout(dropout_p)
else:
self.dropout = None
self.scale = scale
self.scale_depth = scale_depth
self.num_layers = num_layers
def forward(
self,
hidden_states: torch.Tensor,
):
"""
Args:
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Hidden states before feed forward layer.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of feed-forward block
""" # noqa: E501
x = self.layernorm_before_ffn(hidden_states)
x = self.ffn(x)
if self.dropout is not None:
x = self.dropout(x)
if self.scale_depth > 0:
hidden_states = hidden_states + x.view_as(hidden_states) * (
self.scale_depth / math.sqrt(self.num_layers)
) # https://arxiv.org/pdf/2310.02244.pdf
else:
hidden_states = hidden_states + x.view_as(hidden_states)
return hidden_states
class TransformerBlock(torch.nn.Module):
"""The whole transformer block. A sequence of operation. Consists of self-attention block[, cross-attention block] and feed-forward block.
Args:
dim_model (int): main dimension of modules in transformer blocks.
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
dtype (optional): Defaults to torch.bfloat16.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
dim_model: int,
dim_ff: int,
num_heads: int,
num_kv_heads: int,
dim_head: int,
activate_fn: str = "silu",
dtype=torch.bfloat16,
eps: float = 1e-6,
dropout_p: Optional[float] = None,
tp: int = 0,
scale: bool = False,
init_std: float = 0.02,
scale_width: float = 1.0,
scale_depth: float = -1,
qk_norm: bool = False,
layer_id: int = 0,
num_layers: int = 0,
):
super().__init__()
self.self_att = SelfAttentionBlock(
dim_model=dim_model,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim_head=dim_head,
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
tp=tp,
scale=scale,
init_std=init_std,
scale_width=scale_width,
scale_depth=scale_depth,
qk_norm=qk_norm,
layer_id=layer_id,
num_layers=num_layers,
)
self.ffn = FFNBlock(
dim_model=dim_model,
dim_ff=dim_ff,
activate_fn=activate_fn,
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
tp=tp,
scale=scale,
init_std=init_std,
scale_width=scale_width,
scale_depth=scale_depth,
layer_id=layer_id,
num_layers=num_layers,
)
def forward(
self,
self_hidden_states: torch.Tensor,
self_position_bias: Optional[torch.Tensor] = None, # TODO
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
):
"""
Args:
self_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
self_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation of self-attention.
self_position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of transformer block.
""" # noqa: E501
# (batch, dim_model, seq_self)
hidden_states = self.self_att(
self_hidden_states,
position_bias=self_position_bias,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
# (batch, dim_model, seq_self)
hidden_states = self.ffn(hidden_states)
return hidden_states
class Encoder(bmt.DistributedModule):
"""Layers of encoder transformer blocks plus an final layernorm.
Args:
num_layers (int): number of layers.
dim_model (int): main dimension of modules in transformer blocks.
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
dtype (optional): Defaults to torch.bfloat16.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
num_layers: int,
dim_model: int,
dim_ff: int,
num_heads: int,
dim_head: int,
num_kv_heads: int = -1,
activate_fn: str = "silu",
dtype: torch.dtype = torch.bfloat16,
eps: float = 1e-6,
dropout_p: Optional[float] = None,
tp: int = 0,
scale: bool = False,
init_std: float = 0.02,
scale_width: float = 1.0,
scale_depth: float = -1,
qk_norm: bool = False,
use_checkpoint: bool = True,
):
super().__init__()
if num_kv_heads == -1:
num_kv_heads = num_heads
self.num_layers = num_layers
self.layers = bmt.TransformerBlockList(
[
bmt.CheckpointBlock(
TransformerBlock(
dim_model=dim_model,
dim_ff=dim_ff,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim_head=dim_head,
activate_fn=activate_fn,
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
tp=tp,
scale=scale,
init_std=init_std,
scale_width=scale_width,
scale_depth=scale_depth,
qk_norm=qk_norm,
layer_id=layer_id,
num_layers=num_layers,
),
use_checkpoint=use_checkpoint
)
for layer_id in range(num_layers)
]
)
self.output_layernorm = LayerNorm(dim_norm=dim_model, dtype=dtype, eps=eps)
def forward(
self,
hidden_states: torch.Tensor,
position_bias: torch.Tensor = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
):
"""
Args:
hidden-states (:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``): Input of encoder, might be the embedding of a batch of sequences.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_enc, seq_enc)``): Avoid invalid areas to participate in the calculation
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, seq_enc, seq_enc)``) Provides position information to attention mechanism.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``: The encoder output.
""" # noqa: E501
hidden_states = self.layers(
hidden_states,
position_bias,
cu_seqlens,
max_seqlen,
position_ids,
)
hidden_states = self.output_layernorm(hidden_states)
return hidden_states
class Dragonfly(bmt.DistributedModule):
def __init__(self, config: DragonflyConfig):
super().__init__()
self.encoder = Encoder(
num_layers=config.num_layers,
dim_model=config.dim_model,
dim_ff=config.dim_ff,
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
dim_head=config.dim_head,
activate_fn=config.activate_fn,
dtype=config.dtype,
eps=config.eps,
dropout_p=config.dropout_p,
tp=config.tp,
scale=config.scale,
init_std=config.init_std,
scale_width=config.scale_width,
scale_depth=config.scale_depth,
qk_norm=config.qk_norm,
use_checkpoint=config.use_checkpoint,
)
self.input_embedding = Embedding(
vocab_size=config.vocab_size,
embedding_size=config.dim_model,
dtype=config.dtype,
init_std=config.init_std,
tp=config.tp,
scale=config.scale,
scale_emb=config.scale_emb,
scale_width=config.scale_width,
)
self.position_bias = ScaledRotaryEmbeddingESM(
dim=config.dim_head,
max_position_embeddings=config.max_length,
base=config.base,
pose_prob=config.pose_prob,
pose_scaling_factor=config.pose_scaling_factor,
scaling_type=config.rope_scaling_type,
rope_scaling_factor=config.rope_scaling_factor,
original_max_position_embeddings=config.orig_max_length,
dynamic_scaling_seq_len=config.max_length, # disable dynamic scaling
persistent=False,
device="cuda",
)
if config.tie_lm_head is False:
self.lm_head = Embedding(
vocab_size=config.vocab_size,
embedding_size=config.dim_model,
dtype=config.dtype,
init_std=config.init_std,
scale=config.scale,
scale_width=config.scale_width,
tp=config.tp,
)
self.config = config
def forward(
self,
input: torch.Tensor, # (batch, seqlen) int32
cu_seqlens: torch.Tensor = None, # (real_batch+2) int32
max_seqlen: int = None,
position_ids: torch.Tensor = None, # (batch, seqlen) int32
):
hidden_states = self.input_embedding(input)
assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl"
hidden_states = self.encoder(
hidden_states,
position_bias=self.position_bias,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
if self.config.tie_lm_head is True:
logits = self.input_embedding.projection(hidden_states)
else:
logits = self.lm_head.projection(hidden_states)
return logits