alter rewards data type

This commit is contained in:
hiyouga 2023-06-02 14:19:51 +08:00
parent e6126244c1
commit 50d9a20f81
12 changed files with 40 additions and 50 deletions

View File

@ -4,22 +4,24 @@
import torch
from utils import ModelArguments, load_pretrained
from utils import ModelArguments, FinetuningArguments, load_pretrained
from transformers import HfArgumentParser
def main():
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
model_args, finetuning_args = parser.parse_args_into_dataclasses()
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args)
model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
def format_example(query):

View File

@ -70,7 +70,7 @@ def main():
ppo_trainer.save_model()
ppo_trainer.save_state() # must be after save_model
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "reward"])
plot_loss(training_args.output_dir, keys=["loss", "reward"])
def _mp_fn(index):

View File

@ -55,7 +55,7 @@ def main():
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:

View File

@ -56,7 +56,7 @@ def main():
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:

View File

@ -72,7 +72,7 @@ def main():
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:

View File

@ -13,5 +13,5 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
from .ppo import PPOPeftTrainer
from .config import ModelArguments
from .config import ModelArguments, FinetuningArguments
from .other import get_logits_processor, plot_loss

View File

@ -42,8 +42,7 @@ from .other import (
load_valuehead_params,
print_trainable_params,
prepare_model_for_training,
IGNORE_INDEX,
FINETUNING_ARGS_NAME
IGNORE_INDEX
)
check_min_version("4.29.1")
@ -128,7 +127,7 @@ def init_adapter(
def load_pretrained(
model_args: ModelArguments,
finetuning_args: Optional[FinetuningArguments] = None,
finetuning_args: FinetuningArguments,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
@ -137,16 +136,9 @@ def load_pretrained(
Support both training and inference.
"""
if finetuning_args is None: # load the fine-tuning arguments
if model_args.checkpoint_dir is None:
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
finetuning_args = FinetuningArguments.load_from_json(
os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)
)
else:
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with LoRA method."

View File

@ -2,7 +2,7 @@ import torch
from typing import Dict, Optional, Sequence, Union
from transformers import DataCollatorWithPadding
from transformers import DataCollatorWithPadding, BatchEncoding
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
@ -34,7 +34,7 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
attention_mask = attention_mask.bool()
return attention_mask
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding:
r"""
Pads batched data to the longest sequence in the batch.
@ -64,4 +64,4 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
batch["input_ids"] = input_ids
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
return batch
return BatchEncoding(batch)

View File

@ -5,7 +5,6 @@ import torch
import logging
from typing import Dict, List, Optional
from transformers import Seq2SeqTrainingArguments
from transformers.trainer import TRAINER_STATE_NAME
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
@ -143,7 +142,7 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
"""
EMA implementation according to TensorBoard.
"""
@ -156,9 +155,10 @@ def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
return smoothed
def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None:
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
import matplotlib.pyplot as plt
data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r"))
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)
for key in keys:
steps, metrics = [], []
@ -174,9 +174,9 @@ def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]]
plt.figure()
plt.plot(steps, metrics, alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), label="smoothed")
plt.title("training {} of {}".format(key, training_args.output_dir))
plt.title("training {} of {}".format(key, save_dictionary))
plt.xlabel("step")
plt.ylabel(key)
plt.legend()
plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100)
print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key)))
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))

View File

@ -109,7 +109,8 @@ class PeftTrainer(Seq2SeqTrainer):
if hasattr(model, "v_head"): # save valuehead weights
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):

View File

@ -75,7 +75,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.finetuning_args = finetuning_args
self.log_callback = callbacks[0]
self.state = TrainerState()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
def ppo_train(self, max_target_length: int) -> None:
r"""
@ -148,7 +148,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Compute rewards
replace_model(unwrapped_model, target="reward")
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[:, -1]]
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") # make sure the model is default at the end
# Run PPO step
@ -214,13 +214,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
return response[:, inputs["input_ids"].size(1):]
return response
def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator([{"input_ids": ids} for ids in input_ids])
input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None}
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data
@PPODecorators.empty_cuda_cache()
def batched_forward_pass(
self,

View File

@ -7,21 +7,23 @@ import torch
import mdtex2html
import gradio as gr
from utils import ModelArguments, load_pretrained
from utils import ModelArguments, FinetuningArguments, load_pretrained
from transformers import HfArgumentParser
from transformers.utils.versions import require_version
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
model_args, finetuning_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()