fix bug in PPO training

This commit is contained in:
hiyouga 2023-11-16 02:32:54 +08:00
parent 35b91ea34c
commit 856522a3df
3 changed files with 7 additions and 4 deletions

View File

@ -158,9 +158,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
self.additional_target = split_arg(self.additional_target)
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Lora reward model only supports lora training.")
@ -175,4 +180,5 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))

View File

@ -95,9 +95,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")

View File

@ -57,7 +57,7 @@ def create_reward_model(
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)