fix saving custom code

This commit is contained in:
hiyouga 2023-07-16 18:04:41 +08:00
parent 2c867b9bb1
commit 1e1358431d
2 changed files with 89 additions and 24 deletions

View File

@ -11,7 +11,7 @@ from transformers import (
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
@ -36,7 +36,7 @@ def load_model_and_tokenizer(
finetuning_args: FinetuningArguments, finetuning_args: FinetuningArguments,
is_trainable: Optional[bool] = False, is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
r""" r"""
Loads pretrained model and tokenizer. Loads pretrained model and tokenizer.
@ -113,12 +113,12 @@ def load_model_and_tokenizer(
) )
# Register auto class to save the custom code files. # Register auto class to save the custom code files.
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig): if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class() config.__class__.register_for_auto_class()
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer): if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
tokenizer.__class__.register_for_auto_class()
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel):
model.__class__.register_for_auto_class() model.__class__.register_for_auto_class()
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()
# Initialize adapters # Initialize adapters
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model

View File

@ -300,6 +300,45 @@ class BaichuanPreTrainedModel(PreTrainedModel):
if isinstance(module, BaichuanModel): if isinstance(module, BaichuanModel):
module.gradient_checkpointing = value module.gradient_checkpointing = value
@staticmethod
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
@staticmethod
def _convert_to_baichuan_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
class BaichuanModel(BaichuanPreTrainedModel): class BaichuanModel(BaichuanPreTrainedModel):
@ -318,9 +357,9 @@ class BaichuanModel(BaichuanPreTrainedModel):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
return build_alibi_tensor(attention_mask, num_heads, dtype) return build_alibi_tensor(attention_mask, num_heads, dtype)
@ -468,7 +507,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
class BaichuanForCausalLM(BaichuanPreTrainedModel): class BaichuanForCausalLM(BaichuanPreTrainedModel):
@ -498,7 +537,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.model return self.model
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
@ -528,7 +567,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
@ -559,11 +598,20 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self,
): input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
) -> dict:
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
# the cache may be in the standard format (e.g. in contrastive search)
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_baichuan_cache(past_key_values)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
@ -571,21 +619,38 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
model_inputs.update( model_inputs.update(
{ {
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask, "attention_mask": attention_mask,
} }
) )
return model_inputs return model_inputs
@staticmethod def _reorder_cache(
def _reorder_cache(past_key_values, beam_idx): self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
return tuple( ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) """
for layer_past in past_key_values This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
) [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in standardized_past
)
return self._convert_to_baichuan_cache(reordered_past)
def quantize(self, bits: int): def quantize(self, bits: int):
try: try:
@ -594,7 +659,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
raise ImportError( raise ImportError(
f"Needs QLinear to run quantize." f"Needs QLinear to run quantize."
) )
for layer in self.model.layers: for layer in self.model.layers:
layer.self_attn.W_pack = QLinear( layer.self_attn.W_pack = QLinear(
bits=bits, bits=bits,
@ -621,7 +686,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
weight=layer.mlp.up_proj.weight, weight=layer.mlp.up_proj.weight,
bias = None, bias = None,
) )
return self return self
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0): def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens