fix reward model loading
This commit is contained in:
parent
d92f112951
commit
c52336d144
|
@ -104,7 +104,7 @@ def init_adapter(
|
||||||
def load_valuehead_params(
|
def load_valuehead_params(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
model_args: "ModelArguments"
|
model_args: "ModelArguments"
|
||||||
) -> None:
|
) -> bool:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"path_or_repo_id": model_args.reward_model,
|
"path_or_repo_id": model_args.reward_model,
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
|
@ -117,10 +117,12 @@ def load_valuehead_params(
|
||||||
try:
|
try:
|
||||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||||
except:
|
except:
|
||||||
raise ValueError("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
|
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
|
||||||
|
return False
|
||||||
|
|
||||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
vhead_params = torch.load(vhead_file, map_location="cpu")
|
||||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||||
|
return True
|
||||||
|
|
|
@ -204,7 +204,7 @@ def load_model_and_tokenizer(
|
||||||
reset_logging()
|
reset_logging()
|
||||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
||||||
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
if load_valuehead_params(model, model_args):
|
||||||
model.v_head.load_state_dict({
|
model.v_head.load_state_dict({
|
||||||
"summary.weight": getattr(model, "reward_head_weight"),
|
"summary.weight": getattr(model, "reward_head_weight"),
|
||||||
"summary.bias": getattr(model, "reward_head_bias")
|
"summary.bias": getattr(model, "reward_head_bias")
|
||||||
|
@ -214,7 +214,7 @@ def load_model_and_tokenizer(
|
||||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||||
if getattr(model, "is_peft_model", False):
|
if getattr(model, "is_peft_model", False):
|
||||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
||||||
load_valuehead_params(model, model_args)
|
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
|
||||||
|
|
||||||
# Prepare model for inference
|
# Prepare model for inference
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
|
|
|
@ -59,13 +59,15 @@ def run_dpo(
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
if training_args.push_to_hub:
|
|
||||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
|
||||||
else:
|
|
||||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
# Create model card
|
||||||
|
if training_args.do_train:
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||||
|
else:
|
||||||
|
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||||
|
|
|
@ -45,11 +45,6 @@ def run_pt(
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
if training_args.push_to_hub:
|
|
||||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
|
||||||
else:
|
|
||||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
|
@ -61,3 +56,10 @@ def run_pt(
|
||||||
metrics["perplexity"] = perplexity
|
metrics["perplexity"] = perplexity
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
# Create model card
|
||||||
|
if training_args.do_train:
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||||
|
else:
|
||||||
|
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||||
|
|
|
@ -53,11 +53,6 @@ def run_rm(
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
if training_args.push_to_hub:
|
|
||||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
|
||||||
else:
|
|
||||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
|
@ -70,3 +65,10 @@ def run_rm(
|
||||||
trainer.log_metrics("predict", predict_results.metrics)
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_metrics("predict", predict_results.metrics)
|
trainer.save_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_predictions(predict_results)
|
trainer.save_predictions(predict_results)
|
||||||
|
|
||||||
|
# Create model card
|
||||||
|
if training_args.do_train:
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||||
|
else:
|
||||||
|
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||||
|
|
|
@ -72,11 +72,6 @@ def run_sft(
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
if training_args.push_to_hub:
|
|
||||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
|
||||||
else:
|
|
||||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||||
|
@ -93,3 +88,10 @@ def run_sft(
|
||||||
trainer.log_metrics("predict", predict_results.metrics)
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_metrics("predict", predict_results.metrics)
|
trainer.save_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_predictions(predict_results)
|
trainer.save_predictions(predict_results)
|
||||||
|
|
||||||
|
# Create model card
|
||||||
|
if training_args.do_train:
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||||
|
else:
|
||||||
|
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||||
|
|
Loading…
Reference in New Issue