diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 20b9b5d4..18a33043 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,4 +1,3 @@ -import os import math import torch from types import MethodType @@ -202,6 +201,7 @@ def load_model_and_tokenizer( # Prepare model with valuehead for RLHF if stage in ["rm", "ppo"]: model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) + setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name]) vhead_path = ( model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path ) diff --git a/src/llmtuner/model/parser.py b/src/llmtuner/model/parser.py index 051978b8..d298996e 100644 --- a/src/llmtuner/model/parser.py +++ b/src/llmtuner/model/parser.py @@ -40,6 +40,18 @@ _EVAL_CLS = Tuple[ ] +def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None: + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") + + if ( + model_args.checkpoint_dir is not None + and len(model_args.checkpoint_dir) != 1 + and finetuning_args.finetuning_type != "lora" + ): + raise ValueError("Multiple checkpoints are only available for LoRA tuning.") + + def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: parser = HfArgumentParser(_TRAIN_ARGS) return parse_args(parser, args) @@ -81,19 +93,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: raise ValueError("Please enable `predict_with_generate` to save model predictions.") - if finetuning_args.stage in ["rm", "ppo"]: - if training_args.resume_from_checkpoint is not None: - raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.") - if training_args.load_best_model_at_end: - raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.") + if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end: + raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.") if finetuning_args.stage == "ppo" and not training_args.do_train: raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.") - if finetuning_args.stage in ["rm", "dpo"]: - for dataset_attr in data_args.dataset_list: - if not dataset_attr.ranking: - raise ValueError("Please use ranked datasets for reward modeling or DPO training.") + if finetuning_args.stage in ["rm", "dpo"] and (not all([data_attr.ranking for data_attr in data_args.dataset_list])): + raise ValueError("Please use ranked datasets for reward modeling or DPO training.") if finetuning_args.stage == "ppo" and model_args.shift_attn: raise ValueError("PPO training is incompatible with S^2-Attn.") @@ -107,15 +114,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None: raise ValueError("Please specify `lora_target` in LoRA training.") - if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": - raise ValueError("Quantization is only compatible with the LoRA method.") - - if ( - model_args.checkpoint_dir is not None - and len(model_args.checkpoint_dir) != 1 - and finetuning_args.finetuning_type != "lora" - ): - raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + _verify_model_args(model_args, finetuning_args) if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm): logger.warning("We recommend enable `upcast_layernorm` in quantized training.") @@ -154,9 +153,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: training_args_dict = training_args.to_dict() training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint)) training_args = Seq2SeqTrainingArguments(**training_args_dict) - logger.info( - "Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid." - ) + logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format( + training_args.resume_from_checkpoint + )) + + if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None: + logger.warning("Add {} to `checkpoint_dir` to resume training from checkpoint.".format( + training_args.resume_from_checkpoint + )) # postprocess model_args model_args.compute_dtype = ( @@ -183,15 +187,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: if data_args.template is None: raise ValueError("Please specify which `template` to use.") - if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": - raise ValueError("Quantization is only compatible with the LoRA method.") - - if ( - model_args.checkpoint_dir is not None - and len(model_args.checkpoint_dir) != 1 - and finetuning_args.finetuning_type != "lora" - ): - raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + _verify_model_args(model_args, finetuning_args) return model_args, data_args, finetuning_args, generating_args @@ -202,8 +198,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: if data_args.template is None: raise ValueError("Please specify which `template` to use.") - if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": - raise ValueError("Quantization is only compatible with the LoRA method.") + _verify_model_args(model_args, finetuning_args) transformers.set_seed(eval_args.seed) diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index 949e2ce8..a8bae1f4 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -74,10 +74,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer): else: self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) - def ppo_train(self) -> None: + def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. """ + if resume_from_checkpoint is not None: + raise ValueError("`resume_from_checkpoint` will be supported in the future version.") + total_train_batch_size = ( self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size ) diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py index 41a99e2c..2a779edd 100644 --- a/src/llmtuner/train/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -94,7 +94,7 @@ def run_ppo( # Training if training_args.do_train: - ppo_trainer.ppo_train() + ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint) ppo_trainer.save_model() ppo_trainer.save_state() # must be called after save_model to have a folder if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: diff --git a/src/llmtuner/train/rm/workflow.py b/src/llmtuner/train/rm/workflow.py index a2e97e8d..ecc409b7 100644 --- a/src/llmtuner/train/rm/workflow.py +++ b/src/llmtuner/train/rm/workflow.py @@ -47,7 +47,7 @@ def run_rm( # Training if training_args.do_train: - train_result = trainer.train() + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)