diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 57f0b848..8b89e38a 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -70,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", - callbacks: List["TrainerCallback"], + callbacks: Optional[List["TrainerCallback"]], model: "AutoModelForCausalLMWithValueHead", reward_model: Optional["AutoModelForCausalLMWithValueHead"], ref_model: Optional["AutoModelForCausalLMWithValueHead"], @@ -78,7 +78,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): processor: Optional["ProcessorMixin"], dataset: "Dataset", data_collator: "DataCollatorWithPadding", - ): + ) -> None: backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps ppo_config = PPOConfig( model_name=model_args.model_name_or_path, @@ -144,7 +144,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None self.callback_handler = CallbackHandler( - [callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler + callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler ) if self.args.max_steps > 0: diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index 651296f3..df22dae5 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -22,7 +22,7 @@ from transformers import DataCollatorWithPadding from ...data import get_dataset from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint +from ..callbacks import fix_valuehead_checkpoint from ..trainer_utils import create_ref_model, create_reward_model from .trainer import CustomPPOTrainer @@ -59,7 +59,7 @@ def run_ppo( training_args=training_args, finetuning_args=finetuning_args, generating_args=generating_args, - callbacks=callbacks + [FixValueHeadModelCallback()], + callbacks=callbacks, model=model, reward_model=reward_model, ref_model=ref_model,