fix inference, add prompt template
This commit is contained in:
parent
13d1f0709c
commit
5d021d4ad5
|
@ -23,7 +23,9 @@ from fastapi import FastAPI, Request
|
|||
from utils import (
|
||||
load_pretrained,
|
||||
prepare_infer_args,
|
||||
get_logits_processor
|
||||
get_logits_processor,
|
||||
prompt_template_alpaca,
|
||||
prompt_template_ziya
|
||||
)
|
||||
|
||||
|
||||
|
@ -96,23 +98,6 @@ async def create_item(request: Request):
|
|||
if __name__ == "__main__":
|
||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
def format_example_alpaca(query, history):
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\n"
|
||||
for old_query, response in history:
|
||||
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
|
||||
prompt += "Human: {}\nAssistant:".format(query)
|
||||
return prompt
|
||||
|
||||
def format_example_ziya(query, history):
|
||||
prompt = ""
|
||||
for old_query, response in history:
|
||||
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
|
||||
prompt += "<human>: {}\n<bot>:".format(query)
|
||||
return prompt
|
||||
|
||||
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
|
||||
format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
|
||||
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||
|
|
|
@ -6,7 +6,9 @@
|
|||
from utils import (
|
||||
load_pretrained,
|
||||
prepare_infer_args,
|
||||
get_logits_processor
|
||||
get_logits_processor,
|
||||
prompt_template_alpaca,
|
||||
prompt_template_ziya
|
||||
)
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
|
@ -18,23 +20,7 @@ def main():
|
|||
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
def format_example_alpaca(query, history):
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\n"
|
||||
for old_query, response in history:
|
||||
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
|
||||
prompt += "Human: {}\nAssistant:".format(query)
|
||||
return prompt
|
||||
|
||||
def format_example_ziya(query, history):
|
||||
prompt = ""
|
||||
for old_query, response in history:
|
||||
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
|
||||
prompt += "<human>: {}\n<bot>:".format(query)
|
||||
return prompt
|
||||
|
||||
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
|
||||
format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
def predict_and_print(query, history: list):
|
||||
|
|
|
@ -14,4 +14,6 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
|
|||
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
|
||||
from .ppo import PPOPeftTrainer
|
||||
|
||||
from .template import prompt_template_alpaca, prompt_template_ziya
|
||||
|
||||
from .other import get_logits_processor, plot_loss
|
||||
|
|
|
@ -37,6 +37,11 @@ from .config import (
|
|||
FinetuningArguments
|
||||
)
|
||||
|
||||
from .template import (
|
||||
prompt_template_alpaca,
|
||||
prompt_template_ziya
|
||||
)
|
||||
|
||||
from .other import (
|
||||
get_logger,
|
||||
load_trainable_params,
|
||||
|
@ -224,6 +229,7 @@ def load_pretrained(
|
|||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||
|
||||
print_trainable_params(model)
|
||||
|
||||
|
@ -395,39 +401,19 @@ def preprocess_data(
|
|||
|
||||
column_names = list(dataset.column_names)
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
prompt_template = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
|
||||
|
||||
# support question with a single answer or multiple answers
|
||||
def format_example_alpaca(examples):
|
||||
def format_example(examples):
|
||||
for i in range(len(examples["prompt"])):
|
||||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
if examples["query"][i]:
|
||||
query += "\n" + examples["query"][i]
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\n" + prefix
|
||||
if examples["history"][i]:
|
||||
for old_query, response in examples["history"][i]:
|
||||
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
|
||||
prompt += "Human: {}\nAssistant: ".format(query)
|
||||
yield prompt, answer
|
||||
|
||||
def format_example_ziya(examples):
|
||||
for i in range(len(examples["prompt"])):
|
||||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
if examples["query"][i]:
|
||||
query += "\n" + examples["query"][i]
|
||||
prompt = ""
|
||||
if examples["history"][i]:
|
||||
for old_query, response in examples["history"][i]:
|
||||
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
|
||||
prompt += "<human>: {}\n<bot>:".format(query)
|
||||
prompt = prompt_template(query, examples["history"][i])
|
||||
prompt = prefix + prompt
|
||||
yield prompt, answer
|
||||
|
||||
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>)
|
||||
text_ids = tokenizer(examples["prompt"])["input_ids"]
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
def prompt_template_alpaca(query, history=None):
|
||||
prompt = ""
|
||||
if history:
|
||||
for old_query, response in history:
|
||||
prompt += "Human:{}\nAssistant:{}\n".format(old_query, response)
|
||||
prompt += "Human:{}\nAssistant:".format(query)
|
||||
return prompt
|
||||
|
||||
|
||||
def prompt_template_ziya(query, history=None):
|
||||
prompt = ""
|
||||
if history:
|
||||
for old_query, response in history:
|
||||
prompt += "<human>:{}\n<bot>:{}\n".format(old_query, response)
|
||||
prompt += "<human>:{}\n<bot>:".format(query)
|
||||
return prompt
|
|
@ -7,7 +7,14 @@ import mdtex2html
|
|||
import gradio as gr
|
||||
|
||||
from threading import Thread
|
||||
from utils import load_pretrained, prepare_infer_args, get_logits_processor
|
||||
from utils import (
|
||||
load_pretrained,
|
||||
prepare_infer_args,
|
||||
get_logits_processor,
|
||||
prompt_template_alpaca,
|
||||
prompt_template_ziya
|
||||
)
|
||||
|
||||
from transformers import TextIteratorStreamer
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
@ -18,26 +25,7 @@ require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
|||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
|
||||
def format_example_alpaca(query, history):
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\n"
|
||||
for old_query, response in history:
|
||||
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
|
||||
prompt += "Human: {}\nAssistant:".format(query)
|
||||
return prompt
|
||||
|
||||
|
||||
def format_example_ziya(query, history):
|
||||
prompt = ""
|
||||
for old_query, response in history:
|
||||
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
|
||||
prompt += "<human>: {}\n<bot>:".format(query)
|
||||
return prompt
|
||||
|
||||
|
||||
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
|
||||
format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue