fix reward model loading

This commit is contained in:
hiyouga 2023-11-07 17:20:51 +08:00
parent d92f112951
commit c52336d144
6 changed files with 34 additions and 24 deletions

View File

@ -104,7 +104,7 @@ def init_adapter(
def load_valuehead_params(
model: "PreTrainedModel",
model_args: "ModelArguments"
) -> None:
) -> bool:
kwargs = {
"path_or_repo_id": model_args.reward_model,
"cache_dir": model_args.cache_dir,
@ -117,10 +117,12 @@ def load_valuehead_params(
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
raise ValueError("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
return False
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True

View File

@ -204,7 +204,7 @@ def load_model_and_tokenizer(
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
if load_valuehead_params(model, model_args):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
@ -214,7 +214,7 @@ def load_model_and_tokenizer(
logger.info("Load reward model from {}".format(model_args.reward_model))
if getattr(model, "is_peft_model", False):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
load_valuehead_params(model, model_args)
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
# Prepare model for inference
if not is_trainable:

View File

@ -59,13 +59,15 @@ def run_dpo(
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))

View File

@ -45,11 +45,6 @@ def run_pt(
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
@ -61,3 +56,10 @@ def run_pt(
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))

View File

@ -53,11 +53,6 @@ def run_rm(
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
@ -70,3 +65,10 @@ def run_rm(
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))

View File

@ -72,11 +72,6 @@ def run_sft(
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
@ -93,3 +88,10 @@ def run_sft(
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))