add ziya prompt template
This commit is contained in:
parent
771f454ff1
commit
de09ee1315
|
@ -4,14 +4,16 @@
|
|||
|
||||
|
||||
import torch
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
|
||||
from transformers import HfArgumentParser
|
||||
from utils import (
|
||||
load_pretrained,
|
||||
prepare_infer_args,
|
||||
get_logits_processor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
|
||||
model_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
|
@ -24,14 +26,26 @@ def main():
|
|||
|
||||
model.eval()
|
||||
|
||||
def format_example(query):
|
||||
def format_example_alpaca(query, history):
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
|
||||
prompt += "Instruction:\n"
|
||||
for old_query, response in history:
|
||||
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
|
||||
prompt += "Human: {}\nAssistant:".format(query)
|
||||
return prompt
|
||||
|
||||
def format_example_ziya(query, history):
|
||||
prompt = ""
|
||||
for old_query, response in history:
|
||||
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
|
||||
prompt += "<human>: {}\n<bot>:".format(query)
|
||||
return prompt
|
||||
|
||||
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
|
||||
|
||||
def predict(query, history: list):
|
||||
input_ids = tokenizer([format_example(query)], return_tensors="pt")["input_ids"]
|
||||
input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"]
|
||||
input_ids = input_ids.to(model.device)
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
|
@ -65,6 +79,7 @@ def main():
|
|||
|
||||
if query.strip() == "clear":
|
||||
history = []
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
response, history = predict(query, history)
|
||||
|
|
|
@ -3,19 +3,15 @@
|
|||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
||||
|
||||
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained
|
||||
from utils import load_pretrained, prepare_args
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, TrainingArguments, FinetuningArguments))
|
||||
model_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
model_args, _, training_args, finetuning_args = prepare_args(stage="sft")
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
print("model and tokenizer have been saved at:", training_args.output_dir)
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .common import (
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_infer_args,
|
||||
prepare_data,
|
||||
preprocess_data
|
||||
)
|
||||
|
@ -13,5 +14,4 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
|
|||
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
|
||||
from .ppo import PPOPeftTrainer
|
||||
|
||||
from .config import ModelArguments, FinetuningArguments
|
||||
from .other import get_logits_processor, plot_loss
|
||||
|
|
|
@ -264,6 +264,18 @@ def prepare_args(
|
|||
return model_args, data_args, training_args, finetuning_args
|
||||
|
||||
|
||||
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
||||
model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
return model_args, data_args, finetuning_args
|
||||
|
||||
|
||||
def prepare_data(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataTrainingArguments
|
||||
|
@ -347,7 +359,8 @@ def preprocess_data(
|
|||
column_names = list(dataset.column_names)
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
def format_example(examples): # support question with a single answer or multiple answers
|
||||
# support question with a single answer or multiple answers
|
||||
def format_example_alpaca(examples):
|
||||
for i in range(len(examples["prompt"])):
|
||||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
|
@ -357,12 +370,27 @@ def preprocess_data(
|
|||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\n" + prefix
|
||||
if examples["history"][i]:
|
||||
history = examples["history"][i]
|
||||
for old_query, response in history:
|
||||
for old_query, response in examples["history"][i]:
|
||||
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
|
||||
prompt += "Human: {}\nAssistant: ".format(query)
|
||||
yield prompt, answer
|
||||
|
||||
def format_example_ziya(examples):
|
||||
for i in range(len(examples["prompt"])):
|
||||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
if examples["query"][i]:
|
||||
query += "\n" + examples["query"][i]
|
||||
prompt = ""
|
||||
if examples["history"][i]:
|
||||
for old_query, response in examples["history"][i]:
|
||||
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
|
||||
prompt += "<human>: {}\n<bot>:".format(query)
|
||||
prompt = prefix + prompt
|
||||
yield prompt, answer
|
||||
|
||||
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>)
|
||||
text_ids = tokenizer(examples["prompt"])["input_ids"]
|
||||
|
|
|
@ -136,6 +136,10 @@ class DataTrainingArguments:
|
|||
default=0,
|
||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
||||
)
|
||||
prompt_template: Optional[Literal["alpaca", "ziya"]] = field(
|
||||
default="alpaca",
|
||||
metadata={"help": "Which template to use for constructing prompts in training."}
|
||||
)
|
||||
|
||||
def __post_init__(self): # support mixing multiple datasets
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||
|
|
|
@ -7,14 +7,12 @@ import torch
|
|||
import mdtex2html
|
||||
import gradio as gr
|
||||
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
|
||||
from transformers import HfArgumentParser
|
||||
from utils import load_pretrained, prepare_infer_args, get_logits_processor
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
|
||||
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
|
||||
model_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
|
@ -75,17 +73,31 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
|
|||
return text
|
||||
|
||||
|
||||
def format_example(query):
|
||||
def format_example_alpaca(query, history):
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
|
||||
prompt += "Instruction:\n"
|
||||
for old_query, response in history:
|
||||
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
|
||||
prompt += "Human: {}\nAssistant:".format(query)
|
||||
return prompt
|
||||
|
||||
|
||||
def format_example_ziya(query, history):
|
||||
prompt = ""
|
||||
for old_query, response in history:
|
||||
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
|
||||
prompt += "<human>: {}\n<bot>:".format(query)
|
||||
return prompt
|
||||
|
||||
|
||||
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
|
||||
|
||||
|
||||
def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||
chatbot.append((parse_text(input), ""))
|
||||
|
||||
input_ids = tokenizer([format_example(input)], return_tensors="pt")["input_ids"]
|
||||
input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"]
|
||||
input_ids = input_ids.to(model.device)
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
|
|
Loading…
Reference in New Issue