From 083787dbfe41f58ff59cb16ddde02df98593aef5 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 1 Nov 2023 23:38:49 +0800 Subject: [PATCH] fix #1325 --- src/llmtuner/dsets/loader.py | 5 +++-- src/llmtuner/hparams/data_args.py | 2 ++ src/llmtuner/tuner/core/parser.py | 3 --- src/llmtuner/tuner/ppo/trainer.py | 35 +++++++++++++++++------------- src/llmtuner/tuner/ppo/workflow.py | 12 ++++++---- 5 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index 826b548c..fe88ce50 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -72,10 +72,11 @@ def get_dataset( dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) if dataset_attr.system_prompt: # add system prompt + system_prompt = dataset_attr.system_prompt if data_args.streaming: - dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}) + dataset = dataset.map(lambda _: {"system": system_prompt}) else: - dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset)) + dataset = dataset.add_column("system", [system_prompt] * len(dataset)) all_datasets.append(dataset) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 184cc3ca..fa2989ef 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -12,6 +12,8 @@ class DatasetAttr: dataset_sha1: Optional[str] = None system_prompt: Optional[str] = None ranking: Optional[bool] = False + formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" + prompt: Optional[str] = "instruction" query: Optional[str] = "input" response: Optional[str] = "output" diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 59ece34d..71b2c810 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -117,9 +117,6 @@ def get_train_args( if finetuning_args.stage == "ppo" and model_args.reward_model is None: raise ValueError("Reward model is necessary for PPO training.") - if finetuning_args.stage == "ppo" and data_args.streaming: - raise ValueError("Streaming mode does not suppport PPO training currently.") - if finetuning_args.stage == "ppo" and model_args.shift_attn: raise ValueError("PPO training is incompatible with S^2-Attn.") diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index baa36404..372c4891 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -1,4 +1,5 @@ import os +import sys import math import torch from tqdm import tqdm @@ -39,9 +40,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer): **kwargs ): PPOTrainer.__init__(self, **kwargs) - if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None: - raise ValueError("PPOTrainer is incompatible with DeepSpeed.") - self.args = training_args self.model_args = model_args self.finetuning_args = finetuning_args @@ -54,6 +52,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.control = TrainerControl() self.log_callback, self.save_callback = callbacks[0], callbacks[1] assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) + if self.args.max_steps > 0: + logger.info("max_steps is given, it will override any value given in num_train_epochs") def ppo_train(self) -> None: r""" @@ -62,10 +62,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer): total_train_batch_size = ( self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size ) - len_dataloader = len(self.dataloader) - num_examples = len(self.dataset) - num_train_epochs = self.args.num_train_epochs - max_steps = math.ceil(num_train_epochs * len_dataloader) + if self.args.max_steps > 0: + num_examples = total_train_batch_size * self.args.max_steps + num_train_epochs = sys.maxsize + max_steps = self.args.max_steps + steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps + else: + len_dataloader = len(self.dataloader) + num_examples = len(self.dataset) + num_train_epochs = self.args.num_train_epochs + max_steps = math.ceil(num_train_epochs * len_dataloader) + steps_in_epoch = len_dataloader self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs @@ -84,14 +91,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer): unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) dataiter = iter(self.dataloader) - steps_trained = 0 loss_meter = AverageMeter() reward_meter = AverageMeter() self.log_callback.on_train_begin(self.args, self.state, self.control) for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()): - batch = next(dataiter) - steps_trained += 1 + try: + batch = next(dataiter) + except StopIteration: + dataiter = iter(self.dataloader) + batch = next(dataiter) # Cast to inference mode unwrapped_model.gradient_checkpointing_disable() @@ -130,7 +139,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): loss=round(loss_meter.avg, 4), reward=round(reward_meter.avg, 4), learning_rate=stats["ppo/learning_rate"], - epoch=round(step / len_dataloader, 2) + epoch=round(step / steps_in_epoch, 2) ) tqdm.write(str(logs)) logs["step"] = step @@ -150,10 +159,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer): if self.control.should_epoch_stop or self.control.should_training_stop: break - if steps_trained == len_dataloader: - dataiter = iter(self.dataloader) - steps_trained = 0 - self.log_callback.on_train_end(self.args, self.state, self.control) self.save_callback.on_train_end( self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 3fcb72fd..4c35f628 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -51,10 +51,14 @@ def run_ppo( ) optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) - total_train_batch_size = ( - training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size - ) - num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) + if training_args.max_steps > 0: + num_training_steps = training_args.max_steps + else: + total_train_batch_size = ( + training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size + ) + num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) + lr_scheduler = get_scheduler( training_args.lr_scheduler_type, optimizer=optimizer,