support save full model, replace BOS token
This commit is contained in:
parent
1c732e2537
commit
2e01abfda5
|
@ -103,8 +103,9 @@ def _init_adapter(
|
||||||
lastest_checkpoint = None
|
lastest_checkpoint = None
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \
|
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)):
|
||||||
not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
|
raise ValueError("Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]))
|
||||||
|
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
|
||||||
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
|
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
|
||||||
please specify `--finetuning_type full/freeze` instead.")
|
please specify `--finetuning_type full/freeze` instead.")
|
||||||
|
|
||||||
|
@ -449,7 +450,7 @@ def preprocess_data(
|
||||||
yield dialog
|
yield dialog
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(examples):
|
def preprocess_pretrain_dataset(examples):
|
||||||
# build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS])
|
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
||||||
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
||||||
concatenated_ids = list(chain(*text_ids))
|
concatenated_ids = list(chain(*text_ids))
|
||||||
total_length = len(concatenated_ids)
|
total_length = len(concatenated_ids)
|
||||||
|
@ -465,7 +466,7 @@ def preprocess_data(
|
||||||
}
|
}
|
||||||
|
|
||||||
def preprocess_supervised_dataset(examples):
|
def preprocess_supervised_dataset(examples):
|
||||||
# build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]`
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
# for input with history, we build multiple input-label pairs just like:
|
# for input with history, we build multiple input-label pairs just like:
|
||||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||||
model_inputs = {"input_ids": [], "labels": []}
|
model_inputs = {"input_ids": [], "labels": []}
|
||||||
|
@ -475,15 +476,26 @@ def preprocess_data(
|
||||||
for i in range(len(dialog) // 2):
|
for i in range(len(dialog) // 2):
|
||||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
|
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
|
||||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
||||||
input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
|
|
||||||
|
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||||
|
source_ids = source_ids[:data_args.max_source_length - 1]
|
||||||
|
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||||
|
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||||
|
|
||||||
|
input_ids += [tokenizer.bos_token_id] + source_ids + target_ids + [tokenizer.eos_token_id]
|
||||||
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
|
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length])
|
if len(input_ids) > data_args.max_source_length + data_args.max_target_length:
|
||||||
model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length])
|
input_ids = input_ids[:data_args.max_source_length + data_args.max_target_length]
|
||||||
|
if len(labels) > data_args.max_source_length + data_args.max_target_length:
|
||||||
|
labels = labels[:data_args.max_source_length + data_args.max_target_length]
|
||||||
|
|
||||||
|
model_inputs["input_ids"].append(input_ids)
|
||||||
|
model_inputs["labels"].append(labels)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(examples):
|
def preprocess_unsupervised_dataset(examples):
|
||||||
# build inputs with format `X [BOS]` and labels with format `Y [BOS]`
|
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||||
model_inputs = {"input_ids": [], "labels": []}
|
model_inputs = {"input_ids": [], "labels": []}
|
||||||
for dialog in get_dialog(examples):
|
for dialog in get_dialog(examples):
|
||||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||||
|
@ -496,15 +508,15 @@ def preprocess_data(
|
||||||
if len(target_ids) > data_args.max_target_length - 1: # bos token
|
if len(target_ids) > data_args.max_target_length - 1: # bos token
|
||||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||||
|
|
||||||
input_ids = source_ids + [tokenizer.bos_token_id]
|
input_ids = [tokenizer.bos_token_id] + source_ids
|
||||||
labels = target_ids + [tokenizer.bos_token_id]
|
labels = [tokenizer.bos_token_id] + target_ids
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
model_inputs["labels"].append(labels)
|
model_inputs["labels"].append(labels)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(examples):
|
def preprocess_pairwise_dataset(examples):
|
||||||
# build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]`
|
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
||||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||||
for dialog in get_dialog(examples):
|
for dialog in get_dialog(examples):
|
||||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||||
|
@ -520,8 +532,8 @@ def preprocess_data(
|
||||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||||
|
|
||||||
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]
|
accept_ids = [tokenizer.bos_token_id] + source_ids + accept_ids + [tokenizer.eos_token_id]
|
||||||
reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id]
|
reject_ids = [tokenizer.bos_token_id] + source_ids + reject_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
model_inputs["accept_ids"].append(accept_ids)
|
model_inputs["accept_ids"].append(accept_ids)
|
||||||
model_inputs["reject_ids"].append(reject_ids)
|
model_inputs["reject_ids"].append(reject_ids)
|
||||||
|
|
|
@ -5,13 +5,11 @@ import torch
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from transformers.trainer import TRAINER_STATE_NAME
|
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.generation.utils import LogitsProcessorList
|
from transformers.generation.utils import LogitsProcessorList
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
|
||||||
from peft.utils import WEIGHTS_NAME
|
|
||||||
|
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||||
|
|
|
@ -16,8 +16,6 @@ from transformers import (
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
from transformers.modeling_utils import unwrap_model
|
from transformers.modeling_utils import unwrap_model
|
||||||
|
|
||||||
from peft.utils.other import WEIGHTS_NAME
|
|
||||||
|
|
||||||
from .config import FinetuningArguments
|
from .config import FinetuningArguments
|
||||||
|
|
||||||
from .other import (
|
from .other import (
|
||||||
|
@ -98,18 +96,22 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||||
model = unwrap_model(self.model)
|
model = unwrap_model(self.model)
|
||||||
|
|
||||||
if hasattr(model, "pretrained_model"): # for models with valuehead
|
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
|
||||||
backbone_model = getattr(model, "pretrained_model")
|
backbone_model = getattr(model, "pretrained_model")
|
||||||
|
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||||
else:
|
else:
|
||||||
backbone_model = model
|
backbone_model = model
|
||||||
|
|
||||||
if hasattr(backbone_model, "peft_config"): # peft methods
|
if self.finetuning_args.finetuning_type == "lora":
|
||||||
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights
|
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
|
||||||
else:
|
else: # freeze/full tuning
|
||||||
torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights
|
backbone_model.save_pretrained(
|
||||||
|
output_dir,
|
||||||
if hasattr(model, "v_head"): # save valuehead weights
|
state_dict=get_state_dict(backbone_model),
|
||||||
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
safe_serialization=self.args.save_safetensors
|
||||||
|
)
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||||
f.write(self.args.to_json_string() + "\n")
|
f.write(self.args.to_json_string() + "\n")
|
||||||
|
@ -122,11 +124,14 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||||
model = unwrap_model(self.model)
|
|
||||||
if hasattr(model, "peft_config"): # peft methods
|
|
||||||
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
|
|
||||||
else:
|
|
||||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
|
||||||
|
|
||||||
if hasattr(model, "v_head"):
|
model = unwrap_model(self.model)
|
||||||
load_valuehead_params(model, self.state.best_model_checkpoint)
|
if self.finetuning_args.finetuning_type == "lora":
|
||||||
|
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
|
||||||
|
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
|
||||||
|
model.v_head.load_state_dict({
|
||||||
|
"summary.weight": getattr(model, "reward_head_weight"),
|
||||||
|
"summary.bias": getattr(model, "reward_head_bias")
|
||||||
|
})
|
||||||
|
else: # freeze/full-tuning
|
||||||
|
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||||
|
|
Loading…
Reference in New Issue