import math from typing import Tuple from typing import Union import bmtrain as bmt import torch import torch.nn.functional as F try: from flash_attn.layers.rotary import apply_rotary_emb_func except: apply_rotary_emb_func = None class SegmentPositionEmbedding(bmt.DistributedModule): def __init__( self, num_heads: int, num_segments: int = 1, num_buckets: int = 32, max_distance: int = 128, bidirectional: bool = False, dtype: torch.dtype = torch.half, init_mean: float = 0.0, init_std: float = 1, ): super().__init__() self.num_heads = num_heads self.num_buckets = num_buckets self.max_distance = max_distance self.bidirectional = bidirectional self.num_segments = num_segments self.relative_attention_bias = bmt.DistributedParameter( torch.empty(num_segments * num_segments + num_buckets, num_heads, dtype=dtype), init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), ) def forward( self, key_pos: torch.Tensor, query_pos: torch.Tensor, key_segment: torch.Tensor, query_segment: torch.Tensor, ): with torch.no_grad(): batch = key_pos.size(0) keylen = key_pos.size(1) querylen = query_pos.size(1) assert key_pos.size(0) == query_pos.size(0) assert keylen == key_segment.size(1) and querylen == query_segment.size(1) key_pos = key_pos.view(batch, -1, keylen) query_pos = query_pos.view(batch, querylen, -1) key_segment = key_segment.view(batch, -1, keylen) query_segment = query_segment.view(batch, querylen, -1) relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment) relative_position_bucket = relative_position_bucket + self.num_buckets # 与相对位置编码区间不重叠 # b*q*k absolute_position_bucket = self._position_bucket( torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :] - torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None], bidirectional=self.bidirectional, num_buckets=self.num_buckets, max_distance=self.max_distance, ) relative_position_bucket = torch.where( (key_segment == query_segment), absolute_position_bucket[None, :, :], relative_position_bucket, ) # (batch, len_q, len_k) # (batch, len_q, len_k, num_heads) embeds = F.embedding(relative_position_bucket, self.relative_attention_bias) # (batch, num_heads, len_q, len_k) embeds = embeds.permute(0, 3, 1, 2).contiguous() return embeds def _segment_relative_position_bucket(self, query_segment, key_segment): return query_segment * self.num_segments + key_segment def _position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128): relative_buckets = 0 if bidirectional: num_buckets //= 2 relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets relative_position = torch.abs(relative_position) else: relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) max_exact = num_buckets // 2 is_small = relative_position < max_exact relative_postion_if_large = max_exact + ( torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.int32) relative_postion_if_large = torch.min( relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1), ) relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large) return relative_buckets class BucketPositionBias(bmt.DistributedModule): def __init__( self, num_heads: int, num_buckets: int = 32, num_segment_bucket: int = 32, max_distance: int = 128, dtype: torch.dtype = torch.half, init_mean: float = 0.0, init_std: float = 1, ) -> None: super().__init__() self.num_heads = num_heads self.num_buckets = num_buckets self.num_segment_bucket = num_segment_bucket self.max_distance = max_distance self.relative_attention_bias = bmt.DistributedParameter( torch.empty(num_buckets + num_segment_bucket, num_heads, dtype=dtype), init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), ) def forward( self, query_pos: torch.Tensor, # (batch, len_q) key_pos: torch.Tensor, # (batch, len_k) rel_buckets: torch.Tensor, # (batch, len_q, len_k) ): with torch.no_grad(): batch = key_pos.size(0) keylen = key_pos.size(1) querylen = query_pos.size(1) assert key_pos.size(0) == query_pos.size(0) assert rel_buckets.size(0) == batch and rel_buckets.size(1) == querylen and rel_buckets.size(2) == keylen relative_position_bucket = rel_buckets - 1 + self.num_buckets # 与相对位置编码区间不重叠 # b*q*k inner_segment_bucket = self._position_bucket( key_pos[..., None, :] - query_pos[..., :, None], num_buckets=self.num_buckets, max_distance=self.max_distance, ) relative_position_bucket = torch.where( rel_buckets == 0, inner_segment_bucket, relative_position_bucket, ) # (batch, len_q, len_k) # (batch, len_q, len_k, num_heads) embeds = F.embedding(relative_position_bucket, self.relative_attention_bias) # (batch, num_heads, len_q, len_k) embeds = embeds.permute(0, 3, 1, 2).contiguous() return embeds def _position_bucket(self, relative_position, num_buckets=32, max_distance=128): relative_buckets = 0 num_buckets //= 2 relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets relative_position = torch.abs(relative_position) max_exact = num_buckets // 2 is_small = relative_position < max_exact relative_postion_if_large = max_exact + ( torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.int32) relative_postion_if_large = torch.min( relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1), ) relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large) return relative_buckets class RotaryEmbedding(bmt.DistributedModule): def __init__( self, dim, base: Union[int, float] = 10000, distance_scale: Union[int, float] = 1, dtype: torch.dtype = torch.half, ): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim)) inv_freq = inv_freq.to(dtype) self.distance_scale = distance_scale self.dtype = dtype self.inv_freq = inv_freq def forward(self, x: torch.Tensor, x_pos: torch.Tensor): """ Args: x (:obj:`torch.Tensor` of shape ``(..., dim)``): Inputs. x_pos (:obj:`torch.Tensor` of shape ``(...)``): Positions of inputs. """ x_pos = x_pos * self.distance_scale freqs = x_pos[..., None].to(self.dtype) * self.inv_freq[None, :] # (..., dim/2) # the same implementation as sat emb = torch.cat((freqs, freqs), dim=-1) # (..., dim) emb_cos = emb.cos() # (..., dim) emb_sin = emb.sin() # (..., dim) rotate_x = torch.cat([-x[..., x.size(-1) // 2 :], x[..., : x.size(-1) // 2]], dim=-1) # (..., dim) return x * emb_cos + rotate_x * emb_sin def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(x, cos, sin, seq_dim, offset): if x.size(seq_dim) < cos.size(seq_dim): # == do not need narrow cos = cos.narrow(seq_dim, offset, x.size(seq_dim)) sin = sin.narrow(seq_dim, offset, x.size(seq_dim)) return (x * cos) + (rotate_half(x) * sin) def unpad_apply_rotary_pos_emb(x, cos, sin, seq_dim, position_ids): cos = cos.index_select(seq_dim, position_ids.view(-1)) sin = sin.index_select(seq_dim, position_ids.view(-1)) return (x * cos) + (rotate_half(x) * sin) class RotaryEmbeddingESM(bmt.DistributedModule): """ Rotary position embeddings based on those in [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation matrices which depend on their relative positions. """ def __init__( self, dim: int, base: Union[int, float] = 10000, distance_scale: Union[int, float] = 1, dtype=torch.half, persistent=True, mixed_precision=True, ): super().__init__() self.base = base self.distance_scale = distance_scale self.dtype = dtype # Generate and save the inverse frequency buffer (non trainable) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim)) if mixed_precision: self.register_buffer("inv_freq", inv_freq, persistent=persistent) else: self.register_buffer("inv_freq", inv_freq.to(self.dtype), persistent=persistent) self._seq_len_cached = -1 self._cos_cached = None self._sin_cached = None self.mixed_precision = mixed_precision self.apply_rotary_pos_emb = apply_rotary_pos_emb self.unpad_apply_rotary_pos_emb = unpad_apply_rotary_pos_emb def _update_cos_sin_tables(self, x, seq_dim, seq_len): if seq_len > self._seq_len_cached or self._cos_cached.device != x.device: self._seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) freqs = torch.outer(t * self.distance_scale, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) for i in range(x.dim() - 1): if i != seq_dim: emb = emb.unsqueeze_(i) if self.mixed_precision: self._cos_cached = emb.cos().to(self.dtype) self._sin_cached = emb.sin().to(self.dtype) else: self._cos_cached = emb.cos() self._sin_cached = emb.sin() return self._cos_cached, self._sin_cached def forward( self, q: torch.Tensor, k: torch.Tensor, seq_dim, offset=0, cu_seqlens=None, max_length=None, position_ids=None ) -> Tuple[torch.Tensor, torch.Tensor]: seq_dim = (seq_dim + k.dim()) % k.dim() if cu_seqlens is None: self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, k.size(seq_dim) + offset) return ( self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, offset), self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, offset), ) else: assert offset == 0, "past kv is not supported in flash attn" self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, max_length) return ( self.unpad_apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, position_ids), self.unpad_apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, position_ids), ) @torch.jit.script def apply_chatglm_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: # x: [b, np, sq, hn] x = x.permute(2, 0, 1, 3) # [b, np, sq, hn] -> [sq, b, np, hn] sq, b, np, hn = x.shape rot_dim = rope_cache.shape[-2] * 2 x, x_pass = x[..., :rot_dim], x[..., rot_dim:] # truncate to support variable sizes rope_cache = rope_cache[:sq] xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) x_out2 = torch.stack( [ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], ], -1, ) x_out2 = x_out2.flatten(3) ret = torch.cat((x_out2, x_pass), dim=-1) ret = ret.permute(1, 2, 0, 3) # [sq, b, np, hn] -> [b, np, sq, hn] return ret class ChatGLMRotaryEmbedding(bmt.DistributedModule): def __init__(self, dim, device="cuda", dtype=torch.float16, persistent=True): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=dtype, device=device) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=persistent) self.dim = dim def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000): """Enhanced Transformer with Rotary Position Embedding. Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ transformers/rope/__init__.py. MIT License: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. """ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, dtype=dtype, device=device) # Calculate the product of position index and $\theta_i$ idx_theta = torch.outer(seq_idx, theta).float() cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) # this is to mimic the behaviour of complex32, else we will get different results if dtype in (torch.float16, torch.bfloat16, torch.int8): cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() return cache def forward(self, max_seq_len, offset: int = 0): return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)