This commit is contained in:
hiyouga 2023-08-23 20:21:15 +08:00
parent 1c702ad538
commit 57146c101f
4 changed files with 19 additions and 13 deletions

View File

@ -156,10 +156,9 @@ def get_train_args(
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False
if training_args.optim == "adamw_hf":
training_args.optim = "adamw_torch" # suppress warning
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(ddp_find_unused_parameters=False))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
if (
training_args.resume_from_checkpoint is None
@ -172,7 +171,9 @@ def get_train_args(
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
logger.info(
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
)

View File

@ -31,13 +31,14 @@ def run_dpo(
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
training_args.remove_unused_columns = False # important for pairwise dataset
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
ref_model=ref_model,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
model=model,
args=training_args,
tokenizer=tokenizer,

View File

@ -28,7 +28,9 @@ def run_rm(
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
training_args.remove_unused_columns = False # important for pairwise dataset
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = PairwisePeftTrainer(

View File

@ -33,10 +33,12 @@ def run_sft(
)
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \
training_args.generation_max_length is not None else data_args.max_target_length
training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.eval_num_beams is not None else training_args.generation_num_beams
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(
generation_max_length=training_args.generation_max_length or data_args.max_target_length,
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = Seq2SeqPeftTrainer(