This commit is contained in:
hiyouga 2023-10-26 16:34:52 +08:00
parent d357e08b58
commit aff9363ce3
2 changed files with 6 additions and 3 deletions

View File

@ -18,7 +18,7 @@ from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments, GeneratingArguments
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
@ -33,6 +33,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self,
model_args: "ModelArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
**kwargs
@ -43,6 +44,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.args = training_args
self.model_args = model_args
self.finetuning_args = finetuning_args
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
@ -162,7 +164,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Generates model's responses given queries.
"""
if self.model_args.upcast_layernorm:
if self.finetuning_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
@ -172,7 +174,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
**batch
)
if self.model_args.upcast_layernorm:
if self.finetuning_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()

View File

@ -66,6 +66,7 @@ def run_ppo(
ppo_trainer = CustomPPOTrainer(
model_args=model_args,
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()],
config=ppo_config,