fix #1285
This commit is contained in:
parent
d357e08b58
commit
aff9363ce3
|
@ -18,7 +18,7 @@ from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_
|
|||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.hparams import ModelArguments, GeneratingArguments
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -33,6 +33,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
self,
|
||||
model_args: "ModelArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: List["TrainerCallback"],
|
||||
**kwargs
|
||||
|
@ -43,6 +44,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
|
||||
self.args = training_args
|
||||
self.model_args = model_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generation_config = GenerationConfig(
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
|
@ -162,7 +164,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
if self.model_args.upcast_layernorm:
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(self.model)
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
|
@ -172,7 +174,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
**batch
|
||||
)
|
||||
|
||||
if self.model_args.upcast_layernorm:
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
restore_layernorm(self.model, layernorm_params)
|
||||
|
||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
|
|
|
@ -66,6 +66,7 @@ def run_ppo(
|
|||
ppo_trainer = CustomPPOTrainer(
|
||||
model_args=model_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
config=ppo_config,
|
||||
|
|
Loading…
Reference in New Issue