diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index d80afa70..abd88d0f 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -11,7 +11,7 @@ from transformers import ( from transformers.utils import check_min_version from transformers.utils.versions import require_version from transformers.modeling_utils import PretrainedConfig, PreTrainedModel -from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils import PreTrainedTokenizerBase from trl import AutoModelForCausalLMWithValueHead from llmtuner.extras.logging import get_logger @@ -36,7 +36,7 @@ def load_model_and_tokenizer( finetuning_args: FinetuningArguments, is_trainable: Optional[bool] = False, stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" -) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: +) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: r""" Loads pretrained model and tokenizer. @@ -113,12 +113,12 @@ def load_model_and_tokenizer( ) # Register auto class to save the custom code files. - if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig): + if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}): config.__class__.register_for_auto_class() - if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer): - tokenizer.__class__.register_for_auto_class() - if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel): + if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}): model.__class__.register_for_auto_class() + if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): + tokenizer.__class__.register_for_auto_class() # Initialize adapters model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model diff --git a/tests/modeling_baichuan.py b/tests/modeling_baichuan.py index 3dedbcd9..2a2d4357 100644 --- a/tests/modeling_baichuan.py +++ b/tests/modeling_baichuan.py @@ -300,6 +300,45 @@ class BaichuanPreTrainedModel(PreTrainedModel): if isinstance(module, BaichuanModel): module.gradient_checkpointing = value + @staticmethod + def _convert_to_standard_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, + num_heads, ...])) + """ + batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape + num_heads = batch_size_times_num_heads // batch_size + # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] + # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size, num_heads, head_dim, seq_length), + layer_past[1].view(batch_size, num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + @staticmethod + def _convert_to_baichuan_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...])) + """ + batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape + batch_size_times_num_heads = batch_size * num_heads + # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] + # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + class BaichuanModel(BaichuanPreTrainedModel): @@ -318,9 +357,9 @@ class BaichuanModel(BaichuanPreTrainedModel): def get_input_embeddings(self): return self.embed_tokens - + def set_input_embeddings(self, value): - self.embed_tokens = value + self.embed_tokens = value def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: return build_alibi_tensor(attention_mask, num_heads, dtype) @@ -468,7 +507,7 @@ class BaichuanModel(BaichuanPreTrainedModel): hidden_states=all_hidden_states, attentions=all_self_attns, ) - + class BaichuanForCausalLM(BaichuanPreTrainedModel): @@ -498,7 +537,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel): def get_decoder(self): return self.model - + def forward( self, input_ids: torch.LongTensor = None, @@ -528,7 +567,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - ) + ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) @@ -559,11 +598,20 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel): ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ) -> dict: if past_key_values: input_ids = input_ids[:, -1:] + # the cache may be in the standard format (e.g. in contrastive search) + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = self._convert_to_baichuan_cache(past_key_values) + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -571,21 +619,38 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel): model_inputs = {"input_ids": input_ids} model_inputs.update( - { + { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, - } - ) + } + ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - return tuple( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) - for layer_past in past_key_values - ) + def _reorder_cache( + self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + Output shares the same memory storage as `past`. + """ + standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx)) + + # Get a copy of `beam_idx` on all the devices where we need those indices. + device_to_beam_idx = { + past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past + } + reordered_past = tuple( + ( + layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), + layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), + ) + for layer_past in standardized_past + ) + return self._convert_to_baichuan_cache(reordered_past) def quantize(self, bits: int): try: @@ -594,7 +659,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel): raise ImportError( f"Needs QLinear to run quantize." ) - + for layer in self.model.layers: layer.self_attn.W_pack = QLinear( bits=bits, @@ -621,7 +686,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel): weight=layer.mlp.up_proj.weight, bias = None, ) - return self + return self def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0): max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens