support save full model, replace BOS token

This commit is contained in:
hiyouga 2023-06-27 21:40:11 +08:00
parent 1c732e2537
commit 2e01abfda5
3 changed files with 48 additions and 33 deletions

View File

@ -103,8 +103,9 @@ def _init_adapter(
lastest_checkpoint = None lastest_checkpoint = None
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \ if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)):
not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)): raise ValueError("Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]))
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \ raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
please specify `--finetuning_type full/freeze` instead.") please specify `--finetuning_type full/freeze` instead.")
@ -449,7 +450,7 @@ def preprocess_data(
yield dialog yield dialog
def preprocess_pretrain_dataset(examples): def preprocess_pretrain_dataset(examples):
# build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS]) # build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
concatenated_ids = list(chain(*text_ids)) concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids) total_length = len(concatenated_ids)
@ -465,7 +466,7 @@ def preprocess_data(
} }
def preprocess_supervised_dataset(examples): def preprocess_supervised_dataset(examples):
# build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for input with history, we build multiple input-label pairs just like: # for input with history, we build multiple input-label pairs just like:
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112 # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "labels": []} model_inputs = {"input_ids": [], "labels": []}
@ -475,15 +476,26 @@ def preprocess_data(
for i in range(len(dialog) // 2): for i in range(len(dialog) // 2):
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False) source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
input_ids += [tokenizer.bos_token_id] + source_ids + target_ids + [tokenizer.eos_token_id]
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id] labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length]) if len(input_ids) > data_args.max_source_length + data_args.max_target_length:
model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length]) input_ids = input_ids[:data_args.max_source_length + data_args.max_target_length]
if len(labels) > data_args.max_source_length + data_args.max_target_length:
labels = labels[:data_args.max_source_length + data_args.max_target_length]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs return model_inputs
def preprocess_unsupervised_dataset(examples): def preprocess_unsupervised_dataset(examples):
# build inputs with format `X [BOS]` and labels with format `Y [BOS]` # build inputs with format `<bos> X` and labels with format `<bos> Y`
model_inputs = {"input_ids": [], "labels": []} model_inputs = {"input_ids": [], "labels": []}
for dialog in get_dialog(examples): for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1] prompt, answer = "".join(dialog[:-1]), dialog[-1]
@ -496,15 +508,15 @@ def preprocess_data(
if len(target_ids) > data_args.max_target_length - 1: # bos token if len(target_ids) > data_args.max_target_length - 1: # bos token
target_ids = target_ids[:data_args.max_target_length - 1] target_ids = target_ids[:data_args.max_target_length - 1]
input_ids = source_ids + [tokenizer.bos_token_id] input_ids = [tokenizer.bos_token_id] + source_ids
labels = target_ids + [tokenizer.bos_token_id] labels = [tokenizer.bos_token_id] + target_ids
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
return model_inputs return model_inputs
def preprocess_pairwise_dataset(examples): def preprocess_pairwise_dataset(examples):
# build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]` # build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []} model_inputs = {"accept_ids": [], "reject_ids": []}
for dialog in get_dialog(examples): for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1] prompt, answer = "".join(dialog[:-1]), dialog[-1]
@ -520,8 +532,8 @@ def preprocess_data(
if len(reject_ids) > data_args.max_target_length - 1: # eos token if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1] reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id] accept_ids = [tokenizer.bos_token_id] + source_ids + accept_ids + [tokenizer.eos_token_id]
reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id] reject_ids = [tokenizer.bos_token_id] + source_ids + reject_ids + [tokenizer.eos_token_id]
model_inputs["accept_ids"].append(accept_ids) model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids) model_inputs["reject_ids"].append(reject_ids)

View File

@ -5,13 +5,11 @@ import torch
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor from transformers.generation.logits_process import LogitsProcessor
from peft.utils import WEIGHTS_NAME
IGNORE_INDEX = -100 IGNORE_INDEX = -100
VALUE_HEAD_FILE_NAME = "value_head.bin" VALUE_HEAD_FILE_NAME = "value_head.bin"

View File

@ -16,8 +16,6 @@ from transformers import (
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model from transformers.modeling_utils import unwrap_model
from peft.utils.other import WEIGHTS_NAME
from .config import FinetuningArguments from .config import FinetuningArguments
from .other import ( from .other import (
@ -98,18 +96,22 @@ class PeftTrainer(Seq2SeqTrainer):
logger.info(f"Saving model checkpoint to {output_dir}") logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model) model = unwrap_model(self.model)
if hasattr(model, "pretrained_model"): # for models with valuehead if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
backbone_model = getattr(model, "pretrained_model") backbone_model = getattr(model, "pretrained_model")
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
else: else:
backbone_model = model backbone_model = model
if hasattr(backbone_model, "peft_config"): # peft methods if self.finetuning_args.finetuning_type == "lora":
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
else: else: # freeze/full tuning
torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights backbone_model.save_pretrained(
output_dir,
if hasattr(model, "v_head"): # save valuehead weights state_dict=get_state_dict(backbone_model),
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n") f.write(self.args.to_json_string() + "\n")
@ -122,11 +124,14 @@ class PeftTrainer(Seq2SeqTrainer):
Subclass and override to inject custom behavior. It should not be directly used by external scripts. Subclass and override to inject custom behavior. It should not be directly used by external scripts.
""" """
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model)
if hasattr(model, "peft_config"): # peft methods
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
else:
load_trainable_params(model, self.state.best_model_checkpoint)
if hasattr(model, "v_head"): model = unwrap_model(self.model)
load_valuehead_params(model, self.state.best_model_checkpoint) if self.finetuning_args.finetuning_type == "lora":
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
else: # freeze/full-tuning
load_trainable_params(model, self.state.best_model_checkpoint)