fix #1285
This commit is contained in:
parent
d357e08b58
commit
aff9363ce3
|
@ -18,7 +18,7 @@ from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.hparams import ModelArguments, GeneratingArguments
|
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -33,6 +33,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self,
|
self,
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: List["TrainerCallback"],
|
callbacks: List["TrainerCallback"],
|
||||||
**kwargs
|
**kwargs
|
||||||
|
@ -43,6 +44,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
|
|
||||||
self.args = training_args
|
self.args = training_args
|
||||||
self.model_args = model_args
|
self.model_args = model_args
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
self.generation_config = GenerationConfig(
|
self.generation_config = GenerationConfig(
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
|
@ -162,7 +164,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
||||||
"""
|
"""
|
||||||
if self.model_args.upcast_layernorm:
|
if self.finetuning_args.upcast_layernorm:
|
||||||
layernorm_params = dump_layernorm(self.model)
|
layernorm_params = dump_layernorm(self.model)
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
|
@ -172,7 +174,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
**batch
|
**batch
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.model_args.upcast_layernorm:
|
if self.finetuning_args.upcast_layernorm:
|
||||||
restore_layernorm(self.model, layernorm_params)
|
restore_layernorm(self.model, layernorm_params)
|
||||||
|
|
||||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||||
|
|
|
@ -66,6 +66,7 @@ def run_ppo(
|
||||||
ppo_trainer = CustomPPOTrainer(
|
ppo_trainer = CustomPPOTrainer(
|
||||||
model_args=model_args,
|
model_args=model_args,
|
||||||
training_args=training_args,
|
training_args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
generating_args=generating_args,
|
generating_args=generating_args,
|
||||||
callbacks=callbacks + [SavePeftModelCallback()],
|
callbacks=callbacks + [SavePeftModelCallback()],
|
||||||
config=ppo_config,
|
config=ppo_config,
|
||||||
|
|
Loading…
Reference in New Issue