forked from p04798526/LLaMA-Factory-Mirror
fix ppo callbacks
This commit is contained in:
parent
33f2ddb8b6
commit
4c296001c4
|
@ -70,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: List["TrainerCallback"],
|
callbacks: Optional[List["TrainerCallback"]],
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
|
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
|
||||||
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
|
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
|
||||||
|
@ -78,7 +78,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
dataset: "Dataset",
|
dataset: "Dataset",
|
||||||
data_collator: "DataCollatorWithPadding",
|
data_collator: "DataCollatorWithPadding",
|
||||||
):
|
) -> None:
|
||||||
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
|
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
|
||||||
ppo_config = PPOConfig(
|
ppo_config = PPOConfig(
|
||||||
model_name=model_args.model_name_or_path,
|
model_name=model_args.model_name_or_path,
|
||||||
|
@ -144,7 +144,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||||
self.callback_handler = CallbackHandler(
|
self.callback_handler = CallbackHandler(
|
||||||
[callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
|
|
|
@ -22,7 +22,7 @@ from transformers import DataCollatorWithPadding
|
||||||
from ...data import get_dataset
|
from ...data import get_dataset
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint
|
from ..callbacks import fix_valuehead_checkpoint
|
||||||
from ..trainer_utils import create_ref_model, create_reward_model
|
from ..trainer_utils import create_ref_model, create_reward_model
|
||||||
from .trainer import CustomPPOTrainer
|
from .trainer import CustomPPOTrainer
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ def run_ppo(
|
||||||
training_args=training_args,
|
training_args=training_args,
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
generating_args=generating_args,
|
generating_args=generating_args,
|
||||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
callbacks=callbacks,
|
||||||
model=model,
|
model=model,
|
||||||
reward_model=reward_model,
|
reward_model=reward_model,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
|
|
Loading…
Reference in New Issue