CPM-9G-8B/FM_9G/fm9g/layers/embedding.py

170 lines
6.5 KiB
Python

import math
from typing import Optional
import bmtrain as bmt
import torch
import torch.nn.functional as F
from .position_embedding import RotaryEmbedding
class Embedding(bmt.DistributedModule):
def __init__(
self,
vocab_size: int,
embedding_size: int,
dtype: torch.dtype = torch.half,
scale: bool = True,
init_mean: float = 0.0,
init_std: float = 1,
):
super().__init__()
self.dim_model = embedding_size
self.weight = bmt.DistributedParameter(
torch.empty(vocab_size, embedding_size, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
)
self.scale = scale
def forward(self, ids: torch.Tensor):
"""
Args:
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
Return:
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
""" # noqa: E501
if self.scale:
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
else:
embeds = F.embedding(ids, self.weight)
return embeds
def projection(self, x: torch.Tensor):
"""
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
""" # noqa: E501
if self.scale:
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
else:
logits = F.linear(x, self.weight)
return logits
class EmbeddingExt(bmt.DistributedModule):
def __init__(
self,
vocab_size: int,
embedding_size: int,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
distance_scale: int = 16,
):
super().__init__()
self.dim_model = embedding_size
self.rotary_emb = RotaryEmbedding(dim=embedding_size, distance_scale=distance_scale, dtype=dtype)
self.weight = bmt.DistributedParameter(
torch.empty(vocab_size, embedding_size, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
)
def forward(self, ids: torch.Tensor, ids_sub: torch.Tensor):
"""
Args:
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
ids (:obj:`torch.Tensor` of shape ``(batch_size)``): Subscript of input sequence tokens.
Return:
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
""" # noqa: E501
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
return self.rotary_emb(embeds, ids_sub)
def projection(self, x: torch.Tensor, ext_table: Optional[torch.Tensor] = None):
"""
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
ext_table (:obj:`torch.Tensor` of shape ``(ext_table_size, dim_model)``): Ext vocab table.
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_size + ext_table_size)``: The projection output.
""" # noqa: E501
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
if ext_table is not None:
logits_ext = F.linear(x, ext_table)
logits = torch.cat([logits, logits_ext], dim=-1)
return logits
class VocabParallelEmbedding(bmt.DistributedModule):
def __init__(
self,
vocab_size: int,
embedding_size: int,
dtype: torch.dtype = torch.half,
scale: bool = True,
init_mean: float = 0.0,
init_std: float = 1,
):
super().__init__()
self.dim_model = embedding_size
assert vocab_size % config["tp_size"] == 0
self.vocab_size_per_partition = vocab_size // config["tp_size"]
self.start_index = config["tp_rank"] * self.vocab_size_per_partition
self.end_index = (config["tp_rank"] + 1) * self.vocab_size_per_partition
self.weight = bmt.DistributedParameter(
torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
tp_split_dim=0,
tp_mode=True,
)
def forward(self, ids: torch.Tensor, gather_input=True):
"""
Args:
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
gather_input (bool) : whether gather input is required between tensor parallel group)
Return:
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
""" # noqa: E501
if gather_input:
ids = all_gather(ids, comm=config["tp_comm"])
input_mask = (ids < self.start_index) | (ids >= self.end_index)
ids = ids.clone() - self.start_index
ids[input_mask] = 0
embeds = F.embedding(ids, self.weight)
embeds[input_mask, :] = 0.0
embeds = all_reduce(embeds, op="sum", comm=config["tp_comm"])
embed_list = embeds.chunk(config["tp_size"], dim=0)
embeds = embed_list[config["tp_rank"]].flatten(0, 1)
if self.scale:
embeds = embeds / math.sqrt(self.dim_model)
return embeds
def projection(self, x: torch.Tensor, gather_output=False, gather_input=True):
"""
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
""" # noqa: E501
if self.scale:
x = x / math.sqrt(self.dim_model)
out = bmt.nn.OpParallelLinear.apply(x, self.weight, None, gather_input, gather_output, False, None)
return out