support save full model, replace BOS token
This commit is contained in:
parent
1c732e2537
commit
2e01abfda5
|
@ -103,8 +103,9 @@ def _init_adapter(
|
|||
lastest_checkpoint = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \
|
||||
not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
|
||||
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)):
|
||||
raise ValueError("Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]))
|
||||
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
|
||||
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
|
||||
please specify `--finetuning_type full/freeze` instead.")
|
||||
|
||||
|
@ -449,7 +450,7 @@ def preprocess_data(
|
|||
yield dialog
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
# build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS])
|
||||
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
||||
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
||||
concatenated_ids = list(chain(*text_ids))
|
||||
total_length = len(concatenated_ids)
|
||||
|
@ -465,7 +466,7 @@ def preprocess_data(
|
|||
}
|
||||
|
||||
def preprocess_supervised_dataset(examples):
|
||||
# build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]`
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for input with history, we build multiple input-label pairs just like:
|
||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
|
@ -475,15 +476,26 @@ def preprocess_data(
|
|||
for i in range(len(dialog) // 2):
|
||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
|
||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
||||
input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
|
||||
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
|
||||
input_ids += [tokenizer.bos_token_id] + source_ids + target_ids + [tokenizer.eos_token_id]
|
||||
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length])
|
||||
model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length])
|
||||
if len(input_ids) > data_args.max_source_length + data_args.max_target_length:
|
||||
input_ids = input_ids[:data_args.max_source_length + data_args.max_target_length]
|
||||
if len(labels) > data_args.max_source_length + data_args.max_target_length:
|
||||
labels = labels[:data_args.max_source_length + data_args.max_target_length]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["labels"].append(labels)
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples):
|
||||
# build inputs with format `X [BOS]` and labels with format `Y [BOS]`
|
||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
|
@ -496,15 +508,15 @@ def preprocess_data(
|
|||
if len(target_ids) > data_args.max_target_length - 1: # bos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
|
||||
input_ids = source_ids + [tokenizer.bos_token_id]
|
||||
labels = target_ids + [tokenizer.bos_token_id]
|
||||
input_ids = [tokenizer.bos_token_id] + source_ids
|
||||
labels = [tokenizer.bos_token_id] + target_ids
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["labels"].append(labels)
|
||||
return model_inputs
|
||||
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]`
|
||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
|
@ -520,8 +532,8 @@ def preprocess_data(
|
|||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||
|
||||
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]
|
||||
reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id]
|
||||
accept_ids = [tokenizer.bos_token_id] + source_ids + accept_ids + [tokenizer.eos_token_id]
|
||||
reject_ids = [tokenizer.bos_token_id] + source_ids + reject_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["accept_ids"].append(accept_ids)
|
||||
model_inputs["reject_ids"].append(reject_ids)
|
||||
|
|
|
@ -5,13 +5,11 @@ import torch
|
|||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
|
||||
from peft.utils import WEIGHTS_NAME
|
||||
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||
|
|
|
@ -16,8 +16,6 @@ from transformers import (
|
|||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
|
||||
from peft.utils.other import WEIGHTS_NAME
|
||||
|
||||
from .config import FinetuningArguments
|
||||
|
||||
from .other import (
|
||||
|
@ -98,18 +96,22 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
model = unwrap_model(self.model)
|
||||
|
||||
if hasattr(model, "pretrained_model"): # for models with valuehead
|
||||
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
|
||||
backbone_model = getattr(model, "pretrained_model")
|
||||
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
else:
|
||||
backbone_model = model
|
||||
|
||||
if hasattr(backbone_model, "peft_config"): # peft methods
|
||||
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights
|
||||
else:
|
||||
torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights
|
||||
|
||||
if hasattr(model, "v_head"): # save valuehead weights
|
||||
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
if self.finetuning_args.finetuning_type == "lora":
|
||||
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
|
||||
else: # freeze/full tuning
|
||||
backbone_model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=get_state_dict(backbone_model),
|
||||
safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||
f.write(self.args.to_json_string() + "\n")
|
||||
|
@ -122,11 +124,14 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
model = unwrap_model(self.model)
|
||||
if hasattr(model, "peft_config"): # peft methods
|
||||
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
|
||||
else:
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
||||
if hasattr(model, "v_head"):
|
||||
load_valuehead_params(model, self.state.best_model_checkpoint)
|
||||
model = unwrap_model(self.model)
|
||||
if self.finetuning_args.finetuning_type == "lora":
|
||||
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
|
||||
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
else: # freeze/full-tuning
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
|
Loading…
Reference in New Issue