fix checkpoint loading
This commit is contained in:
parent
ce71cc8b6d
commit
c0e5df92d6
|
@ -21,8 +21,14 @@ def main():
|
|||
model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
def format_example(query):
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
|
||||
return prompt
|
||||
|
||||
def predict(query, history: list):
|
||||
inputs = tokenizer([query], return_tensors="pt")
|
||||
inputs = tokenizer([format_example(query)], return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
import sys
|
||||
import torch
|
||||
import hashlib
|
||||
from itertools import chain
|
||||
from typing import List, Literal, Optional, Tuple
|
||||
|
||||
import transformers
|
||||
|
@ -84,6 +85,8 @@ def init_adapter(
|
|||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
if len(model_args.checkpoint_dir) > 1:
|
||||
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
|
||||
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
|
@ -117,6 +120,9 @@ def init_adapter(
|
|||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
@ -131,19 +137,14 @@ def load_pretrained(
|
|||
|
||||
Support both training and inference.
|
||||
"""
|
||||
|
||||
if (not is_trainable) and (model_args.checkpoint_dir is None):
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint
|
||||
for checkpoint_dir in model_args.checkpoint_dir:
|
||||
if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)):
|
||||
raise ValueError("The fine-tuning arguments are not found in the provided dictionary.")
|
||||
logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
|
||||
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1:
|
||||
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
|
||||
if finetuning_args is None: # load the fine-tuning arguments
|
||||
if model_args.checkpoint_dir is None:
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
|
||||
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
|
||||
else:
|
||||
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
|
||||
|
||||
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method."
|
||||
|
||||
|
@ -350,7 +351,7 @@ def preprocess_data(
|
|||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
if examples["query"][i]:
|
||||
query += examples["query"][i]
|
||||
query += "\n" + examples["query"][i]
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\n" + prefix
|
||||
|
@ -361,6 +362,20 @@ def preprocess_data(
|
|||
prompt += "Human: {}\nAssistant: ".format(query)
|
||||
yield prompt, answer
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
# build grouped texts with format `<s>??`
|
||||
text_ids = tokenizer(examples["prompt"])["input_ids"]
|
||||
concatenated_ids = list(chain(*text_ids))
|
||||
total_length = len(concatenated_ids)
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // data_args.max_source_length) * data_args.max_source_length
|
||||
# split by chunks of max_source_length
|
||||
result = [concatenated_ids[i: i+data_args.max_source_length] for i in range(0, total_length, data_args.max_source_length)]
|
||||
return {
|
||||
"input_ids": result,
|
||||
"labels": result.copy()
|
||||
}
|
||||
|
||||
def preprocess_supervised_dataset(examples):
|
||||
# build inputs with format `X <s> Y </s>` and labels with format `<ignore> ... <ignore> <s> Y </s>`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
|
@ -425,7 +440,9 @@ def preprocess_data(
|
|||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]])))
|
||||
print("labels:\n{}".format(
|
||||
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))
|
||||
)
|
||||
|
||||
def print_pairwise_dataset_example(example):
|
||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
||||
|
@ -437,11 +454,11 @@ def preprocess_data(
|
|||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
||||
|
||||
if stage == "sft":
|
||||
if (not training_args.do_train) and training_args.predict_with_generate: # with generation
|
||||
preprocess_function = preprocess_evaluation_dataset
|
||||
else: # without generation
|
||||
preprocess_function = preprocess_supervised_dataset
|
||||
if stage == "pt":
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
elif stage == "sft":
|
||||
preprocess_function = preprocess_evaluation_dataset \
|
||||
if training_args.predict_with_generate else preprocess_supervised_dataset
|
||||
elif stage == "rm":
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
elif stage == "ppo":
|
||||
|
|
|
@ -194,7 +194,8 @@ class FinetuningArguments:
|
|||
if self.name_module_trainable == "mlp":
|
||||
self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids]
|
||||
elif self.name_module_trainable == "qkv":
|
||||
self.trainable_layers = ["layers.{:d}.attention.query_key_value".format(idx) for idx in trainable_layer_ids]
|
||||
self.trainable_layers = ["layers.{:d}.self_attn.{}".format(idx, proj) \
|
||||
for proj in ["k_proj", "q_proj", "v_proj", "o_proj"] for idx in trainable_layer_ids]
|
||||
|
||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||
|
||||
|
|
|
@ -9,8 +9,10 @@ import gradio as gr
|
|||
|
||||
from utils import ModelArguments, auto_configure_device_map, load_pretrained
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
|
||||
parser = HfArgumentParser(ModelArguments)
|
||||
model_args, = parser.parse_args_into_dataclasses()
|
||||
model, tokenizer = load_pretrained(model_args)
|
||||
|
@ -71,10 +73,17 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
|
|||
return text
|
||||
|
||||
|
||||
def format_example(query):
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
|
||||
return prompt
|
||||
|
||||
|
||||
def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||
chatbot.append((parse_text(input), ""))
|
||||
|
||||
inputs = tokenizer([input], return_tensors="pt")
|
||||
inputs = tokenizer([format_example(input)], return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
|
|
Loading…
Reference in New Issue