fix #1325
This commit is contained in:
parent
8b912690e3
commit
083787dbfe
|
@ -72,10 +72,11 @@ def get_dataset(
|
||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||||
|
|
||||||
if dataset_attr.system_prompt: # add system prompt
|
if dataset_attr.system_prompt: # add system prompt
|
||||||
|
system_prompt = dataset_attr.system_prompt
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
|
dataset = dataset.map(lambda _: {"system": system_prompt})
|
||||||
else:
|
else:
|
||||||
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
|
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
|
||||||
|
|
||||||
all_datasets.append(dataset)
|
all_datasets.append(dataset)
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,8 @@ class DatasetAttr:
|
||||||
dataset_sha1: Optional[str] = None
|
dataset_sha1: Optional[str] = None
|
||||||
system_prompt: Optional[str] = None
|
system_prompt: Optional[str] = None
|
||||||
ranking: Optional[bool] = False
|
ranking: Optional[bool] = False
|
||||||
|
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||||
|
|
||||||
prompt: Optional[str] = "instruction"
|
prompt: Optional[str] = "instruction"
|
||||||
query: Optional[str] = "input"
|
query: Optional[str] = "input"
|
||||||
response: Optional[str] = "output"
|
response: Optional[str] = "output"
|
||||||
|
|
|
@ -117,9 +117,6 @@ def get_train_args(
|
||||||
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
|
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
|
||||||
raise ValueError("Reward model is necessary for PPO training.")
|
raise ValueError("Reward model is necessary for PPO training.")
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and data_args.streaming:
|
|
||||||
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
@ -39,9 +40,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
|
|
||||||
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
|
|
||||||
|
|
||||||
self.args = training_args
|
self.args = training_args
|
||||||
self.model_args = model_args
|
self.model_args = model_args
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
|
@ -54,6 +52,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||||
|
if self.args.max_steps > 0:
|
||||||
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
|
||||||
def ppo_train(self) -> None:
|
def ppo_train(self) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
@ -62,10 +62,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
total_train_batch_size = (
|
total_train_batch_size = (
|
||||||
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
|
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
|
||||||
)
|
)
|
||||||
len_dataloader = len(self.dataloader)
|
if self.args.max_steps > 0:
|
||||||
num_examples = len(self.dataset)
|
num_examples = total_train_batch_size * self.args.max_steps
|
||||||
num_train_epochs = self.args.num_train_epochs
|
num_train_epochs = sys.maxsize
|
||||||
max_steps = math.ceil(num_train_epochs * len_dataloader)
|
max_steps = self.args.max_steps
|
||||||
|
steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
|
||||||
|
else:
|
||||||
|
len_dataloader = len(self.dataloader)
|
||||||
|
num_examples = len(self.dataset)
|
||||||
|
num_train_epochs = self.args.num_train_epochs
|
||||||
|
max_steps = math.ceil(num_train_epochs * len_dataloader)
|
||||||
|
steps_in_epoch = len_dataloader
|
||||||
|
|
||||||
self.state.max_steps = max_steps
|
self.state.max_steps = max_steps
|
||||||
self.state.num_train_epochs = num_train_epochs
|
self.state.num_train_epochs = num_train_epochs
|
||||||
|
@ -84,14 +91,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
dataiter = iter(self.dataloader)
|
dataiter = iter(self.dataloader)
|
||||||
steps_trained = 0
|
|
||||||
loss_meter = AverageMeter()
|
loss_meter = AverageMeter()
|
||||||
reward_meter = AverageMeter()
|
reward_meter = AverageMeter()
|
||||||
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
|
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
|
||||||
batch = next(dataiter)
|
try:
|
||||||
steps_trained += 1
|
batch = next(dataiter)
|
||||||
|
except StopIteration:
|
||||||
|
dataiter = iter(self.dataloader)
|
||||||
|
batch = next(dataiter)
|
||||||
|
|
||||||
# Cast to inference mode
|
# Cast to inference mode
|
||||||
unwrapped_model.gradient_checkpointing_disable()
|
unwrapped_model.gradient_checkpointing_disable()
|
||||||
|
@ -130,7 +139,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
loss=round(loss_meter.avg, 4),
|
loss=round(loss_meter.avg, 4),
|
||||||
reward=round(reward_meter.avg, 4),
|
reward=round(reward_meter.avg, 4),
|
||||||
learning_rate=stats["ppo/learning_rate"],
|
learning_rate=stats["ppo/learning_rate"],
|
||||||
epoch=round(step / len_dataloader, 2)
|
epoch=round(step / steps_in_epoch, 2)
|
||||||
)
|
)
|
||||||
tqdm.write(str(logs))
|
tqdm.write(str(logs))
|
||||||
logs["step"] = step
|
logs["step"] = step
|
||||||
|
@ -150,10 +159,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||||
break
|
break
|
||||||
|
|
||||||
if steps_trained == len_dataloader:
|
|
||||||
dataiter = iter(self.dataloader)
|
|
||||||
steps_trained = 0
|
|
||||||
|
|
||||||
self.log_callback.on_train_end(self.args, self.state, self.control)
|
self.log_callback.on_train_end(self.args, self.state, self.control)
|
||||||
self.save_callback.on_train_end(
|
self.save_callback.on_train_end(
|
||||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||||
|
|
|
@ -51,10 +51,14 @@ def run_ppo(
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
total_train_batch_size = (
|
if training_args.max_steps > 0:
|
||||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
num_training_steps = training_args.max_steps
|
||||||
)
|
else:
|
||||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
total_train_batch_size = (
|
||||||
|
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||||
|
)
|
||||||
|
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
training_args.lr_scheduler_type,
|
training_args.lr_scheduler_type,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
|
Loading…
Reference in New Issue