diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 25330545..4c2984b1 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -104,7 +104,7 @@ def init_adapter( def load_valuehead_params( model: "PreTrainedModel", model_args: "ModelArguments" -) -> None: +) -> bool: kwargs = { "path_or_repo_id": model_args.reward_model, "cache_dir": model_args.cache_dir, @@ -117,10 +117,12 @@ def load_valuehead_params( try: vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs) except: - raise ValueError("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model)) + logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model)) + return False vhead_params = torch.load(vhead_file, map_location="cpu") model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False) model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False) + return True diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 829599c3..2931f087 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -204,7 +204,7 @@ def load_model_and_tokenizer( reset_logging() if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model logger.warning("Only the last checkpoint containing valuehead will be loaded.") - if load_valuehead_params(model, model_args.checkpoint_dir[-1]): + if load_valuehead_params(model, model_args): model.v_head.load_state_dict({ "summary.weight": getattr(model, "reward_head_weight"), "summary.bias": getattr(model, "reward_head_bias") @@ -214,7 +214,7 @@ def load_model_and_tokenizer( logger.info("Load reward model from {}".format(model_args.reward_model)) if getattr(model, "is_peft_model", False): model.pretrained_model.load_adapter(model_args.reward_model, "reward") - load_valuehead_params(model, model_args) + assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded." # Prepare model for inference if not is_trainable: diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/tuner/dpo/workflow.py index c4acb331..99bd4cc6 100644 --- a/src/llmtuner/tuner/dpo/workflow.py +++ b/src/llmtuner/tuner/dpo/workflow.py @@ -59,13 +59,15 @@ def run_dpo( if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) - if training_args.push_to_hub: - trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) - else: - trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args)) - # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval") trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + + # Create model card + if training_args.do_train: + if training_args.push_to_hub: + trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) + else: + trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args)) diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/tuner/pt/workflow.py index c7edff21..ab0e0206 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/tuner/pt/workflow.py @@ -45,11 +45,6 @@ def run_pt( if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) - if training_args.push_to_hub: - trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) - else: - trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args)) - # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval") @@ -61,3 +56,10 @@ def run_pt( metrics["perplexity"] = perplexity trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + + # Create model card + if training_args.do_train: + if training_args.push_to_hub: + trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) + else: + trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args)) diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index eedec5e7..dffa5517 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -53,11 +53,6 @@ def run_rm( if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) - if training_args.push_to_hub: - trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) - else: - trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args)) - # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval") @@ -70,3 +65,10 @@ def run_rm( trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) trainer.save_predictions(predict_results) + + # Create model card + if training_args.do_train: + if training_args.push_to_hub: + trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) + else: + trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args)) diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 04b37ac7..ef902fe7 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -72,11 +72,6 @@ def run_sft( if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) - if training_args.push_to_hub: - trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) - else: - trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args)) - # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) @@ -93,3 +88,10 @@ def run_sft( trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) trainer.save_predictions(predict_results) + + # Create model card + if training_args.do_train: + if training_args.push_to_hub: + trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) + else: + trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))