fix #1567
This commit is contained in:
parent
00baaa990e
commit
99a3f06377
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import math
|
||||
import torch
|
||||
from types import MethodType
|
||||
|
@ -202,6 +201,7 @@ def load_model_and_tokenizer(
|
|||
# Prepare model with valuehead for RLHF
|
||||
if stage in ["rm", "ppo"]:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
|
||||
vhead_path = (
|
||||
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
||||
)
|
||||
|
|
|
@ -40,6 +40,18 @@ _EVAL_CLS = Tuple[
|
|||
]
|
||||
|
||||
|
||||
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Multiple checkpoints are only available for LoRA tuning.")
|
||||
|
||||
|
||||
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
return parse_args(parser, args)
|
||||
|
@ -81,19 +93,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"]:
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||
if training_args.load_best_model_at_end:
|
||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
|
||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
||||
|
||||
if finetuning_args.stage in ["rm", "dpo"]:
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if not dataset_attr.ranking:
|
||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||
if finetuning_args.stage in ["rm", "dpo"] and (not all([data_attr.ranking for data_attr in data_args.dataset_list])):
|
||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
@ -107,15 +114,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
||||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
|
@ -154,9 +153,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
logger.info(
|
||||
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
|
||||
)
|
||||
logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
))
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||
logger.warning("Add {} to `checkpoint_dir` to resume training from checkpoint.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
))
|
||||
|
||||
# postprocess model_args
|
||||
model_args.compute_dtype = (
|
||||
|
@ -183,15 +187,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
@ -202,8 +198,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
|||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
|
||||
transformers.set_seed(eval_args.seed)
|
||||
|
||||
|
|
|
@ -74,10 +74,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
else:
|
||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||
|
||||
def ppo_train(self) -> None:
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
"""
|
||||
if resume_from_checkpoint is not None:
|
||||
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
|
||||
|
||||
total_train_batch_size = (
|
||||
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
|
||||
)
|
||||
|
|
|
@ -94,7 +94,7 @@ def run_ppo(
|
|||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
ppo_trainer.ppo_train()
|
||||
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
|
|
|
@ -47,7 +47,7 @@ def run_rm(
|
|||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
|
|
Loading…
Reference in New Issue