This commit is contained in:
hiyouga 2023-11-01 23:38:49 +08:00
parent 8b912690e3
commit 083787dbfe
5 changed files with 33 additions and 24 deletions

View File

@ -72,10 +72,11 @@ def get_dataset(
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.system_prompt: # add system prompt
system_prompt = dataset_attr.system_prompt
if data_args.streaming:
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
dataset = dataset.map(lambda _: {"system": system_prompt})
else:
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
all_datasets.append(dataset)

View File

@ -12,6 +12,8 @@ class DatasetAttr:
dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"

View File

@ -117,9 +117,6 @@ def get_train_args(
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if finetuning_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")

View File

@ -1,4 +1,5 @@
import os
import sys
import math
import torch
from tqdm import tqdm
@ -39,9 +40,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
self.args = training_args
self.model_args = model_args
self.finetuning_args = finetuning_args
@ -54,6 +52,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.control = TrainerControl()
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
def ppo_train(self) -> None:
r"""
@ -62,10 +62,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
)
if self.args.max_steps > 0:
num_examples = total_train_batch_size * self.args.max_steps
num_train_epochs = sys.maxsize
max_steps = self.args.max_steps
steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
else:
len_dataloader = len(self.dataloader)
num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * len_dataloader)
steps_in_epoch = len_dataloader
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
@ -84,14 +91,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
loss_meter = AverageMeter()
reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
try:
batch = next(dataiter)
except StopIteration:
dataiter = iter(self.dataloader)
batch = next(dataiter)
steps_trained += 1
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
@ -130,7 +139,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / len_dataloader, 2)
epoch=round(step / steps_in_epoch, 2)
)
tqdm.write(str(logs))
logs["step"] = step
@ -150,10 +159,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control)
self.save_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)

View File

@ -51,10 +51,14 @@ def run_ppo(
)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,