fix rlhf callback
This commit is contained in:
parent
856522a3df
commit
1817ffc86f
|
@ -12,6 +12,7 @@ from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||||
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -25,16 +26,22 @@ class SavePeftModelCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||||
model = kwargs.pop("model")
|
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
||||||
|
model.pretrained_model.config.save_pretrained(output_dir)
|
||||||
|
if model.pretrained_model.can_generate():
|
||||||
|
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||||
if getattr(model, "is_peft_model", False):
|
if getattr(model, "is_peft_model", False):
|
||||||
getattr(model, "pretrained_model").save_pretrained(output_dir)
|
model.pretrained_model.save_pretrained(output_dir)
|
||||||
|
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of training.
|
Event called at the end of training.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
model = kwargs.pop("model")
|
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
||||||
|
model.pretrained_model.config.save_pretrained(args.output_dir)
|
||||||
|
if model.pretrained_model.can_generate():
|
||||||
|
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
|
||||||
if getattr(model, "is_peft_model", False):
|
if getattr(model, "is_peft_model", False):
|
||||||
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
|
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|
|
@ -166,7 +166,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||||
if self.stage == "ppo" and self.reward_model is None:
|
if self.stage == "ppo" and self.reward_model is None:
|
||||||
raise ValueError("Reward model is necessary for PPO training.")
|
raise ValueError("Reward model is necessary for PPO training.")
|
||||||
|
|
||||||
if self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||||
raise ValueError("Lora reward model only supports lora training.")
|
raise ValueError("Lora reward model only supports lora training.")
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
|
|
|
@ -203,12 +203,13 @@ def load_model_and_tokenizer(
|
||||||
if stage in ["rm", "ppo"]:
|
if stage in ["rm", "ppo"]:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
reset_logging()
|
reset_logging()
|
||||||
if model_args.checkpoint_dir is not None: # load valuehead weights if exists
|
vhead_path = (
|
||||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
||||||
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
|
)
|
||||||
if vhead_params is not None:
|
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||||
model.load_state_dict(vhead_params, strict=False)
|
if vhead_params is not None:
|
||||||
logger.info("Loaded valuehead from checkpoint: {}".format(model_args.checkpoint_dir[-1]))
|
model.load_state_dict(vhead_params, strict=False)
|
||||||
|
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
||||||
|
|
||||||
# Prepare model for inference
|
# Prepare model for inference
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
|
|
|
@ -45,6 +45,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.args = training_args
|
self.args = training_args
|
||||||
self.model_args = model_args
|
self.model_args = model_args
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
|
self.reward_model = reward_model
|
||||||
|
|
||||||
self.generation_config = GenerationConfig(
|
self.generation_config = GenerationConfig(
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
@ -72,8 +73,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
||||||
else:
|
else:
|
||||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||||
else:
|
|
||||||
self.reward_model = None
|
|
||||||
|
|
||||||
def ppo_train(self) -> None:
|
def ppo_train(self) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
|
Loading…
Reference in New Issue