alter rewards data type

This commit is contained in:
hiyouga 2023-06-02 14:19:51 +08:00
parent e6126244c1
commit 50d9a20f81
12 changed files with 40 additions and 50 deletions

View File

@ -4,22 +4,24 @@
import torch import torch
from utils import ModelArguments, load_pretrained from utils import ModelArguments, FinetuningArguments, load_pretrained
from transformers import HfArgumentParser from transformers import HfArgumentParser
def main(): def main():
parser = HfArgumentParser(ModelArguments) parser = HfArgumentParser((ModelArguments, FinetuningArguments))
model_args, = parser.parse_args_into_dataclasses() model_args, finetuning_args = parser.parse_args_into_dataclasses()
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model) device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map) model = dispatch_model(model, device_map)
else: else:
model = model.cuda() model = model.cuda()
model.eval() model.eval()
def format_example(query): def format_example(query):

View File

@ -70,7 +70,7 @@ def main():
ppo_trainer.save_model() ppo_trainer.save_model()
ppo_trainer.save_state() # must be after save_model ppo_trainer.save_state() # must be after save_model
if ppo_trainer.is_world_process_zero() and model_args.plot_loss: if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "reward"]) plot_loss(training_args.output_dir, keys=["loss", "reward"])
def _mp_fn(index): def _mp_fn(index):

View File

@ -55,7 +55,7 @@ def main():
trainer.save_state() trainer.save_state()
trainer.save_model() trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss: if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

View File

@ -56,7 +56,7 @@ def main():
trainer.save_state() trainer.save_state()
trainer.save_model() trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss: if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

View File

@ -72,7 +72,7 @@ def main():
trainer.save_state() trainer.save_state()
trainer.save_model() trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss: if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

View File

@ -13,5 +13,5 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
from .ppo import PPOPeftTrainer from .ppo import PPOPeftTrainer
from .config import ModelArguments from .config import ModelArguments, FinetuningArguments
from .other import get_logits_processor, plot_loss from .other import get_logits_processor, plot_loss

View File

@ -42,8 +42,7 @@ from .other import (
load_valuehead_params, load_valuehead_params,
print_trainable_params, print_trainable_params,
prepare_model_for_training, prepare_model_for_training,
IGNORE_INDEX, IGNORE_INDEX
FINETUNING_ARGS_NAME
) )
check_min_version("4.29.1") check_min_version("4.29.1")
@ -128,7 +127,7 @@ def init_adapter(
def load_pretrained( def load_pretrained(
model_args: ModelArguments, model_args: ModelArguments,
finetuning_args: Optional[FinetuningArguments] = None, finetuning_args: FinetuningArguments,
is_trainable: Optional[bool] = False, is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
@ -137,16 +136,9 @@ def load_pretrained(
Support both training and inference. Support both training and inference.
""" """
if finetuning_args is None: # load the fine-tuning arguments if (not is_trainable) and model_args.checkpoint_dir is None:
if model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.") logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none") finetuning_args = FinetuningArguments(finetuning_type="none")
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
finetuning_args = FinetuningArguments.load_from_json(
os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)
)
else:
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with LoRA method." "RM and PPO training can only be performed with LoRA method."

View File

@ -2,7 +2,7 @@ import torch
from typing import Dict, Optional, Sequence, Union from typing import Dict, Optional, Sequence, Union
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding, BatchEncoding
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
@ -34,7 +34,7 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
attention_mask = attention_mask.bool() attention_mask = attention_mask.bool()
return attention_mask return attention_mask
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]: def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding:
r""" r"""
Pads batched data to the longest sequence in the batch. Pads batched data to the longest sequence in the batch.
@ -64,4 +64,4 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
batch["input_ids"] = input_ids batch["input_ids"] = input_ids
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device) batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
return batch return BatchEncoding(batch)

View File

@ -5,7 +5,6 @@ import torch
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from transformers import Seq2SeqTrainingArguments
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList from transformers.generation.utils import LogitsProcessorList
@ -143,7 +142,7 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]: def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
""" """
EMA implementation according to TensorBoard. EMA implementation according to TensorBoard.
""" """
@ -156,9 +155,10 @@ def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
return smoothed return smoothed
def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None: def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r")) with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)
for key in keys: for key in keys:
steps, metrics = [], [] steps, metrics = [], []
@ -174,9 +174,9 @@ def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]]
plt.figure() plt.figure()
plt.plot(steps, metrics, alpha=0.4, label="original") plt.plot(steps, metrics, alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), label="smoothed") plt.plot(steps, smooth(metrics), label="smoothed")
plt.title("training {} of {}".format(key, training_args.output_dir)) plt.title("training {} of {}".format(key, save_dictionary))
plt.xlabel("step") plt.xlabel("step")
plt.ylabel(key) plt.ylabel(key)
plt.legend() plt.legend()
plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100) plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key))) print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))

View File

@ -109,7 +109,8 @@ class PeftTrainer(Seq2SeqTrainer):
if hasattr(model, "v_head"): # save valuehead weights if hasattr(model, "v_head"): # save valuehead weights
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME)) self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self): def _load_best_model(self):

View File

@ -75,7 +75,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.log_callback = callbacks[0] self.log_callback = callbacks[0]
self.state = TrainerState() self.state = TrainerState()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
def ppo_train(self, max_target_length: int) -> None: def ppo_train(self, max_target_length: int) -> None:
r""" r"""
@ -148,7 +148,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Compute rewards # Compute rewards
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
_, _, values = self.model(**self.prepare_model_inputs(queries, responses)) _, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[:, -1]] rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") # make sure the model is default at the end replace_model(unwrapped_model, target="default") # make sure the model is default at the end
# Run PPO step # Run PPO step
@ -214,13 +214,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
return response[:, inputs["input_ids"].size(1):] return response[:, inputs["input_ids"].size(1):]
return response return response
def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator([{"input_ids": ids} for ids in input_ids])
input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None}
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data
@PPODecorators.empty_cuda_cache() @PPODecorators.empty_cuda_cache()
def batched_forward_pass( def batched_forward_pass(
self, self,

View File

@ -7,21 +7,23 @@ import torch
import mdtex2html import mdtex2html
import gradio as gr import gradio as gr
from utils import ModelArguments, load_pretrained from utils import ModelArguments, FinetuningArguments, load_pretrained
from transformers import HfArgumentParser from transformers import HfArgumentParser
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
parser = HfArgumentParser(ModelArguments) parser = HfArgumentParser((ModelArguments, FinetuningArguments))
model_args, = parser.parse_args_into_dataclasses() model_args, finetuning_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model) device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map) model = dispatch_model(model, device_map)
else: else:
model = model.cuda() model = model.cuda()
model.eval() model.eval()