From 2818af0b0967d7695f27658acac0b7e2c2728e5d Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 11 Oct 2023 23:16:01 +0800 Subject: [PATCH] refactor model_dtype, fix PPO trainer --- src/llmtuner/extras/misc.py | 28 +++++++++++ src/llmtuner/extras/patches/llama_patch.py | 8 ++-- src/llmtuner/hparams/model_args.py | 6 +-- src/llmtuner/tuner/core/loader.py | 21 +++++---- src/llmtuner/tuner/core/parser.py | 55 +++------------------- src/llmtuner/tuner/core/utils.py | 11 +++-- src/llmtuner/tuner/ppo/trainer.py | 50 +++++++++----------- src/llmtuner/tuner/ppo/utils.py | 39 +++++++-------- src/llmtuner/tuner/ppo/workflow.py | 2 +- src/llmtuner/webui/runner.py | 3 ++ 10 files changed, 104 insertions(+), 119 deletions(-) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index f9ee2bea..960d43ee 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -3,6 +3,19 @@ import torch from typing import TYPE_CHECKING, Tuple from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList +try: + from transformers.utils import ( + is_torch_bf16_cpu_available, + is_torch_bf16_gpu_available, + is_torch_cuda_available, + is_torch_npu_available + ) + _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() + _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available +except ImportError: + _is_fp16_available = torch.cuda.is_available() + _is_bf16_available = torch.cuda.is_bf16_supported() + if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel @@ -49,7 +62,22 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: return trainable_params, all_param +def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: + r""" + Infers the optimal dtype according to the model_dtype and device compatibility. + """ + if _is_bf16_available and model_dtype == torch.bfloat16: + return torch.bfloat16 + elif _is_fp16_available: + return torch.float16 + else: + return torch.float32 + + def get_logits_processor() -> LogitsProcessorList: + r""" + Gets logits processor that removes NaN and Inf logits. + """ logits_processor = LogitsProcessorList() logits_processor.append(InfNanRemoveLogitsProcessor()) return logits_processor diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index e516df76..a8473311 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -138,11 +138,11 @@ class LlamaFlashAttention2(LlamaAttention): input_dtype = query_states.dtype if input_dtype == torch.float32: logger.warning_once("The input hidden states seems to be silently casted in float32.") - query_states = query_states.to(torch.float16) - key_states = key_states.to(torch.float16) - value_states = value_states.to(torch.float16) + query_states = query_states.to(self.config.torch_dtype) + key_states = key_states.to(self.config.torch_dtype) + value_states = value_states.to(self.config.torch_dtype) - if getattr(self, "num_key_value_groups"): + if getattr(self, "num_key_value_groups", None): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index a26f8aa2..a3d6d917 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -67,9 +67,9 @@ class ModelArguments: default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."} ) - layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field( - default="auto", - metadata={"help": "Data type of the layer norm weights."} + upcast_layernorm: Optional[bool] = field( + default=False, + metadata={"help": "Whether to upcast the layernorm weights in fp32."} ) def __post_init__(self): diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 820307a7..4a81548a 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -24,7 +24,7 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v from transformers.deepspeed import is_deepspeed_zero3_enabled from llmtuner.extras.logging import reset_logging, get_logger -from llmtuner.extras.misc import count_parameters +from llmtuner.extras.misc import count_parameters, infer_optim_dtype from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.extras.save_and_load import load_valuehead_params from llmtuner.hparams import FinetuningArguments @@ -86,11 +86,17 @@ def load_model_and_tokenizer( if getattr(config, "model_type", None) == "chatglm": tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) + # Set model dtype + if model_args.compute_dtype is not None: + setattr(config, "torch_dtype", model_args.compute_dtype) + else: # priority: bf16 > fp16 > fp32 + optim_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + setattr(config, "torch_dtype", optim_dtype) + # Fix config (for Qwen) if getattr(config, "model_type", None) == "qwen": - setattr(config, "fp16", model_args.compute_dtype == torch.float16) - setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16) - setattr(config, "fp32", model_args.compute_dtype == torch.float32) + for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: + setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype) # Set RoPE scaling if model_args.rope_scaling is not None: @@ -131,9 +137,7 @@ def load_model_and_tokenizer( if model_args.flash_attn: if getattr(config, "model_type", None) == "llama": LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 - LlamaModule.LlamaModel._prepare_decoder_attention_mask = ( - LlamaPatches._prepare_decoder_attention_mask - ) + LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask logger.info("Using FlashAttention-2 for faster training and inference.") elif getattr(config, "model_type", None) == "qwen": logger.info("Qwen models automatically enable FlashAttention if installed.") @@ -180,7 +184,6 @@ def load_model_and_tokenizer( model = AutoModelForCausalLM.from_pretrained( model_to_load, config=config, - torch_dtype=model_args.compute_dtype, low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), **config_kwargs ) @@ -203,7 +206,7 @@ def load_model_and_tokenizer( # Initialize adapters if is_trainable: - model = prepare_model_for_training(model, model_args.layernorm_dtype, finetuning_args.finetuning_type) + model = prepare_model_for_training(model, model_args.upcast_layernorm, finetuning_args.finetuning_type) model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) model = model.train() if is_trainable else model.eval() diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index ff3ddff5..a8ed914a 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -8,16 +8,6 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers.utils.versions import require_version from transformers.trainer_utils import get_last_checkpoint -try: - from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available - is_fp16_available = is_torch_cuda_available() - is_bf16_available = is_torch_bf16_gpu_available() - is_npu_available = is_torch_npu_available() -except ImportError: - is_fp16_available = torch.cuda.is_available() - is_bf16_available = torch.cuda.is_bf16_supported() - is_npu_available = False - from llmtuner.extras.logging import get_logger from llmtuner.hparams import ( ModelArguments, @@ -31,17 +21,6 @@ from llmtuner.hparams import ( logger = get_logger(__name__) -def _infer_dtype() -> torch.dtype: - if is_npu_available: - return torch.float16 - elif is_bf16_available: - return torch.bfloat16 - elif is_fp16_available: - return torch.float16 - else: - return torch.float32 - - def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: if args is not None: return parser.parse_dict(args) @@ -178,12 +157,15 @@ def get_train_args( if not finetuning_args.resume_lora_training: raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.") - if model_args.quantization_bit is not None and (not training_args.do_train): - logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") + if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): + logger.warning("We recommend enable `upcast_layernorm` in quantized training.") if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): logger.warning("We recommend enable mixed precision training.") + if (not training_args.do_train) and model_args.quantization_bit is not None: + logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") + # postprocess data_args if data_args.max_samples is not None and data_args.streaming: logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.") @@ -206,10 +188,9 @@ def get_train_args( and os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir ): - require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.") last_checkpoint = get_last_checkpoint(training_args.output_dir) if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: - raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.") + raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") if last_checkpoint is not None: training_args_dict = training_args.to_dict() @@ -220,26 +201,7 @@ def get_train_args( ) # postprocess model_args - if training_args.bf16: - if not is_bf16_available: - raise ValueError("Current device does not support bf16 training.") - model_args.compute_dtype = torch.bfloat16 - elif training_args.fp16: - model_args.compute_dtype = torch.float16 - else: - model_args.compute_dtype = _infer_dtype() - - if model_args.layernorm_dtype == "bf16": - if not is_bf16_available: - raise ValueError("Current device does not support bf16 type.") - model_args.layernorm_dtype = torch.bfloat16 - elif model_args.layernorm_dtype == "fp16": - model_args.layernorm_dtype = torch.float16 - elif model_args.layernorm_dtype == "fp32": - model_args.layernorm_dtype = torch.float32 - else: - model_args.layernorm_dtype = model_args.compute_dtype - + model_args.compute_dtype = torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None) model_args.model_max_length = data_args.cutoff_len # Log on each process the small summary: @@ -278,7 +240,4 @@ def get_infer_args( if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1: raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") - # auto-detect cuda capability - model_args.compute_dtype = _infer_dtype() - return model_args, data_args, finetuning_args, generating_args diff --git a/src/llmtuner/tuner/core/utils.py b/src/llmtuner/tuner/core/utils.py index a07fa31c..4d3630d6 100644 --- a/src/llmtuner/tuner/core/utils.py +++ b/src/llmtuner/tuner/core/utils.py @@ -31,11 +31,11 @@ def find_all_linear_modules( def prepare_model_for_training( model: "PreTrainedModel", - layernorm_dtype: torch.dtype, + upcast_layernorm: bool, finetuning_type: str, output_layer_name: Optional[str] = "lm_head", use_gradient_checkpointing: Optional[bool] = True, - layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES + layernorm_names: Optional[List[str]] = LAYERNORM_NAMES ) -> "PreTrainedModel": r""" Includes: @@ -44,9 +44,10 @@ def prepare_model_for_training( (3) upcast the lm_head to fp32 Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33 """ - for name, param in model.named_parameters(): - if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): - param.data = param.data.to(layernorm_dtype) + if upcast_layernorm: + for name, param in model.named_parameters(): + if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names): + param.data = param.data.to(torch.float32) if use_gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 12d05aa9..85e48279 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -10,14 +10,15 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from trl import PPOTrainer from trl.core import PPODecorators, logprobs_from_logits +from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor -from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model +from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback from trl import AutoModelForCausalLMWithValueHead - from llmtuner.hparams import GeneratingArguments + from llmtuner.hparams import ModelArguments, GeneratingArguments logger = get_logger(__name__) @@ -30,10 +31,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer): def __init__( self, + model_args: "ModelArguments", training_args: "Seq2SeqTrainingArguments", generating_args: "GeneratingArguments", callbacks: List["TrainerCallback"], - compute_dtype: torch.dtype, **kwargs ): PPOTrainer.__init__(self, **kwargs) @@ -41,11 +42,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer): raise ValueError("PPOTrainer is incompatible with DeepSpeed.") self.args = training_args - self.generating_args = generating_args - self.log_callback, self.save_callback = callbacks[0], callbacks[1] - self.compute_dtype = compute_dtype + self.model_args = model_args + self.generation_config = GenerationConfig( + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, + **generating_args.to_dict() + ) self.state = TrainerState() self.control = TrainerControl() + self.log_callback, self.save_callback = callbacks[0], callbacks[1] + assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) def ppo_train(self) -> None: r""" @@ -74,13 +80,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer): logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") - # Keyword arguments for `model.generate` - generating_args = self.generating_args.to_dict() - generating_args.update(dict( - eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, - pad_token_id=self.tokenizer.pad_token_id - )) - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) dataiter = iter(self.dataloader) steps_trained = 0 @@ -98,7 +97,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.model.eval() # Get inputs - queries, responses = self.get_inputs(batch, generating_args) + queries, responses = self.get_inputs(batch) self.tokenizer.padding_side = "right" # change padding side rewards = self.get_rewards(queries, responses, unwrapped_model) @@ -159,27 +158,24 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ) @torch.no_grad() - def get_inputs( - self, - batch: Dict[str, torch.Tensor], - generating_args: Dict[str, Any] - ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: r""" Generates model's responses given queries. """ - gen_kwargs = dict( - generation_config=GenerationConfig(**generating_args), + if self.model_args.upcast_layernorm: + layernorm_params = dump_layernorm(self.model) + + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + response: torch.Tensor = unwrapped_model.generate( + generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch ) - input_ids = batch["input_ids"] - self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype) - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) - response: torch.Tensor = unwrapped_model.generate(**gen_kwargs) - self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params) - query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu() + if self.model_args.upcast_layernorm: + restore_layernorm(self.model, layernorm_params) + query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu() queries, responses = [], [] for i in range(len(query)): query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] diff --git a/src/llmtuner/tuner/ppo/utils.py b/src/llmtuner/tuner/ppo/utils.py index 2257eead..74453a39 100644 --- a/src/llmtuner/tuner/ppo/utils.py +++ b/src/llmtuner/tuner/ppo/utils.py @@ -1,40 +1,35 @@ import torch -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple - -from llmtuner.extras.constants import LAYERNORM_NAMES +from typing import TYPE_CHECKING, Dict, Literal, Optional if TYPE_CHECKING: + from transformers import PreTrainedModel from trl import AutoModelForCausalLMWithValueHead def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: if target == "reward": # save default head temporarily - valuehead_state_dict = model.v_head.state_dict() + valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict() setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone()) setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone()) model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active model.v_head.load_state_dict({ - "summary.weight": getattr(model, "{}_head_weight".format(target)), - "summary.bias": getattr(model, "{}_head_bias".format(target)) + "summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(), + "summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone() }) -def cast_layernorm_dtype( - model: "AutoModelForCausalLMWithValueHead", - compute_dtype: torch.dtype, - layer_norm_params: Optional[Dict[str, torch.Tensor]] = None, - layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES -) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]: - - layer_norm_state_dict = {} - +def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: + layer_norm_params = {} for name, param in model.named_parameters(): - if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): - if layer_norm_params is None: - layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability - param.data = param.data.to(compute_dtype) - else: - param.data = layer_norm_params[name] # restore float32 weights + if param.data.dtype == torch.float32: + layer_norm_params[name] = param.data.detach().clone() + param.data = param.data.to(model.config.torch_dtype) - return model, layer_norm_state_dict + return layer_norm_params + + +def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None: + for name, param in model.named_parameters(): + if name in layernorm_params: + param.data = layernorm_params[name] diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index bfc58cb6..7fd2f29b 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -65,10 +65,10 @@ def run_ppo( # Initialize our Trainer ppo_trainer = CustomPPOTrainer( + model_args=model_args, training_args=training_args, generating_args=generating_args, callbacks=callbacks + [SavePeftModelCallback()], - compute_dtype=model_args.compute_dtype, config=ppo_config, model=model, ref_model=None, diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 08d4557c..4ea54168 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -145,6 +145,9 @@ class Runner: ) args[compute_type] = True + if args["quantization_bit"] is not None: + args["upcast_layernorm"] = True + if args["stage"] == "ppo": args["reward_model"] = reward_model val_size = 0