From 4a6ca621c09d179561acc5957c8c911a4e44184c Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 1 Apr 2024 22:53:52 +0800 Subject: [PATCH] fix #3083 --- src/llmtuner/model/patcher.py | 29 +++++++++++++++++------------ src/llmtuner/train/dpo/trainer.py | 5 ++++- src/llmtuner/train/orpo/trainer.py | 2 +- src/llmtuner/train/ppo/trainer.py | 2 +- 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 97399a2c..db9849cf 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -235,6 +235,12 @@ def _configure_quantization( logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) +def _fp32_forward_post_hook( + module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" +) -> "torch.Tensor": + return output.to(torch.float32) + + def _prepare_model_for_training( model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head" ) -> None: @@ -263,14 +269,10 @@ def _prepare_model_for_training( logger.info("Gradient checkpointing enabled.") if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: - - def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor): - return output.to(torch.float32) - logger.info("Upcasting lm_head outputs in float32.") output_layer = getattr(model, output_layer_name) if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: - output_layer.register_forward_hook(fp32_forward_post_hook) + output_layer.register_forward_hook(_fp32_forward_post_hook) def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: @@ -316,13 +318,6 @@ def patch_config( def patch_model( model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool ) -> None: - if "GenerationMixin" not in str(model.generate.__func__): - model.generate = MethodType(PreTrainedModel.generate, model) - - if getattr(model.config, "model_type", None) == "chatglm": - setattr(model, "lm_head", model.transformer.output_layer) - setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) - gen_config = model.generation_config # check and fix generation config if not gen_config.do_sample and ( (gen_config.temperature is not None and gen_config.temperature != 1.0) @@ -331,6 +326,16 @@ def patch_model( ): gen_config.do_sample = True + if "GenerationMixin" not in str(model.generate.__func__): + model.generate = MethodType(PreTrainedModel.generate, model) + + if is_trainable and getattr(model.config, "model_type", None) == "chatglm": + setattr(model, "lm_head", model.transformer.output_layer) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + + if is_trainable and getattr(model.config, "model_type", None) == "qwen2" and model_args.flash_attn: + setattr(model.config, "use_cache", False) # qwen2 does not support use_cache when using flashattn + if is_trainable and model_args.resize_vocab: _resize_embedding_layer(model, tokenizer) diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index 11727420..0b316c62 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -95,7 +95,10 @@ class CustomDPOTrainer(DPOTrainer): batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error all_logits: "torch.Tensor" = model( - input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True + input_ids=batch_copied["input_ids"], + attention_mask=batch_copied["attention_mask"], + return_dict=True, + use_cache=False, ).logits.to(torch.float32) all_logps = self.get_batch_logps( diff --git a/src/llmtuner/train/orpo/trainer.py b/src/llmtuner/train/orpo/trainer.py index f5b7ff42..d84e0199 100644 --- a/src/llmtuner/train/orpo/trainer.py +++ b/src/llmtuner/train/orpo/trainer.py @@ -73,7 +73,7 @@ class CustomORPOTrainer(DPOTrainer): Computes the average log probabilities of the labels under the given logits. """ all_logits: "torch.Tensor" = model( - input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True + input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False ).logits.to(torch.float32) all_logps = self.get_batch_logps( diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index de87532a..6be45958 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -353,7 +353,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): batch = self.prepare_model_inputs(queries, responses) with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 - _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True) + _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False) if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture values = torch.transpose(values, 0, 1)