diff --git a/src/cli_demo.py b/src/cli_demo.py index 441e6aba..3da88aa6 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -4,22 +4,24 @@ import torch -from utils import ModelArguments, load_pretrained +from utils import ModelArguments, FinetuningArguments, load_pretrained from transformers import HfArgumentParser def main(): - parser = HfArgumentParser(ModelArguments) - model_args, = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((ModelArguments, FinetuningArguments)) + model_args, finetuning_args = parser.parse_args_into_dataclasses() model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" - model, tokenizer = load_pretrained(model_args) + model, tokenizer = load_pretrained(model_args, finetuning_args) + if torch.cuda.device_count() > 1: from accelerate import dispatch_model, infer_auto_device_map device_map = infer_auto_device_map(model) model = dispatch_model(model, device_map) else: model = model.cuda() + model.eval() def format_example(query): diff --git a/src/train_ppo.py b/src/train_ppo.py index 672dd8a9..9dbe9c0e 100644 --- a/src/train_ppo.py +++ b/src/train_ppo.py @@ -70,7 +70,7 @@ def main(): ppo_trainer.save_model() ppo_trainer.save_state() # must be after save_model if ppo_trainer.is_world_process_zero() and model_args.plot_loss: - plot_loss(training_args, keys=["loss", "reward"]) + plot_loss(training_args.output_dir, keys=["loss", "reward"]) def _mp_fn(index): diff --git a/src/train_pt.py b/src/train_pt.py index af88cb6b..7f5a4779 100644 --- a/src/train_pt.py +++ b/src/train_pt.py @@ -55,7 +55,7 @@ def main(): trainer.save_state() trainer.save_model() if trainer.is_world_process_zero() and model_args.plot_loss: - plot_loss(training_args, keys=["loss", "eval_loss"]) + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) # Evaluation if training_args.do_eval: diff --git a/src/train_rm.py b/src/train_rm.py index 8b51f7bf..11aec993 100644 --- a/src/train_rm.py +++ b/src/train_rm.py @@ -56,7 +56,7 @@ def main(): trainer.save_state() trainer.save_model() if trainer.is_world_process_zero() and model_args.plot_loss: - plot_loss(training_args, keys=["loss", "eval_loss"]) + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) # Evaluation if training_args.do_eval: diff --git a/src/train_sft.py b/src/train_sft.py index 29c593f5..bad98c8a 100644 --- a/src/train_sft.py +++ b/src/train_sft.py @@ -72,7 +72,7 @@ def main(): trainer.save_state() trainer.save_model() if trainer.is_world_process_zero() and model_args.plot_loss: - plot_loss(training_args, keys=["loss", "eval_loss"]) + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) # Evaluation if training_args.do_eval: diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 680975a1..9e536b87 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -13,5 +13,5 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer from .ppo import PPOPeftTrainer -from .config import ModelArguments +from .config import ModelArguments, FinetuningArguments from .other import get_logits_processor, plot_loss diff --git a/src/utils/common.py b/src/utils/common.py index 26523286..9009906b 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -42,8 +42,7 @@ from .other import ( load_valuehead_params, print_trainable_params, prepare_model_for_training, - IGNORE_INDEX, - FINETUNING_ARGS_NAME + IGNORE_INDEX ) check_min_version("4.29.1") @@ -128,7 +127,7 @@ def init_adapter( def load_pretrained( model_args: ModelArguments, - finetuning_args: Optional[FinetuningArguments] = None, + finetuning_args: FinetuningArguments, is_trainable: Optional[bool] = False, stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: @@ -137,16 +136,9 @@ def load_pretrained( Support both training and inference. """ - if finetuning_args is None: # load the fine-tuning arguments - if model_args.checkpoint_dir is None: - logger.warning("Checkpoint is not found at evaluation, load the original model.") - finetuning_args = FinetuningArguments(finetuning_type="none") - elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)): - finetuning_args = FinetuningArguments.load_from_json( - os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME) - ) - else: - raise ValueError("Missing fine-tuning arguments in the provided dictionary.") + if (not is_trainable) and model_args.checkpoint_dir is None: + logger.warning("Checkpoint is not found at evaluation, load the original model.") + finetuning_args = FinetuningArguments(finetuning_type="none") assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ "RM and PPO training can only be performed with LoRA method." diff --git a/src/utils/data_collator.py b/src/utils/data_collator.py index 6a6de42d..27d1b7ac 100644 --- a/src/utils/data_collator.py +++ b/src/utils/data_collator.py @@ -2,7 +2,7 @@ import torch from typing import Dict, Optional, Sequence, Union -from transformers import DataCollatorWithPadding +from transformers import DataCollatorWithPadding, BatchEncoding from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizer @@ -34,7 +34,7 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding): attention_mask = attention_mask.bool() return attention_mask - def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]: + def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding: r""" Pads batched data to the longest sequence in the batch. @@ -64,4 +64,4 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding): batch["input_ids"] = input_ids batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device) - return batch + return BatchEncoding(batch) diff --git a/src/utils/other.py b/src/utils/other.py index 2008fb80..470e2e97 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -5,7 +5,6 @@ import torch import logging from typing import Dict, List, Optional -from transformers import Seq2SeqTrainingArguments from transformers.trainer import TRAINER_STATE_NAME from transformers.modeling_utils import PreTrainedModel from transformers.generation.utils import LogitsProcessorList @@ -143,7 +142,7 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) - model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) -def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]: +def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]: """ EMA implementation according to TensorBoard. """ @@ -156,9 +155,10 @@ def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]: return smoothed -def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None: +def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: import matplotlib.pyplot as plt - data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r")) + with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: + data = json.load(f) for key in keys: steps, metrics = [], [] @@ -174,9 +174,9 @@ def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] plt.figure() plt.plot(steps, metrics, alpha=0.4, label="original") plt.plot(steps, smooth(metrics), label="smoothed") - plt.title("training {} of {}".format(key, training_args.output_dir)) + plt.title("training {} of {}".format(key, save_dictionary)) plt.xlabel("step") plt.ylabel(key) plt.legend() - plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100) - print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key))) + plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) + print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) diff --git a/src/utils/peft_trainer.py b/src/utils/peft_trainer.py index 0afe4fb2..f951fd8c 100644 --- a/src/utils/peft_trainer.py +++ b/src/utils/peft_trainer.py @@ -109,7 +109,8 @@ class PeftTrainer(Seq2SeqTrainer): if hasattr(model, "v_head"): # save valuehead weights torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) - torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: + f.write(self.args.to_json_string() + "\n") self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME)) def _load_best_model(self): diff --git a/src/utils/ppo.py b/src/utils/ppo.py index e279c199..b782d1e6 100644 --- a/src/utils/ppo.py +++ b/src/utils/ppo.py @@ -75,7 +75,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): self.finetuning_args = finetuning_args self.log_callback = callbacks[0] self.state = TrainerState() - self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) + self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer def ppo_train(self, max_target_length: int) -> None: r""" @@ -148,7 +148,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): # Compute rewards replace_model(unwrapped_model, target="reward") _, _, values = self.model(**self.prepare_model_inputs(queries, responses)) - rewards = [reward for reward in values[:, -1]] + rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type replace_model(unwrapped_model, target="default") # make sure the model is default at the end # Run PPO step @@ -214,13 +214,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): return response[:, inputs["input_ids"].size(1):] return response - def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]: - input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)] - input_data = self.data_collator([{"input_ids": ids} for ids in input_ids]) - input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None} - input_data.pop("labels", None) # we don't want to compute LM losses - return input_data - @PPODecorators.empty_cuda_cache() def batched_forward_pass( self, diff --git a/src/web_demo.py b/src/web_demo.py index ca766592..83ccdf9a 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -7,21 +7,23 @@ import torch import mdtex2html import gradio as gr -from utils import ModelArguments, load_pretrained +from utils import ModelArguments, FinetuningArguments, load_pretrained from transformers import HfArgumentParser from transformers.utils.versions import require_version require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems -parser = HfArgumentParser(ModelArguments) -model_args, = parser.parse_args_into_dataclasses() -model, tokenizer = load_pretrained(model_args) +parser = HfArgumentParser((ModelArguments, FinetuningArguments)) +model_args, finetuning_args = parser.parse_args_into_dataclasses() +model, tokenizer = load_pretrained(model_args, finetuning_args) + if torch.cuda.device_count() > 1: from accelerate import dispatch_model, infer_auto_device_map device_map = infer_auto_device_map(model) model = dispatch_model(model, device_map) else: model = model.cuda() + model.eval() @@ -74,10 +76,10 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT def format_example(query): - prompt = "Below is an instruction that describes a task. " - prompt += "Write a response that appropriately completes the request.\n" - prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query) - return prompt + prompt = "Below is an instruction that describes a task. " + prompt += "Write a response that appropriately completes the request.\n" + prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query) + return prompt def predict(input, chatbot, max_length, top_p, temperature, history):