This commit is contained in:
hiyouga 2023-10-26 16:34:52 +08:00
parent d357e08b58
commit aff9363ce3
2 changed files with 6 additions and 3 deletions

View File

@ -18,7 +18,7 @@ from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments, GeneratingArguments from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -33,6 +33,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self, self,
model_args: "ModelArguments", model_args: "ModelArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"], callbacks: List["TrainerCallback"],
**kwargs **kwargs
@ -43,6 +44,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.args = training_args self.args = training_args
self.model_args = model_args self.model_args = model_args
self.finetuning_args = finetuning_args
self.generation_config = GenerationConfig( self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
@ -162,7 +164,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
r""" r"""
Generates model's responses given queries. Generates model's responses given queries.
""" """
if self.model_args.upcast_layernorm: if self.finetuning_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model) layernorm_params = dump_layernorm(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
@ -172,7 +174,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
**batch **batch
) )
if self.model_args.upcast_layernorm: if self.finetuning_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params) restore_layernorm(self.model, layernorm_params)
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu() query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()

View File

@ -66,6 +66,7 @@ def run_ppo(
ppo_trainer = CustomPPOTrainer( ppo_trainer = CustomPPOTrainer(
model_args=model_args, model_args=model_args,
training_args=training_args, training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args, generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()], callbacks=callbacks + [SavePeftModelCallback()],
config=ppo_config, config=ppo_config,