From 0c9fda01e3c61727c939efd9d9398f657a2d69b6 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 28 May 2023 21:30:28 +0800 Subject: [PATCH] use fp16 model, add logcallback --- src/train_ppo.py | 2 ++ src/train_rm.py | 2 ++ src/train_sft.py | 4 ++- src/utils/__init__.py | 2 ++ src/utils/common.py | 12 +++++++-- src/utils/peft_trainer.py | 53 ++++++++++++++++++++++++++++++++++++++- src/utils/ppo.py | 47 +++++++++++++++++++++++++++++----- 7 files changed, 112 insertions(+), 10 deletions(-) diff --git a/src/train_ppo.py b/src/train_ppo.py index 41f89a57..f6e57c05 100644 --- a/src/train_ppo.py +++ b/src/train_ppo.py @@ -17,6 +17,7 @@ from utils import ( preprocess_data, DataCollatorForLLaMA, PPOTrainerForLLaMA, + LogCallback, plot_loss ) @@ -54,6 +55,7 @@ def main(): ppo_trainer = PPOTrainerForLLaMA( training_args=training_args, finetuning_args=finetuning_args, + callbacks=[LogCallback()], config=ppo_config, model=model, ref_model=None, diff --git a/src/train_rm.py b/src/train_rm.py index dd544f3a..ecd7f714 100644 --- a/src/train_rm.py +++ b/src/train_rm.py @@ -12,6 +12,7 @@ from utils import ( preprocess_data, PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA, + LogCallback, plot_loss ) @@ -43,6 +44,7 @@ def main(): args=training_args, tokenizer=tokenizer, data_collator=data_collator, + callbacks=[LogCallback()], **trainer_kwargs ) diff --git a/src/train_sft.py b/src/train_sft.py index d34a8e4e..3bc0f850 100644 --- a/src/train_sft.py +++ b/src/train_sft.py @@ -12,6 +12,7 @@ from utils import ( DataCollatorForLLaMA, Seq2SeqTrainerForLLaMA, ComputeMetrics, + LogCallback, get_logits_processor, plot_loss ) @@ -49,6 +50,7 @@ def main(): args=training_args, tokenizer=tokenizer, data_collator=data_collator, + callbacks=[LogCallback()], compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, **trainer_kwargs ) @@ -57,7 +59,7 @@ def main(): gen_kwargs = { "do_sample": True, "top_p": 0.7, - "max_length": data_args.max_source_length + data_args.max_target_length + 1, + "max_new_tokens": data_args.max_target_length + 1, "temperature": 0.95, "logits_processor": get_logits_processor() } diff --git a/src/utils/__init__.py b/src/utils/__init__.py index c19e82ad..a9104cc0 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -7,6 +7,8 @@ from .common import ( from .data_collator import DataCollatorForLLaMA +from .peft_trainer import LogCallback + from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA from .ppo import PPOTrainerForLLaMA diff --git a/src/utils/common.py b/src/utils/common.py index db6bfd22..2798124e 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Tuple import transformers from transformers import ( + LlamaConfig, LlamaForCausalLM, LlamaTokenizer, HfArgumentParser, @@ -151,7 +152,7 @@ def load_pretrained( use_fast=model_args.use_fast_tokenizer, padding_side="left" ) - tokenizer.pad_token_id = 0 # set as the token + tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the token # Quantization configurations (using bitsandbytes library). config_kwargs = {} @@ -168,8 +169,15 @@ def load_pretrained( config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit logger.info("Quantized model to {} bit.".format(model_args.quantization_bit)) + config = LlamaConfig.from_pretrained(model_args.model_name_or_path) + # Load and prepare pretrained models (without valuehead). - model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, **config_kwargs) + model = LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + torch_dtype=torch.float16, # the llama weights are float16 type + **config_kwargs + ) model = prepare_model_for_training(model) if is_trainable else model model = init_adapter(model, model_args, finetuning_args, is_trainable) diff --git a/src/utils/peft_trainer.py b/src/utils/peft_trainer.py index 57d54a8d..0afe4fb2 100644 --- a/src/utils/peft_trainer.py +++ b/src/utils/peft_trainer.py @@ -1,8 +1,18 @@ import os +import json +import time import torch from typing import Dict, Optional +from datetime import timedelta + +from transformers import ( + Seq2SeqTrainer, + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments +) -from transformers import Seq2SeqTrainer from transformers.trainer import TRAINING_ARGS_NAME from transformers.modeling_utils import unwrap_model @@ -23,6 +33,44 @@ from .other import ( logger = get_logger(__name__) +class LogCallback(TrainerCallback): + r""" + TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class. + The on_log function primarily collects process parameters during training, such as training loss, learning rate, + and training epochs, as well as progress parameters like the current percentage progress and estimated remaining + time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization + purposes. + """ + + def __init__(self): + self.start_time = time.time() + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: + r""" + Event called after logging the last logs. + """ + cur_time = time.time() + cur_steps = state.log_history[-1].get("step") + elapsed_time = cur_time - self.start_time + avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 + remaining_steps = state.max_steps - cur_steps + remaining_time = remaining_steps * avg_time_per_step + log_dict = { + "current_steps": cur_steps, + "total_steps": state.max_steps, + "loss": state.log_history[-1].get("loss", None), + "reward": state.log_history[-1].get("reward", None), + "learning_rate": state.log_history[-1].get("learning_rate", None), + "epoch": state.log_history[-1].get("epoch", None), + "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100, + "elapsed_time": str(timedelta(seconds=int(elapsed_time))), + "remaining_time": str(timedelta(seconds=int(remaining_time))) + } + os.makedirs(args.output_dir, exist_ok=True) + with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f: + f.write(json.dumps(log_dict) + "\n") + + class PeftTrainer(Seq2SeqTrainer): r""" Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. @@ -31,6 +79,9 @@ class PeftTrainer(Seq2SeqTrainer): def __init__(self, finetuning_args: FinetuningArguments, **kwargs): super().__init__(**kwargs) self.finetuning_args = finetuning_args + if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")): + logger.warning("Previous log file in this folder will be deleted.") + os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl")) def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: r""" diff --git a/src/utils/ppo.py b/src/utils/ppo.py index 85c69505..8a068876 100644 --- a/src/utils/ppo.py +++ b/src/utils/ppo.py @@ -4,15 +4,14 @@ import torch from tqdm import tqdm from typing import Callable, Dict, List, Literal, Optional, Tuple -from transformers import Seq2SeqTrainingArguments -from transformers.trainer import TrainerState +from transformers import Seq2SeqTrainingArguments, TrainerState from transformers.modeling_utils import PreTrainedModel from trl import PPOTrainer, AutoModelForCausalLMWithValueHead from trl.core import LengthSampler from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits -from .peft_trainer import PeftTrainer +from .peft_trainer import PeftTrainer, LogCallback from .config import FinetuningArguments @@ -40,15 +39,41 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def }) +def cast_layernorm_dtype( + model: AutoModelForCausalLMWithValueHead, + layer_norm_names: List[str] = ["layernorm"], # for chatglm setting + layer_norm_params: Optional[Dict[str, torch.Tensor]] = None +) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: + + layer_norm_state_dict = {} + + for name, param in model.named_parameters(): + if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): + if layer_norm_params is not None: + param.data = layer_norm_params[name] # restore float32 weights + else: + layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability + param.data = param.data.to(torch.float16) + + return model, layer_norm_state_dict + + class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): r""" Inherits PPOTrainer. """ - def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, **kwargs): + def __init__( + self, + training_args: Seq2SeqTrainingArguments, + finetuning_args: FinetuningArguments, + callbacks: List[LogCallback], + **kwargs + ): PPOTrainer.__init__(self, **kwargs) self.args = training_args self.finetuning_args = finetuning_args + self.log_callback = callbacks[0] self.state = TrainerState() self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) @@ -63,6 +88,11 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): num_train_epochs = self.args.num_train_epochs max_steps = math.ceil(num_train_epochs * num_steps_per_epoch) + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + if self.is_world_process_zero(): logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples}") @@ -144,6 +174,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): print(logs) logs["step"] = step self.state.log_history.append(logs) + self.log_callback.on_log(self.args, self.state, None) loss_meter.reset() reward_meter.reset() @@ -154,8 +185,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): def generate( self, inputs: Dict[str, torch.Tensor], - length_sampler: Callable = None, - return_prompt: bool = True, + length_sampler: Optional[Callable] = None, + return_prompt: Optional[bool] = True, **generation_kwargs, ) -> torch.Tensor: r""" @@ -163,6 +194,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): Subclass and override to inject custom behavior. """ + self.model, layer_norm_params = cast_layernorm_dtype(self.model) + if length_sampler is not None: generation_kwargs["max_new_tokens"] = length_sampler() @@ -175,6 +208,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): if unwrapped_model.pretrained_model.generation_config._from_model_config: unwrapped_model.pretrained_model.generation_config._from_model_config = False + self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) + if not return_prompt and not self.is_encoder_decoder: return response[:, inputs["input_ids"].size(1):] return response