This commit is contained in:
hiyouga 2023-11-01 23:38:49 +08:00
parent 8b912690e3
commit 083787dbfe
5 changed files with 33 additions and 24 deletions

View File

@ -72,10 +72,11 @@ def get_dataset(
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.system_prompt: # add system prompt if dataset_attr.system_prompt: # add system prompt
system_prompt = dataset_attr.system_prompt
if data_args.streaming: if data_args.streaming:
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}) dataset = dataset.map(lambda _: {"system": system_prompt})
else: else:
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset)) dataset = dataset.add_column("system", [system_prompt] * len(dataset))
all_datasets.append(dataset) all_datasets.append(dataset)

View File

@ -12,6 +12,8 @@ class DatasetAttr:
dataset_sha1: Optional[str] = None dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
ranking: Optional[bool] = False ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
prompt: Optional[str] = "instruction" prompt: Optional[str] = "instruction"
query: Optional[str] = "input" query: Optional[str] = "input"
response: Optional[str] = "output" response: Optional[str] = "output"

View File

@ -117,9 +117,6 @@ def get_train_args(
if finetuning_args.stage == "ppo" and model_args.reward_model is None: if finetuning_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.") raise ValueError("Reward model is necessary for PPO training.")
if finetuning_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.")
if finetuning_args.stage == "ppo" and model_args.shift_attn: if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.") raise ValueError("PPO training is incompatible with S^2-Attn.")

View File

@ -1,4 +1,5 @@
import os import os
import sys
import math import math
import torch import torch
from tqdm import tqdm from tqdm import tqdm
@ -39,9 +40,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
**kwargs **kwargs
): ):
PPOTrainer.__init__(self, **kwargs) PPOTrainer.__init__(self, **kwargs)
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
self.args = training_args self.args = training_args
self.model_args = model_args self.model_args = model_args
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
@ -54,6 +52,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.control = TrainerControl() self.control = TrainerControl()
self.log_callback, self.save_callback = callbacks[0], callbacks[1] self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
def ppo_train(self) -> None: def ppo_train(self) -> None:
r""" r"""
@ -62,10 +62,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
total_train_batch_size = ( total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
) )
len_dataloader = len(self.dataloader) if self.args.max_steps > 0:
num_examples = len(self.dataset) num_examples = total_train_batch_size * self.args.max_steps
num_train_epochs = self.args.num_train_epochs num_train_epochs = sys.maxsize
max_steps = math.ceil(num_train_epochs * len_dataloader) max_steps = self.args.max_steps
steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
else:
len_dataloader = len(self.dataloader)
num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * len_dataloader)
steps_in_epoch = len_dataloader
self.state.max_steps = max_steps self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs self.state.num_train_epochs = num_train_epochs
@ -84,14 +91,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
steps_trained = 0
loss_meter = AverageMeter() loss_meter = AverageMeter()
reward_meter = AverageMeter() reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control) self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()): for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
batch = next(dataiter) try:
steps_trained += 1 batch = next(dataiter)
except StopIteration:
dataiter = iter(self.dataloader)
batch = next(dataiter)
# Cast to inference mode # Cast to inference mode
unwrapped_model.gradient_checkpointing_disable() unwrapped_model.gradient_checkpointing_disable()
@ -130,7 +139,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
loss=round(loss_meter.avg, 4), loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4), reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"], learning_rate=stats["ppo/learning_rate"],
epoch=round(step / len_dataloader, 2) epoch=round(step / steps_in_epoch, 2)
) )
tqdm.write(str(logs)) tqdm.write(str(logs))
logs["step"] = step logs["step"] = step
@ -150,10 +159,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.control.should_epoch_stop or self.control.should_training_stop: if self.control.should_epoch_stop or self.control.should_training_stop:
break break
if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control) self.log_callback.on_train_end(self.args, self.state, self.control)
self.save_callback.on_train_end( self.save_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)

View File

@ -51,10 +51,14 @@ def run_ppo(
) )
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = ( if training_args.max_steps > 0:
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size num_training_steps = training_args.max_steps
) else:
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
training_args.lr_scheduler_type, training_args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,