From aff9363ce399fd9f1b29d5a088b49c865001f37a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 26 Oct 2023 16:34:52 +0800 Subject: [PATCH] fix #1285 --- src/llmtuner/tuner/ppo/trainer.py | 8 +++++--- src/llmtuner/tuner/ppo/workflow.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 44df9672..657d658c 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -18,7 +18,7 @@ from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_ if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback from trl import AutoModelForCausalLMWithValueHead - from llmtuner.hparams import ModelArguments, GeneratingArguments + from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments logger = get_logger(__name__) @@ -33,6 +33,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self, model_args: "ModelArguments", training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", callbacks: List["TrainerCallback"], **kwargs @@ -43,6 +44,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.args = training_args self.model_args = model_args + self.finetuning_args = finetuning_args self.generation_config = GenerationConfig( pad_token_id=self.tokenizer.pad_token_id, eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, @@ -162,7 +164,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): r""" Generates model's responses given queries. """ - if self.model_args.upcast_layernorm: + if self.finetuning_args.upcast_layernorm: layernorm_params = dump_layernorm(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) @@ -172,7 +174,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): **batch ) - if self.model_args.upcast_layernorm: + if self.finetuning_args.upcast_layernorm: restore_layernorm(self.model, layernorm_params) query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu() diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 1dd3205e..3fcb72fd 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -66,6 +66,7 @@ def run_ppo( ppo_trainer = CustomPPOTrainer( model_args=model_args, training_args=training_args, + finetuning_args=finetuning_args, generating_args=generating_args, callbacks=callbacks + [SavePeftModelCallback()], config=ppo_config,