alter rewards data type
This commit is contained in:
parent
e6126244c1
commit
50d9a20f81
|
@ -4,22 +4,24 @@
|
|||
|
||||
|
||||
import torch
|
||||
from utils import ModelArguments, load_pretrained
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = HfArgumentParser(ModelArguments)
|
||||
model_args, = parser.parse_args_into_dataclasses()
|
||||
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
|
||||
model_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
|
||||
model, tokenizer = load_pretrained(model_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
device_map = infer_auto_device_map(model)
|
||||
model = dispatch_model(model, device_map)
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
model.eval()
|
||||
|
||||
def format_example(query):
|
||||
|
|
|
@ -70,7 +70,7 @@ def main():
|
|||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be after save_model
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args, keys=["loss", "reward"])
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
|
|
@ -55,7 +55,7 @@ def main():
|
|||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args, keys=["loss", "eval_loss"])
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
|
|
|
@ -56,7 +56,7 @@ def main():
|
|||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args, keys=["loss", "eval_loss"])
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
|
|
|
@ -72,7 +72,7 @@ def main():
|
|||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args, keys=["loss", "eval_loss"])
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
|
|
|
@ -13,5 +13,5 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
|
|||
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
|
||||
from .ppo import PPOPeftTrainer
|
||||
|
||||
from .config import ModelArguments
|
||||
from .config import ModelArguments, FinetuningArguments
|
||||
from .other import get_logits_processor, plot_loss
|
||||
|
|
|
@ -42,8 +42,7 @@ from .other import (
|
|||
load_valuehead_params,
|
||||
print_trainable_params,
|
||||
prepare_model_for_training,
|
||||
IGNORE_INDEX,
|
||||
FINETUNING_ARGS_NAME
|
||||
IGNORE_INDEX
|
||||
)
|
||||
|
||||
check_min_version("4.29.1")
|
||||
|
@ -128,7 +127,7 @@ def init_adapter(
|
|||
|
||||
def load_pretrained(
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: Optional[FinetuningArguments] = None,
|
||||
finetuning_args: FinetuningArguments,
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
|
@ -137,16 +136,9 @@ def load_pretrained(
|
|||
|
||||
Support both training and inference.
|
||||
"""
|
||||
if finetuning_args is None: # load the fine-tuning arguments
|
||||
if model_args.checkpoint_dir is None:
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
|
||||
finetuning_args = FinetuningArguments.load_from_json(
|
||||
os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with LoRA method."
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
from transformers import DataCollatorWithPadding
|
||||
from transformers import DataCollatorWithPadding, BatchEncoding
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
@ -34,7 +34,7 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
|
|||
attention_mask = attention_mask.bool()
|
||||
return attention_mask
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
|
@ -64,4 +64,4 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
|
|||
batch["input_ids"] = input_ids
|
||||
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
|
||||
|
||||
return batch
|
||||
return BatchEncoding(batch)
|
||||
|
|
|
@ -5,7 +5,6 @@ import torch
|
|||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
|
@ -143,7 +142,7 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
|
|||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
||||
|
||||
|
||||
def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
|
||||
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
|
||||
"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
|
@ -156,9 +155,10 @@ def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
|
|||
return smoothed
|
||||
|
||||
|
||||
def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r"))
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for key in keys:
|
||||
steps, metrics = [], []
|
||||
|
@ -174,9 +174,9 @@ def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]]
|
|||
plt.figure()
|
||||
plt.plot(steps, metrics, alpha=0.4, label="original")
|
||||
plt.plot(steps, smooth(metrics), label="smoothed")
|
||||
plt.title("training {} of {}".format(key, training_args.output_dir))
|
||||
plt.title("training {} of {}".format(key, save_dictionary))
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100)
|
||||
print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key)))
|
||||
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
|
||||
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
|
||||
|
|
|
@ -109,7 +109,8 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||
if hasattr(model, "v_head"): # save valuehead weights
|
||||
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||
f.write(self.args.to_json_string() + "\n")
|
||||
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
||||
|
||||
def _load_best_model(self):
|
||||
|
|
|
@ -75,7 +75,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
self.finetuning_args = finetuning_args
|
||||
self.log_callback = callbacks[0]
|
||||
self.state = TrainerState()
|
||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
|
||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
|
||||
|
||||
def ppo_train(self, max_target_length: int) -> None:
|
||||
r"""
|
||||
|
@ -148,7 +148,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
# Compute rewards
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
||||
rewards = [reward for reward in values[:, -1]]
|
||||
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
||||
replace_model(unwrapped_model, target="default") # make sure the model is default at the end
|
||||
|
||||
# Run PPO step
|
||||
|
@ -214,13 +214,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
return response[:, inputs["input_ids"].size(1):]
|
||||
return response
|
||||
|
||||
def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
|
||||
input_data = self.data_collator([{"input_ids": ids} for ids in input_ids])
|
||||
input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None}
|
||||
input_data.pop("labels", None) # we don't want to compute LM losses
|
||||
return input_data
|
||||
|
||||
@PPODecorators.empty_cuda_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
|
|
|
@ -7,21 +7,23 @@ import torch
|
|||
import mdtex2html
|
||||
import gradio as gr
|
||||
|
||||
from utils import ModelArguments, load_pretrained
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
|
||||
parser = HfArgumentParser(ModelArguments)
|
||||
model_args, = parser.parse_args_into_dataclasses()
|
||||
model, tokenizer = load_pretrained(model_args)
|
||||
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
|
||||
model_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
device_map = infer_auto_device_map(model)
|
||||
model = dispatch_model(model, device_map)
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
model.eval()
|
||||
|
||||
|
||||
|
@ -74,10 +76,10 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
|
|||
|
||||
|
||||
def format_example(query):
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
|
||||
return prompt
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
|
||||
return prompt
|
||||
|
||||
|
||||
def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||
|
|
Loading…
Reference in New Issue