support save full model, replace BOS token

This commit is contained in:
hiyouga 2023-06-27 21:40:11 +08:00
parent 1c732e2537
commit 2e01abfda5
3 changed files with 48 additions and 33 deletions

View File

@ -103,8 +103,9 @@ def _init_adapter(
lastest_checkpoint = None
if model_args.checkpoint_dir is not None:
if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \
not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)):
raise ValueError("Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]))
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
please specify `--finetuning_type full/freeze` instead.")
@ -449,7 +450,7 @@ def preprocess_data(
yield dialog
def preprocess_pretrain_dataset(examples):
# build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS])
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids)
@ -465,7 +466,7 @@ def preprocess_data(
}
def preprocess_supervised_dataset(examples):
# build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]`
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for input with history, we build multiple input-label pairs just like:
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "labels": []}
@ -475,15 +476,26 @@ def preprocess_data(
for i in range(len(dialog) // 2):
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
input_ids += [tokenizer.bos_token_id] + source_ids + target_ids + [tokenizer.eos_token_id]
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length])
model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length])
if len(input_ids) > data_args.max_source_length + data_args.max_target_length:
input_ids = input_ids[:data_args.max_source_length + data_args.max_target_length]
if len(labels) > data_args.max_source_length + data_args.max_target_length:
labels = labels[:data_args.max_source_length + data_args.max_target_length]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_unsupervised_dataset(examples):
# build inputs with format `X [BOS]` and labels with format `Y [BOS]`
# build inputs with format `<bos> X` and labels with format `<bos> Y`
model_inputs = {"input_ids": [], "labels": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
@ -496,15 +508,15 @@ def preprocess_data(
if len(target_ids) > data_args.max_target_length - 1: # bos token
target_ids = target_ids[:data_args.max_target_length - 1]
input_ids = source_ids + [tokenizer.bos_token_id]
labels = target_ids + [tokenizer.bos_token_id]
input_ids = [tokenizer.bos_token_id] + source_ids
labels = [tokenizer.bos_token_id] + target_ids
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_pairwise_dataset(examples):
# build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]`
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
@ -520,8 +532,8 @@ def preprocess_data(
if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]
reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id]
accept_ids = [tokenizer.bos_token_id] + source_ids + accept_ids + [tokenizer.eos_token_id]
reject_ids = [tokenizer.bos_token_id] + source_ids + reject_ids + [tokenizer.eos_token_id]
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)

View File

@ -5,13 +5,11 @@ import torch
import logging
from typing import Dict, List, Optional
from transformers.trainer import TRAINER_STATE_NAME
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
from peft.utils import WEIGHTS_NAME
IGNORE_INDEX = -100
VALUE_HEAD_FILE_NAME = "value_head.bin"

View File

@ -16,8 +16,6 @@ from transformers import (
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model
from peft.utils.other import WEIGHTS_NAME
from .config import FinetuningArguments
from .other import (
@ -98,18 +96,22 @@ class PeftTrainer(Seq2SeqTrainer):
logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model)
if hasattr(model, "pretrained_model"): # for models with valuehead
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
backbone_model = getattr(model, "pretrained_model")
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
else:
backbone_model = model
if hasattr(backbone_model, "peft_config"): # peft methods
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights
else:
torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights
if hasattr(model, "v_head"): # save valuehead weights
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
if self.finetuning_args.finetuning_type == "lora":
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
else: # freeze/full tuning
backbone_model.save_pretrained(
output_dir,
state_dict=get_state_dict(backbone_model),
safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
@ -122,11 +124,14 @@ class PeftTrainer(Seq2SeqTrainer):
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model)
if hasattr(model, "peft_config"): # peft methods
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
else:
load_trainable_params(model, self.state.best_model_checkpoint)
if hasattr(model, "v_head"):
load_valuehead_params(model, self.state.best_model_checkpoint)
model = unwrap_model(self.model)
if self.finetuning_args.finetuning_type == "lora":
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
else: # freeze/full-tuning
load_trainable_params(model, self.state.best_model_checkpoint)