use fp16 model, add logcallback

This commit is contained in:
hiyouga 2023-05-28 21:30:28 +08:00
parent 769c6ab56b
commit 0c9fda01e3
7 changed files with 112 additions and 10 deletions

View File

@ -17,6 +17,7 @@ from utils import (
preprocess_data,
DataCollatorForLLaMA,
PPOTrainerForLLaMA,
LogCallback,
plot_loss
)
@ -54,6 +55,7 @@ def main():
ppo_trainer = PPOTrainerForLLaMA(
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[LogCallback()],
config=ppo_config,
model=model,
ref_model=None,

View File

@ -12,6 +12,7 @@ from utils import (
preprocess_data,
PairwiseDataCollatorForLLaMA,
PairwiseTrainerForLLaMA,
LogCallback,
plot_loss
)
@ -43,6 +44,7 @@ def main():
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
**trainer_kwargs
)

View File

@ -12,6 +12,7 @@ from utils import (
DataCollatorForLLaMA,
Seq2SeqTrainerForLLaMA,
ComputeMetrics,
LogCallback,
get_logits_processor,
plot_loss
)
@ -49,6 +50,7 @@ def main():
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**trainer_kwargs
)
@ -57,7 +59,7 @@ def main():
gen_kwargs = {
"do_sample": True,
"top_p": 0.7,
"max_length": data_args.max_source_length + data_args.max_target_length + 1,
"max_new_tokens": data_args.max_target_length + 1,
"temperature": 0.95,
"logits_processor": get_logits_processor()
}

View File

@ -7,6 +7,8 @@ from .common import (
from .data_collator import DataCollatorForLLaMA
from .peft_trainer import LogCallback
from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA
from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA
from .ppo import PPOTrainerForLLaMA

View File

@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Tuple
import transformers
from transformers import (
LlamaConfig,
LlamaForCausalLM,
LlamaTokenizer,
HfArgumentParser,
@ -151,7 +152,7 @@ def load_pretrained(
use_fast=model_args.use_fast_tokenizer,
padding_side="left"
)
tokenizer.pad_token_id = 0 # set as the <unk> token
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
# Quantization configurations (using bitsandbytes library).
config_kwargs = {}
@ -168,8 +169,15 @@ def load_pretrained(
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
# Load and prepare pretrained models (without valuehead).
model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, **config_kwargs)
model = LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch.float16, # the llama weights are float16 type
**config_kwargs
)
model = prepare_model_for_training(model) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)

View File

@ -1,8 +1,18 @@
import os
import json
import time
import torch
from typing import Dict, Optional
from datetime import timedelta
from transformers import (
Seq2SeqTrainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments
)
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model
@ -23,6 +33,44 @@ from .other import (
logger = get_logger(__name__)
class LogCallback(TrainerCallback):
r"""
TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
The on_log function primarily collects process parameters during training, such as training loss, learning rate,
and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
purposes.
"""
def __init__(self):
self.start_time = time.time()
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
r"""
Event called after logging the last logs.
"""
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_steps = state.max_steps - cur_steps
remaining_time = remaining_steps * avg_time_per_step
log_dict = {
"current_steps": cur_steps,
"total_steps": state.max_steps,
"loss": state.log_history[-1].get("loss", None),
"reward": state.log_history[-1].get("reward", None),
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
f.write(json.dumps(log_dict) + "\n")
class PeftTrainer(Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
@ -31,6 +79,9 @@ class PeftTrainer(Seq2SeqTrainer):
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""

View File

@ -4,15 +4,14 @@ import torch
from tqdm import tqdm
from typing import Callable, Dict, List, Literal, Optional, Tuple
from transformers import Seq2SeqTrainingArguments
from transformers.trainer import TrainerState
from transformers import Seq2SeqTrainingArguments, TrainerState
from transformers.modeling_utils import PreTrainedModel
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
from .peft_trainer import PeftTrainer
from .peft_trainer import PeftTrainer, LogCallback
from .config import FinetuningArguments
@ -40,15 +39,41 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
})
def cast_layernorm_dtype(
model: AutoModelForCausalLMWithValueHead,
layer_norm_names: List[str] = ["layernorm"], # for chatglm setting
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
layer_norm_state_dict = {}
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
if layer_norm_params is not None:
param.data = layer_norm_params[name] # restore float32 weights
else:
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
param.data = param.data.to(torch.float16)
return model, layer_norm_state_dict
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
r"""
Inherits PPOTrainer.
"""
def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, **kwargs):
def __init__(
self,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: List[LogCallback],
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
self.args = training_args
self.finetuning_args = finetuning_args
self.log_callback = callbacks[0]
self.state = TrainerState()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
@ -63,6 +88,11 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
if self.is_world_process_zero():
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
@ -144,6 +174,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
print(logs)
logs["step"] = step
self.state.log_history.append(logs)
self.log_callback.on_log(self.args, self.state, None)
loss_meter.reset()
reward_meter.reset()
@ -154,8 +185,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
def generate(
self,
inputs: Dict[str, torch.Tensor],
length_sampler: Callable = None,
return_prompt: bool = True,
length_sampler: Optional[Callable] = None,
return_prompt: Optional[bool] = True,
**generation_kwargs,
) -> torch.Tensor:
r"""
@ -163,6 +194,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
Subclass and override to inject custom behavior.
"""
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
@ -175,6 +208,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
if not return_prompt and not self.is_encoder_decoder:
return response[:, inputs["input_ids"].size(1):]
return response