744 lines
30 KiB
744 lines
30 KiB
# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging
from transformers.generation.utils import GenerationConfig
from .configuration_baichuan import BaichuanConfig
logger = logging.get_logger(__name__)
# Copied from transformers.models.bloom.modeling_bloom._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
Make causal mask used for self-attention.
batch_size, target_length = input_ids_shape
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
seq_ids = torch.arange(target_length, device=device)
mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
if past_key_values_length > 0:
mask[:, :past_key_values_length] = False
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
return expanded_mask
# Copied from transformers.models.bloom.modeling_bloom._expand_mask
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
# Copied from transformers.models.bloom.modeling_bloom.build_alibi_tensor
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`.
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
class RMSNorm(nn.Module):
def __init__(self, hidden_size, epsilon=1e-6):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.epsilon = epsilon
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
return (self.weight * hidden_states).to(input_dtype)
class MLP(nn.Module):
def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str,
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class BaichuanAttention(nn.Module):
def __init__(self, config: BaichuanConfig):
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.model_max_length
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = 1.0
self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
proj = self.W_pack(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim)
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim)
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim)
key_states = key_states.permute(0, 2, 3, 1).reshape(bsz * self.num_heads, self.head_dim, q_len)
value_states = value_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim)
if past_key_value is not None:
# reuse k, v, self_attention
past_key, past_value = past_key_value
key_states = torch.cat([past_key, key_states], dim=2)
value_states = torch.cat([past_value, value_states], dim=1)
_, _, kv_seq_len = key_states.shape
past_key_value = (key_states, value_states) if use_cache else None
# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
matmul_result = alibi.baddbmm(
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(bsz, self.num_heads, q_len, kv_seq_len)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
# [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
# change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(bsz * self.num_heads, q_len, kv_seq_len)
# matmul: [batch_size * num_heads, q_length, head_dim]
attn_output = torch.bmm(attention_probs_reshaped, value_states)
attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attention_probs = None
return attn_output, attention_probs, past_key_value
class BaichuanLayer(nn.Module):
def __init__(self, config: BaichuanConfig):
self.hidden_size = config.hidden_size
self.self_attn = BaichuanAttention(config=config)
self.mlp = MLP(
self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
def forward(
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class BaichuanPreTrainedModel(PreTrainedModel):
config_class = BaichuanConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["BaichuanLayer"]
_skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BaichuanModel):
module.gradient_checkpointing = value
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
for layer_past in past_key_value
def _convert_to_baichuan_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
for layer_past in past_key_value
class BaichuanModel(BaichuanPreTrainedModel):
def __init__(self, config: BaichuanConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.n_head = config.num_attention_heads
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
self.gradient_checkpointing = config.gradient_checkpointing
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
return build_alibi_tensor(attention_mask, num_heads, dtype)
def _prepare_attn_mask(
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = _make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
return combined_attention_mask
def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
raise ValueError("You need to provide input_ids or inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[1]
seq_length_with_past = seq_length_with_past + past_key_values_length
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
attention_mask = attention_mask.to(hidden_states.device)
if self.gradient_checkpointing and self.training:
if use_cache:
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
use_cache = False
# Compute alibi tensor: check build_alibi_tensor documentation
alibi = self.build_alibi_tensor(attention_mask, self.n_head, dtype=hidden_states.dtype)
causal_mask = self._prepare_attn_mask(
input_shape=(batch_size, seq_length),
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = decoder_layer(
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
class BaichuanForCausalLM(BaichuanPreTrainedModel):
def __init__(self, config):
self.model = BaichuanModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
def prepare_inputs_for_generation(
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> dict:
if past_key_values:
input_ids = input_ids[:, -1:]
# the cache may be in the standard format (e.g. in contrastive search)
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_baichuan_cache(past_key_values)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"input_ids": input_ids}
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
return model_inputs
def _reorder_cache(
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
reordered_past = tuple(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
for layer_past in standardized_past
return self._convert_to_baichuan_cache(reordered_past)
def quantize(self, bits: int):
from .quantizer import QLinear
except ImportError:
raise ImportError(
f"Needs QLinear to run quantize."
for layer in self.model.layers:
layer.self_attn.W_pack = QLinear(
bias = None,
layer.self_attn.o_proj = QLinear(
bias = None,
layer.mlp.gate_proj = QLinear(
bias = None,
layer.mlp.down_proj = QLinear(
bias = None,
layer.mlp.up_proj = QLinear(
bias = None,
return self
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
max_input_tokens = self.config.model_max_length - max_new_tokens
max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)
total_input, round_input = [], []
for i, message in enumerate(messages[::-1]):
content_tokens = tokenizer.encode(message['content'])
if message['role'] == 'user':
round_input = [self.generation_config.user_token_id] + content_tokens + round_input
if total_input and len(total_input) + len(round_input) > max_input_tokens:
total_input = round_input + total_input
if len(total_input) >= max_input_tokens:
round_input = []
elif message['role'] == 'assistant':
round_input = [
] + content_tokens + [
] + round_input
raise ValueError(f"message role not supported yet: {message['role']}")
total_input = total_input[-max_input_tokens:] # truncate left
total_input = torch.LongTensor([total_input]).to(self.device)
return total_input
def chat(self, tokenizer, messages: List[dict], stream=False,
generation_config: Optional[GenerationConfig]=None):
generation_config = generation_config or self.generation_config
input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens)
if stream:
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
self.__class__.generate = NewGenerationMixin.generate
self.__class__.sample_stream = NewGenerationMixin.sample_stream
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
def stream_generator():
outputs = []
for token in self.generate(input_ids, generation_config=stream_config):
yield tokenizer.decode(outputs, skip_special_tokens=True)
return stream_generator()
self.__class__.generate = PreTrainedModel.generate # disable stream
outputs = self.generate(input_ids, generation_config=generation_config)
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
return response