forked from jiuyuan/CPM-9G-8B
369 lines
15 KiB
Python
369 lines
15 KiB
Python
|
import math
|
||
|
from typing import Tuple
|
||
|
from typing import Union
|
||
|
|
||
|
import bmtrain as bmt
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
try:
|
||
|
from flash_attn.layers.rotary import apply_rotary_emb_func
|
||
|
except:
|
||
|
apply_rotary_emb_func = None
|
||
|
|
||
|
|
||
|
class SegmentPositionEmbedding(bmt.DistributedModule):
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_heads: int,
|
||
|
num_segments: int = 1,
|
||
|
num_buckets: int = 32,
|
||
|
max_distance: int = 128,
|
||
|
bidirectional: bool = False,
|
||
|
dtype: torch.dtype = torch.half,
|
||
|
init_mean: float = 0.0,
|
||
|
init_std: float = 1,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
self.num_heads = num_heads
|
||
|
self.num_buckets = num_buckets
|
||
|
self.max_distance = max_distance
|
||
|
self.bidirectional = bidirectional
|
||
|
self.num_segments = num_segments
|
||
|
|
||
|
self.relative_attention_bias = bmt.DistributedParameter(
|
||
|
torch.empty(num_segments * num_segments + num_buckets, num_heads, dtype=dtype),
|
||
|
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
key_pos: torch.Tensor,
|
||
|
query_pos: torch.Tensor,
|
||
|
key_segment: torch.Tensor,
|
||
|
query_segment: torch.Tensor,
|
||
|
):
|
||
|
with torch.no_grad():
|
||
|
batch = key_pos.size(0)
|
||
|
keylen = key_pos.size(1)
|
||
|
querylen = query_pos.size(1)
|
||
|
|
||
|
assert key_pos.size(0) == query_pos.size(0)
|
||
|
assert keylen == key_segment.size(1) and querylen == query_segment.size(1)
|
||
|
|
||
|
key_pos = key_pos.view(batch, -1, keylen)
|
||
|
query_pos = query_pos.view(batch, querylen, -1)
|
||
|
key_segment = key_segment.view(batch, -1, keylen)
|
||
|
query_segment = query_segment.view(batch, querylen, -1)
|
||
|
|
||
|
relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
|
||
|
relative_position_bucket = relative_position_bucket + self.num_buckets # 与相对位置编码区间不重叠
|
||
|
|
||
|
# b*q*k
|
||
|
absolute_position_bucket = self._position_bucket(
|
||
|
torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
|
||
|
- torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
|
||
|
bidirectional=self.bidirectional,
|
||
|
num_buckets=self.num_buckets,
|
||
|
max_distance=self.max_distance,
|
||
|
)
|
||
|
relative_position_bucket = torch.where(
|
||
|
(key_segment == query_segment),
|
||
|
absolute_position_bucket[None, :, :],
|
||
|
relative_position_bucket,
|
||
|
)
|
||
|
# (batch, len_q, len_k)
|
||
|
|
||
|
# (batch, len_q, len_k, num_heads)
|
||
|
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
|
||
|
# (batch, num_heads, len_q, len_k)
|
||
|
embeds = embeds.permute(0, 3, 1, 2).contiguous()
|
||
|
return embeds
|
||
|
|
||
|
def _segment_relative_position_bucket(self, query_segment, key_segment):
|
||
|
return query_segment * self.num_segments + key_segment
|
||
|
|
||
|
def _position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||
|
relative_buckets = 0
|
||
|
if bidirectional:
|
||
|
num_buckets //= 2
|
||
|
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
|
||
|
relative_position = torch.abs(relative_position)
|
||
|
else:
|
||
|
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
||
|
max_exact = num_buckets // 2
|
||
|
is_small = relative_position < max_exact
|
||
|
relative_postion_if_large = max_exact + (
|
||
|
torch.log(relative_position.float() / max_exact)
|
||
|
/ math.log(max_distance / max_exact)
|
||
|
* (num_buckets - max_exact)
|
||
|
).to(torch.int32)
|
||
|
relative_postion_if_large = torch.min(
|
||
|
relative_postion_if_large,
|
||
|
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
||
|
)
|
||
|
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
|
||
|
return relative_buckets
|
||
|
|
||
|
|
||
|
class BucketPositionBias(bmt.DistributedModule):
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_heads: int,
|
||
|
num_buckets: int = 32,
|
||
|
num_segment_bucket: int = 32,
|
||
|
max_distance: int = 128,
|
||
|
dtype: torch.dtype = torch.half,
|
||
|
init_mean: float = 0.0,
|
||
|
init_std: float = 1,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
self.num_heads = num_heads
|
||
|
self.num_buckets = num_buckets
|
||
|
self.num_segment_bucket = num_segment_bucket
|
||
|
self.max_distance = max_distance
|
||
|
|
||
|
self.relative_attention_bias = bmt.DistributedParameter(
|
||
|
torch.empty(num_buckets + num_segment_bucket, num_heads, dtype=dtype),
|
||
|
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
query_pos: torch.Tensor, # (batch, len_q)
|
||
|
key_pos: torch.Tensor, # (batch, len_k)
|
||
|
rel_buckets: torch.Tensor, # (batch, len_q, len_k)
|
||
|
):
|
||
|
with torch.no_grad():
|
||
|
batch = key_pos.size(0)
|
||
|
keylen = key_pos.size(1)
|
||
|
querylen = query_pos.size(1)
|
||
|
|
||
|
assert key_pos.size(0) == query_pos.size(0)
|
||
|
assert rel_buckets.size(0) == batch and rel_buckets.size(1) == querylen and rel_buckets.size(2) == keylen
|
||
|
|
||
|
relative_position_bucket = rel_buckets - 1 + self.num_buckets # 与相对位置编码区间不重叠
|
||
|
|
||
|
# b*q*k
|
||
|
inner_segment_bucket = self._position_bucket(
|
||
|
key_pos[..., None, :] - query_pos[..., :, None],
|
||
|
num_buckets=self.num_buckets,
|
||
|
max_distance=self.max_distance,
|
||
|
)
|
||
|
relative_position_bucket = torch.where(
|
||
|
rel_buckets == 0,
|
||
|
inner_segment_bucket,
|
||
|
relative_position_bucket,
|
||
|
)
|
||
|
# (batch, len_q, len_k)
|
||
|
|
||
|
# (batch, len_q, len_k, num_heads)
|
||
|
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
|
||
|
# (batch, num_heads, len_q, len_k)
|
||
|
embeds = embeds.permute(0, 3, 1, 2).contiguous()
|
||
|
return embeds
|
||
|
|
||
|
def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
|
||
|
relative_buckets = 0
|
||
|
num_buckets //= 2
|
||
|
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
|
||
|
relative_position = torch.abs(relative_position)
|
||
|
|
||
|
max_exact = num_buckets // 2
|
||
|
is_small = relative_position < max_exact
|
||
|
relative_postion_if_large = max_exact + (
|
||
|
torch.log(relative_position.float() / max_exact)
|
||
|
/ math.log(max_distance / max_exact)
|
||
|
* (num_buckets - max_exact)
|
||
|
).to(torch.int32)
|
||
|
relative_postion_if_large = torch.min(
|
||
|
relative_postion_if_large,
|
||
|
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
||
|
)
|
||
|
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
|
||
|
return relative_buckets
|
||
|
|
||
|
|
||
|
class RotaryEmbedding(bmt.DistributedModule):
|
||
|
def __init__(
|
||
|
self,
|
||
|
dim,
|
||
|
base: Union[int, float] = 10000,
|
||
|
distance_scale: Union[int, float] = 1,
|
||
|
dtype: torch.dtype = torch.half,
|
||
|
):
|
||
|
super().__init__()
|
||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
|
||
|
inv_freq = inv_freq.to(dtype)
|
||
|
self.distance_scale = distance_scale
|
||
|
self.dtype = dtype
|
||
|
self.inv_freq = inv_freq
|
||
|
|
||
|
def forward(self, x: torch.Tensor, x_pos: torch.Tensor):
|
||
|
"""
|
||
|
Args:
|
||
|
x (:obj:`torch.Tensor` of shape ``(..., dim)``): Inputs.
|
||
|
x_pos (:obj:`torch.Tensor` of shape ``(...)``): Positions of inputs.
|
||
|
"""
|
||
|
x_pos = x_pos * self.distance_scale
|
||
|
freqs = x_pos[..., None].to(self.dtype) * self.inv_freq[None, :] # (..., dim/2)
|
||
|
|
||
|
# the same implementation as sat
|
||
|
emb = torch.cat((freqs, freqs), dim=-1) # (..., dim)
|
||
|
emb_cos = emb.cos() # (..., dim)
|
||
|
emb_sin = emb.sin() # (..., dim)
|
||
|
|
||
|
rotate_x = torch.cat([-x[..., x.size(-1) // 2 :], x[..., : x.size(-1) // 2]], dim=-1) # (..., dim)
|
||
|
|
||
|
return x * emb_cos + rotate_x * emb_sin
|
||
|
|
||
|
|
||
|
def rotate_half(x):
|
||
|
x1, x2 = x.chunk(2, dim=-1)
|
||
|
return torch.cat((-x2, x1), dim=-1)
|
||
|
|
||
|
|
||
|
def apply_rotary_pos_emb(x, cos, sin, seq_dim, offset):
|
||
|
if x.size(seq_dim) < cos.size(seq_dim): # == do not need narrow
|
||
|
cos = cos.narrow(seq_dim, offset, x.size(seq_dim))
|
||
|
sin = sin.narrow(seq_dim, offset, x.size(seq_dim))
|
||
|
return (x * cos) + (rotate_half(x) * sin)
|
||
|
|
||
|
|
||
|
def unpad_apply_rotary_pos_emb(x, cos, sin, seq_dim, position_ids):
|
||
|
cos = cos.index_select(seq_dim, position_ids.view(-1))
|
||
|
sin = sin.index_select(seq_dim, position_ids.view(-1))
|
||
|
return (x * cos) + (rotate_half(x) * sin)
|
||
|
|
||
|
|
||
|
class RotaryEmbeddingESM(bmt.DistributedModule):
|
||
|
"""
|
||
|
Rotary position embeddings based on those in
|
||
|
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
|
||
|
matrices which depend on their relative positions.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
dim: int,
|
||
|
base: Union[int, float] = 10000,
|
||
|
distance_scale: Union[int, float] = 1,
|
||
|
dtype=torch.half,
|
||
|
persistent=True,
|
||
|
mixed_precision=True,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.base = base
|
||
|
self.distance_scale = distance_scale
|
||
|
self.dtype = dtype
|
||
|
|
||
|
# Generate and save the inverse frequency buffer (non trainable)
|
||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
|
||
|
if mixed_precision:
|
||
|
self.register_buffer("inv_freq", inv_freq, persistent=persistent)
|
||
|
else:
|
||
|
self.register_buffer("inv_freq", inv_freq.to(self.dtype), persistent=persistent)
|
||
|
|
||
|
self._seq_len_cached = -1
|
||
|
self._cos_cached = None
|
||
|
self._sin_cached = None
|
||
|
self.mixed_precision = mixed_precision
|
||
|
|
||
|
self.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||
|
self.unpad_apply_rotary_pos_emb = unpad_apply_rotary_pos_emb
|
||
|
|
||
|
def _update_cos_sin_tables(self, x, seq_dim, seq_len):
|
||
|
if seq_len > self._seq_len_cached or self._cos_cached.device != x.device:
|
||
|
self._seq_len_cached = seq_len
|
||
|
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
||
|
freqs = torch.outer(t * self.distance_scale, self.inv_freq)
|
||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||
|
for i in range(x.dim() - 1):
|
||
|
if i != seq_dim:
|
||
|
emb = emb.unsqueeze_(i)
|
||
|
if self.mixed_precision:
|
||
|
self._cos_cached = emb.cos().to(self.dtype)
|
||
|
self._sin_cached = emb.sin().to(self.dtype)
|
||
|
else:
|
||
|
self._cos_cached = emb.cos()
|
||
|
self._sin_cached = emb.sin()
|
||
|
return self._cos_cached, self._sin_cached
|
||
|
|
||
|
def forward(
|
||
|
self, q: torch.Tensor, k: torch.Tensor, seq_dim, offset=0, cu_seqlens=None, max_length=None, position_ids=None
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
seq_dim = (seq_dim + k.dim()) % k.dim()
|
||
|
if cu_seqlens is None:
|
||
|
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, k.size(seq_dim) + offset)
|
||
|
return (
|
||
|
self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, offset),
|
||
|
self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, offset),
|
||
|
)
|
||
|
else:
|
||
|
assert offset == 0, "past kv is not supported in flash attn"
|
||
|
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, max_length)
|
||
|
return (
|
||
|
self.unpad_apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, position_ids),
|
||
|
self.unpad_apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, position_ids),
|
||
|
)
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def apply_chatglm_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
||
|
# x: [b, np, sq, hn]
|
||
|
x = x.permute(2, 0, 1, 3) # [b, np, sq, hn] -> [sq, b, np, hn]
|
||
|
sq, b, np, hn = x.shape
|
||
|
rot_dim = rope_cache.shape[-2] * 2
|
||
|
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
||
|
# truncate to support variable sizes
|
||
|
rope_cache = rope_cache[:sq]
|
||
|
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
||
|
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
||
|
x_out2 = torch.stack(
|
||
|
[
|
||
|
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
||
|
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
||
|
],
|
||
|
-1,
|
||
|
)
|
||
|
x_out2 = x_out2.flatten(3)
|
||
|
ret = torch.cat((x_out2, x_pass), dim=-1)
|
||
|
ret = ret.permute(1, 2, 0, 3) # [sq, b, np, hn] -> [b, np, sq, hn]
|
||
|
return ret
|
||
|
|
||
|
|
||
|
class ChatGLMRotaryEmbedding(bmt.DistributedModule):
|
||
|
def __init__(self, dim, device="cuda", dtype=torch.float16, persistent=True):
|
||
|
super().__init__()
|
||
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=dtype, device=device) / dim))
|
||
|
self.register_buffer("inv_freq", inv_freq, persistent=persistent)
|
||
|
self.dim = dim
|
||
|
|
||
|
def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000):
|
||
|
"""Enhanced Transformer with Rotary Position Embedding.
|
||
|
|
||
|
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
||
|
transformers/rope/__init__.py. MIT License:
|
||
|
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
||
|
"""
|
||
|
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
||
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
|
||
|
|
||
|
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
||
|
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
|
||
|
|
||
|
# Calculate the product of position index and $\theta_i$
|
||
|
idx_theta = torch.outer(seq_idx, theta).float()
|
||
|
|
||
|
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
|
||
|
|
||
|
# this is to mimic the behaviour of complex32, else we will get different results
|
||
|
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
||
|
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
|
||
|
return cache
|
||
|
|
||
|
def forward(self, max_seq_len, offset: int = 0):
|
||
|
return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
|