diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 4e0a8122..ca304761 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -158,9 +158,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): self.additional_target = split_arg(self.additional_target) self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint) self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint) + assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + + if self.stage == "ppo" and self.reward_model is None: + raise ValueError("Reward model is necessary for PPO training.") + if self.reward_model_type == "lora" and self.finetuning_type != "lora": raise ValueError("Lora reward model only supports lora training.") @@ -175,4 +180,5 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): r"""Creates an instance from the content of `json_path`.""" with open(json_path, "r", encoding="utf-8") as f: text = f.read() + return cls(**json.loads(text)) diff --git a/src/llmtuner/model/parser.py b/src/llmtuner/model/parser.py index 64f48e17..051978b8 100644 --- a/src/llmtuner/model/parser.py +++ b/src/llmtuner/model/parser.py @@ -95,9 +95,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if not dataset_attr.ranking: raise ValueError("Please use ranked datasets for reward modeling or DPO training.") - if finetuning_args.stage == "ppo" and model_args.reward_model is None: - raise ValueError("Reward model is necessary for PPO training.") - if finetuning_args.stage == "ppo" and model_args.shift_attn: raise ValueError("PPO training is incompatible with S^2-Attn.") diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 70307b18..807c44b5 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -57,7 +57,7 @@ def create_reward_model( for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090 if "default" in name: param.data = param.data.to(torch.float32) # trainable params should in fp32 - vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args) + vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args) assert vhead_params is not None, "Reward model is not correctly loaded." model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)