forked from jiuyuan/CPM-9G-8B
170 lines
6.5 KiB
Python
170 lines
6.5 KiB
Python
import math
|
|
from typing import Optional
|
|
|
|
import bmtrain as bmt
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from .position_embedding import RotaryEmbedding
|
|
|
|
|
|
class Embedding(bmt.DistributedModule):
|
|
def __init__(
|
|
self,
|
|
vocab_size: int,
|
|
embedding_size: int,
|
|
dtype: torch.dtype = torch.half,
|
|
scale: bool = True,
|
|
init_mean: float = 0.0,
|
|
init_std: float = 1,
|
|
):
|
|
super().__init__()
|
|
|
|
self.dim_model = embedding_size
|
|
self.weight = bmt.DistributedParameter(
|
|
torch.empty(vocab_size, embedding_size, dtype=dtype),
|
|
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
|
)
|
|
self.scale = scale
|
|
|
|
def forward(self, ids: torch.Tensor):
|
|
"""
|
|
Args:
|
|
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
|
|
Return:
|
|
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
|
|
""" # noqa: E501
|
|
|
|
if self.scale:
|
|
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
|
|
else:
|
|
embeds = F.embedding(ids, self.weight)
|
|
return embeds
|
|
|
|
def projection(self, x: torch.Tensor):
|
|
"""
|
|
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
|
|
Args:
|
|
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
|
|
Returns:
|
|
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
|
|
""" # noqa: E501
|
|
if self.scale:
|
|
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
|
|
else:
|
|
logits = F.linear(x, self.weight)
|
|
return logits
|
|
|
|
|
|
class EmbeddingExt(bmt.DistributedModule):
|
|
def __init__(
|
|
self,
|
|
vocab_size: int,
|
|
embedding_size: int,
|
|
dtype: torch.dtype = torch.half,
|
|
init_mean: float = 0.0,
|
|
init_std: float = 1,
|
|
distance_scale: int = 16,
|
|
):
|
|
super().__init__()
|
|
|
|
self.dim_model = embedding_size
|
|
self.rotary_emb = RotaryEmbedding(dim=embedding_size, distance_scale=distance_scale, dtype=dtype)
|
|
|
|
self.weight = bmt.DistributedParameter(
|
|
torch.empty(vocab_size, embedding_size, dtype=dtype),
|
|
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
|
)
|
|
|
|
def forward(self, ids: torch.Tensor, ids_sub: torch.Tensor):
|
|
"""
|
|
Args:
|
|
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
|
|
ids (:obj:`torch.Tensor` of shape ``(batch_size)``): Subscript of input sequence tokens.
|
|
Return:
|
|
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
|
|
""" # noqa: E501
|
|
|
|
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
|
|
return self.rotary_emb(embeds, ids_sub)
|
|
|
|
def projection(self, x: torch.Tensor, ext_table: Optional[torch.Tensor] = None):
|
|
"""
|
|
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
|
|
Args:
|
|
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
|
|
ext_table (:obj:`torch.Tensor` of shape ``(ext_table_size, dim_model)``): Ext vocab table.
|
|
Returns:
|
|
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_size + ext_table_size)``: The projection output.
|
|
""" # noqa: E501
|
|
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
|
|
if ext_table is not None:
|
|
logits_ext = F.linear(x, ext_table)
|
|
logits = torch.cat([logits, logits_ext], dim=-1)
|
|
return logits
|
|
|
|
|
|
class VocabParallelEmbedding(bmt.DistributedModule):
|
|
def __init__(
|
|
self,
|
|
vocab_size: int,
|
|
embedding_size: int,
|
|
dtype: torch.dtype = torch.half,
|
|
scale: bool = True,
|
|
init_mean: float = 0.0,
|
|
init_std: float = 1,
|
|
):
|
|
super().__init__()
|
|
|
|
self.dim_model = embedding_size
|
|
assert vocab_size % config["tp_size"] == 0
|
|
self.vocab_size_per_partition = vocab_size // config["tp_size"]
|
|
self.start_index = config["tp_rank"] * self.vocab_size_per_partition
|
|
self.end_index = (config["tp_rank"] + 1) * self.vocab_size_per_partition
|
|
self.weight = bmt.DistributedParameter(
|
|
torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype),
|
|
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
|
tp_split_dim=0,
|
|
tp_mode=True,
|
|
)
|
|
|
|
def forward(self, ids: torch.Tensor, gather_input=True):
|
|
"""
|
|
Args:
|
|
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
|
|
gather_input (bool) : whether gather input is required between tensor parallel group)
|
|
Return:
|
|
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
|
|
""" # noqa: E501
|
|
|
|
if gather_input:
|
|
ids = all_gather(ids, comm=config["tp_comm"])
|
|
input_mask = (ids < self.start_index) | (ids >= self.end_index)
|
|
ids = ids.clone() - self.start_index
|
|
ids[input_mask] = 0
|
|
|
|
embeds = F.embedding(ids, self.weight)
|
|
|
|
embeds[input_mask, :] = 0.0
|
|
embeds = all_reduce(embeds, op="sum", comm=config["tp_comm"])
|
|
embed_list = embeds.chunk(config["tp_size"], dim=0)
|
|
embeds = embed_list[config["tp_rank"]].flatten(0, 1)
|
|
|
|
if self.scale:
|
|
embeds = embeds / math.sqrt(self.dim_model)
|
|
|
|
return embeds
|
|
|
|
def projection(self, x: torch.Tensor, gather_output=False, gather_input=True):
|
|
"""
|
|
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
|
|
Args:
|
|
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
|
|
Returns:
|
|
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
|
|
""" # noqa: E501
|
|
if self.scale:
|
|
x = x / math.sqrt(self.dim_model)
|
|
out = bmt.nn.OpParallelLinear.apply(x, self.weight, None, gather_input, gather_output, False, None)
|
|
return out
|