fix rlhf callback

This commit is contained in:
hiyouga 2023-11-16 03:26:19 +08:00
parent 856522a3df
commit 1817ffc86f
4 changed files with 19 additions and 12 deletions

View File

@ -12,6 +12,7 @@ from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl from transformers import TrainingArguments, TrainerState, TrainerControl
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__) logger = get_logger(__name__)
@ -25,16 +26,22 @@ class SavePeftModelCallback(TrainerCallback):
""" """
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
model = kwargs.pop("model") model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False): if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(output_dir) model.pretrained_model.save_pretrained(output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of training. Event called at the end of training.
""" """
if args.should_save: if args.should_save:
model = kwargs.pop("model") model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(args.output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
if getattr(model, "is_peft_model", False): if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(args.output_dir) getattr(model, "pretrained_model").save_pretrained(args.output_dir)

View File

@ -166,7 +166,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
if self.stage == "ppo" and self.reward_model is None: if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.") raise ValueError("Reward model is necessary for PPO training.")
if self.reward_model_type == "lora" and self.finetuning_type != "lora": if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Lora reward model only supports lora training.") raise ValueError("Lora reward model only supports lora training.")
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):

View File

@ -203,12 +203,13 @@ def load_model_and_tokenizer(
if stage in ["rm", "ppo"]: if stage in ["rm", "ppo"]:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging() reset_logging()
if model_args.checkpoint_dir is not None: # load valuehead weights if exists vhead_path = (
logger.warning("Only the last checkpoint containing valuehead will be loaded.") model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args) )
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None: if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False) model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(model_args.checkpoint_dir[-1])) logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
# Prepare model for inference # Prepare model for inference
if not is_trainable: if not is_trainable:

View File

@ -45,6 +45,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.args = training_args self.args = training_args
self.model_args = model_args self.model_args = model_args
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.generation_config = GenerationConfig( self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
@ -72,8 +73,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.reward_model = self._prepare_deepspeed(self.reward_model) self.reward_model = self._prepare_deepspeed(self.reward_model)
else: else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
else:
self.reward_model = None
def ppo_train(self) -> None: def ppo_train(self) -> None:
r""" r"""