fix ppo callbacks

This commit is contained in:
hiyouga 2024-07-02 17:34:56 +08:00
parent 33f2ddb8b6
commit 4c296001c4
2 changed files with 5 additions and 5 deletions

View File

@ -70,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
callbacks: Optional[List["TrainerCallback"]],
model: "AutoModelForCausalLMWithValueHead",
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
@ -78,7 +78,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
processor: Optional["ProcessorMixin"],
dataset: "Dataset",
data_collator: "DataCollatorWithPadding",
):
) -> None:
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
@ -144,7 +144,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
self.callback_handler = CallbackHandler(
[callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
)
if self.args.max_steps > 0:

View File

@ -22,7 +22,7 @@ from transformers import DataCollatorWithPadding
from ...data import get_dataset
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint
from ..callbacks import fix_valuehead_checkpoint
from ..trainer_utils import create_ref_model, create_reward_model
from .trainer import CustomPPOTrainer
@ -59,7 +59,7 @@ def run_ppo(
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [FixValueHeadModelCallback()],
callbacks=callbacks,
model=model,
reward_model=reward_model,
ref_model=ref_model,