use fp16 model, add logcallback
This commit is contained in:
parent
769c6ab56b
commit
0c9fda01e3
|
@ -17,6 +17,7 @@ from utils import (
|
||||||
preprocess_data,
|
preprocess_data,
|
||||||
DataCollatorForLLaMA,
|
DataCollatorForLLaMA,
|
||||||
PPOTrainerForLLaMA,
|
PPOTrainerForLLaMA,
|
||||||
|
LogCallback,
|
||||||
plot_loss
|
plot_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -54,6 +55,7 @@ def main():
|
||||||
ppo_trainer = PPOTrainerForLLaMA(
|
ppo_trainer = PPOTrainerForLLaMA(
|
||||||
training_args=training_args,
|
training_args=training_args,
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
|
callbacks=[LogCallback()],
|
||||||
config=ppo_config,
|
config=ppo_config,
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=None,
|
ref_model=None,
|
||||||
|
|
|
@ -12,6 +12,7 @@ from utils import (
|
||||||
preprocess_data,
|
preprocess_data,
|
||||||
PairwiseDataCollatorForLLaMA,
|
PairwiseDataCollatorForLLaMA,
|
||||||
PairwiseTrainerForLLaMA,
|
PairwiseTrainerForLLaMA,
|
||||||
|
LogCallback,
|
||||||
plot_loss
|
plot_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,6 +44,7 @@ def main():
|
||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
|
callbacks=[LogCallback()],
|
||||||
**trainer_kwargs
|
**trainer_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from utils import (
|
||||||
DataCollatorForLLaMA,
|
DataCollatorForLLaMA,
|
||||||
Seq2SeqTrainerForLLaMA,
|
Seq2SeqTrainerForLLaMA,
|
||||||
ComputeMetrics,
|
ComputeMetrics,
|
||||||
|
LogCallback,
|
||||||
get_logits_processor,
|
get_logits_processor,
|
||||||
plot_loss
|
plot_loss
|
||||||
)
|
)
|
||||||
|
@ -49,6 +50,7 @@ def main():
|
||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
|
callbacks=[LogCallback()],
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
**trainer_kwargs
|
**trainer_kwargs
|
||||||
)
|
)
|
||||||
|
@ -57,7 +59,7 @@ def main():
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"top_p": 0.7,
|
"top_p": 0.7,
|
||||||
"max_length": data_args.max_source_length + data_args.max_target_length + 1,
|
"max_new_tokens": data_args.max_target_length + 1,
|
||||||
"temperature": 0.95,
|
"temperature": 0.95,
|
||||||
"logits_processor": get_logits_processor()
|
"logits_processor": get_logits_processor()
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,8 @@ from .common import (
|
||||||
|
|
||||||
from .data_collator import DataCollatorForLLaMA
|
from .data_collator import DataCollatorForLLaMA
|
||||||
|
|
||||||
|
from .peft_trainer import LogCallback
|
||||||
|
|
||||||
from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA
|
from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA
|
||||||
from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA
|
from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA
|
||||||
from .ppo import PPOTrainerForLLaMA
|
from .ppo import PPOTrainerForLLaMA
|
||||||
|
|
|
@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Tuple
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
LlamaConfig,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
|
@ -151,7 +152,7 @@ def load_pretrained(
|
||||||
use_fast=model_args.use_fast_tokenizer,
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
padding_side="left"
|
padding_side="left"
|
||||||
)
|
)
|
||||||
tokenizer.pad_token_id = 0 # set as the <unk> token
|
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
||||||
|
|
||||||
# Quantization configurations (using bitsandbytes library).
|
# Quantization configurations (using bitsandbytes library).
|
||||||
config_kwargs = {}
|
config_kwargs = {}
|
||||||
|
@ -168,8 +169,15 @@ def load_pretrained(
|
||||||
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
|
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
|
||||||
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
|
config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
|
||||||
|
|
||||||
# Load and prepare pretrained models (without valuehead).
|
# Load and prepare pretrained models (without valuehead).
|
||||||
model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
config=config,
|
||||||
|
torch_dtype=torch.float16, # the llama weights are float16 type
|
||||||
|
**config_kwargs
|
||||||
|
)
|
||||||
model = prepare_model_for_training(model) if is_trainable else model
|
model = prepare_model_for_training(model) if is_trainable else model
|
||||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,18 @@
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
Seq2SeqTrainer,
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments
|
||||||
|
)
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainer
|
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
from transformers.modeling_utils import unwrap_model
|
from transformers.modeling_utils import unwrap_model
|
||||||
|
|
||||||
|
@ -23,6 +33,44 @@ from .other import (
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LogCallback(TrainerCallback):
|
||||||
|
r"""
|
||||||
|
TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
|
||||||
|
The on_log function primarily collects process parameters during training, such as training loss, learning rate,
|
||||||
|
and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
|
||||||
|
time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
|
||||||
|
purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||||
|
r"""
|
||||||
|
Event called after logging the last logs.
|
||||||
|
"""
|
||||||
|
cur_time = time.time()
|
||||||
|
cur_steps = state.log_history[-1].get("step")
|
||||||
|
elapsed_time = cur_time - self.start_time
|
||||||
|
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||||
|
remaining_steps = state.max_steps - cur_steps
|
||||||
|
remaining_time = remaining_steps * avg_time_per_step
|
||||||
|
log_dict = {
|
||||||
|
"current_steps": cur_steps,
|
||||||
|
"total_steps": state.max_steps,
|
||||||
|
"loss": state.log_history[-1].get("loss", None),
|
||||||
|
"reward": state.log_history[-1].get("reward", None),
|
||||||
|
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||||
|
"epoch": state.log_history[-1].get("epoch", None),
|
||||||
|
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||||
|
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||||
|
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||||
|
}
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
|
||||||
|
f.write(json.dumps(log_dict) + "\n")
|
||||||
|
|
||||||
|
|
||||||
class PeftTrainer(Seq2SeqTrainer):
|
class PeftTrainer(Seq2SeqTrainer):
|
||||||
r"""
|
r"""
|
||||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||||
|
@ -31,6 +79,9 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||||
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
|
if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
||||||
|
logger.warning("Previous log file in this folder will be deleted.")
|
||||||
|
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -4,15 +4,14 @@ import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Callable, Dict, List, Literal, Optional, Tuple
|
from typing import Callable, Dict, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments, TrainerState
|
||||||
from transformers.trainer import TrainerState
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
|
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
|
||||||
from trl.core import LengthSampler
|
from trl.core import LengthSampler
|
||||||
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
|
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
|
||||||
|
|
||||||
from .peft_trainer import PeftTrainer
|
from .peft_trainer import PeftTrainer, LogCallback
|
||||||
|
|
||||||
from .config import FinetuningArguments
|
from .config import FinetuningArguments
|
||||||
|
|
||||||
|
@ -40,15 +39,41 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def cast_layernorm_dtype(
|
||||||
|
model: AutoModelForCausalLMWithValueHead,
|
||||||
|
layer_norm_names: List[str] = ["layernorm"], # for chatglm setting
|
||||||
|
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
||||||
|
|
||||||
|
layer_norm_state_dict = {}
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||||
|
if layer_norm_params is not None:
|
||||||
|
param.data = layer_norm_params[name] # restore float32 weights
|
||||||
|
else:
|
||||||
|
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
||||||
|
param.data = param.data.to(torch.float16)
|
||||||
|
|
||||||
|
return model, layer_norm_state_dict
|
||||||
|
|
||||||
|
|
||||||
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||||
r"""
|
r"""
|
||||||
Inherits PPOTrainer.
|
Inherits PPOTrainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
training_args: Seq2SeqTrainingArguments,
|
||||||
|
finetuning_args: FinetuningArguments,
|
||||||
|
callbacks: List[LogCallback],
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
self.args = training_args
|
self.args = training_args
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
|
self.log_callback = callbacks[0]
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
|
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
|
||||||
|
|
||||||
|
@ -63,6 +88,11 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||||
num_train_epochs = self.args.num_train_epochs
|
num_train_epochs = self.args.num_train_epochs
|
||||||
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
|
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
|
||||||
|
|
||||||
|
self.state.max_steps = max_steps
|
||||||
|
self.state.num_train_epochs = num_train_epochs
|
||||||
|
self.state.is_local_process_zero = self.is_local_process_zero()
|
||||||
|
self.state.is_world_process_zero = self.is_world_process_zero()
|
||||||
|
|
||||||
if self.is_world_process_zero():
|
if self.is_world_process_zero():
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(f" Num examples = {num_examples}")
|
logger.info(f" Num examples = {num_examples}")
|
||||||
|
@ -144,6 +174,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||||
print(logs)
|
print(logs)
|
||||||
logs["step"] = step
|
logs["step"] = step
|
||||||
self.state.log_history.append(logs)
|
self.state.log_history.append(logs)
|
||||||
|
self.log_callback.on_log(self.args, self.state, None)
|
||||||
loss_meter.reset()
|
loss_meter.reset()
|
||||||
reward_meter.reset()
|
reward_meter.reset()
|
||||||
|
|
||||||
|
@ -154,8 +185,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, torch.Tensor],
|
inputs: Dict[str, torch.Tensor],
|
||||||
length_sampler: Callable = None,
|
length_sampler: Optional[Callable] = None,
|
||||||
return_prompt: bool = True,
|
return_prompt: Optional[bool] = True,
|
||||||
**generation_kwargs,
|
**generation_kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
|
@ -163,6 +194,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||||
|
|
||||||
Subclass and override to inject custom behavior.
|
Subclass and override to inject custom behavior.
|
||||||
"""
|
"""
|
||||||
|
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
||||||
|
|
||||||
if length_sampler is not None:
|
if length_sampler is not None:
|
||||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||||
|
|
||||||
|
@ -175,6 +208,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||||
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
||||||
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
||||||
|
|
||||||
|
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
||||||
|
|
||||||
if not return_prompt and not self.is_encoder_decoder:
|
if not return_prompt and not self.is_encoder_decoder:
|
||||||
return response[:, inputs["input_ids"].size(1):]
|
return response[:, inputs["input_ids"].size(1):]
|
||||||
return response
|
return response
|
||||||
|
|
Loading…
Reference in New Issue