forked from p04798526/LLaMA-Factory-Mirror
fix #1422
This commit is contained in:
parent
c52336d144
commit
11c1e1e157
|
@ -100,11 +100,13 @@ def get_train_args(
|
|||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||
if finetuning_args.stage in ["rm", "ppo"]:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||
if training_args.load_best_model_at_end:
|
||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||
raise ValueError("PPO training does not support evaluation.")
|
||||
|
|
|
@ -33,6 +33,12 @@ def run_dpo(
|
|||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
# Create reference model
|
||||
ref_model = None
|
||||
if not isinstance(model, PeftModel):
|
||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
@ -41,7 +47,7 @@ def run_dpo(
|
|||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
model=model,
|
||||
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
|
|
|
@ -190,8 +190,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
|
||||
if len(response_index) == 0:
|
||||
response_length = 1 # allow empty response
|
||||
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
||||
response_length = response_index[-1].item() + 2 # save the EOS token
|
||||
else:
|
||||
response_length = response_index[-1].item() + 1
|
||||
|
||||
|
@ -221,7 +219,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token
|
||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
|
|
@ -34,7 +34,7 @@ class PairwiseTrainer(Trainer):
|
|||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Note that the first element will be removed from the output tuple.
|
||||
Note that the first element will be removed from the output tuple.
|
||||
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
||||
"""
|
||||
# Compute rewards
|
||||
|
@ -45,9 +45,6 @@ class PairwiseTrainer(Trainer):
|
|||
# Split the inputs and rewards into two parts, chosen and rejected
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
|
||||
chosen_attn_mask, rejected_attn_mask = (
|
||||
inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:]
|
||||
)
|
||||
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
|
||||
chosen_scores, rejected_scores = [], []
|
||||
|
||||
|
@ -55,8 +52,8 @@ class PairwiseTrainer(Trainer):
|
|||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
|
||||
loss = 0
|
||||
for i in range(batch_size):
|
||||
chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1
|
||||
rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1
|
||||
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
|
||||
|
||||
if len(check_divergence) == 0:
|
||||
|
@ -69,7 +66,7 @@ class PairwiseTrainer(Trainer):
|
|||
assert div_index > 0
|
||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||
if return_outputs: # use the score on the EOS token for inference
|
||||
if return_outputs: # use the score on the last token except pad token for inference
|
||||
chosen_scores.append(chosen_rewards[i, chosen_length-1])
|
||||
rejected_scores.append(rejected_rewards[i, rejected_length-1])
|
||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||
|
@ -95,7 +92,6 @@ class PairwiseTrainer(Trainer):
|
|||
|
||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||
|
||||
chosen_scores, rejected_scores = predict_results.predictions
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
|
|
|
@ -28,6 +28,7 @@ def run_rm(
|
|||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
|
Loading…
Reference in New Issue