forked from p04798526/LLaMA-Factory-Mirror
add prompt template class
This commit is contained in:
parent
5d021d4ad5
commit
909af8f496
|
@ -21,11 +21,10 @@ import datetime
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
|
Template,
|
||||||
load_pretrained,
|
load_pretrained,
|
||||||
prepare_infer_args,
|
prepare_infer_args,
|
||||||
get_logits_processor,
|
get_logits_processor
|
||||||
prompt_template_alpaca,
|
|
||||||
prompt_template_ziya
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,7 +42,7 @@ app = FastAPI()
|
||||||
|
|
||||||
@app.post("/")
|
@app.post("/")
|
||||||
async def create_item(request: Request):
|
async def create_item(request: Request):
|
||||||
global model, tokenizer, format_example
|
global model, tokenizer, prompt_template
|
||||||
|
|
||||||
# Parse the request JSON
|
# Parse the request JSON
|
||||||
json_post_raw = await request.json()
|
json_post_raw = await request.json()
|
||||||
|
@ -53,7 +52,7 @@ async def create_item(request: Request):
|
||||||
history = json_post_list.get("history")
|
history = json_post_list.get("history")
|
||||||
|
|
||||||
# Tokenize the input prompt
|
# Tokenize the input prompt
|
||||||
input_ids = tokenizer([format_example(prompt, history)], return_tensors="pt")["input_ids"]
|
input_ids = tokenizer([prompt_template.get_prompt(prompt, history)], return_tensors="pt")["input_ids"]
|
||||||
input_ids = input_ids.to(model.device)
|
input_ids = input_ids.to(model.device)
|
||||||
|
|
||||||
# Generation arguments
|
# Generation arguments
|
||||||
|
@ -98,6 +97,6 @@ async def create_item(request: Request):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||||
format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
|
prompt_template = Template(data_args.prompt_template)
|
||||||
|
|
||||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||||
|
|
|
@ -4,11 +4,10 @@
|
||||||
|
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
|
Template,
|
||||||
load_pretrained,
|
load_pretrained,
|
||||||
prepare_infer_args,
|
prepare_infer_args,
|
||||||
get_logits_processor,
|
get_logits_processor
|
||||||
prompt_template_alpaca,
|
|
||||||
prompt_template_ziya
|
|
||||||
)
|
)
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
@ -20,11 +19,11 @@ def main():
|
||||||
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
|
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||||
|
|
||||||
format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
|
prompt_template = Template(data_args.prompt_template)
|
||||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|
||||||
def predict_and_print(query, history: list):
|
def predict_and_print(query, history: list):
|
||||||
input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"]
|
input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
|
||||||
input_ids = input_ids.to(model.device)
|
input_ids = input_ids.to(model.device)
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
|
|
|
@ -14,6 +14,6 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
|
||||||
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
|
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
|
||||||
from .ppo import PPOPeftTrainer
|
from .ppo import PPOPeftTrainer
|
||||||
|
|
||||||
from .template import prompt_template_alpaca, prompt_template_ziya
|
from .template import Template
|
||||||
|
|
||||||
from .other import get_logits_processor, plot_loss
|
from .other import get_logits_processor, plot_loss
|
||||||
|
|
|
@ -29,6 +29,8 @@ from peft import (
|
||||||
get_peft_model
|
get_peft_model
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from peft.utils import CONFIG_NAME
|
||||||
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
|
@ -37,10 +39,7 @@ from .config import (
|
||||||
FinetuningArguments
|
FinetuningArguments
|
||||||
)
|
)
|
||||||
|
|
||||||
from .template import (
|
from .template import Template
|
||||||
prompt_template_alpaca,
|
|
||||||
prompt_template_ziya
|
|
||||||
)
|
|
||||||
|
|
||||||
from .other import (
|
from .other import (
|
||||||
get_logger,
|
get_logger,
|
||||||
|
@ -102,6 +101,9 @@ def _init_adapter(
|
||||||
logger.info("Fine-tuning method: LoRA")
|
logger.info("Fine-tuning method: LoRA")
|
||||||
lastest_checkpoint = None
|
lastest_checkpoint = None
|
||||||
|
|
||||||
|
assert os.path.exists(model_args.checkpoint_dir[0], CONFIG_NAME), \
|
||||||
|
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||||
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||||
|
@ -401,7 +403,7 @@ def preprocess_data(
|
||||||
|
|
||||||
column_names = list(dataset.column_names)
|
column_names = list(dataset.column_names)
|
||||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||||
prompt_template = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
|
prompt_template = Template(data_args.prompt_template)
|
||||||
|
|
||||||
# support question with a single answer or multiple answers
|
# support question with a single answer or multiple answers
|
||||||
def format_example(examples):
|
def format_example(examples):
|
||||||
|
@ -410,8 +412,7 @@ def preprocess_data(
|
||||||
query, answer = examples["prompt"][i], examples["response"][i]
|
query, answer = examples["prompt"][i], examples["response"][i]
|
||||||
if examples["query"][i]:
|
if examples["query"][i]:
|
||||||
query += "\n" + examples["query"][i]
|
query += "\n" + examples["query"][i]
|
||||||
prompt = prompt_template(query, examples["history"][i])
|
prompt = prompt_template.get_prompt(query, examples["history"][i], prefix)
|
||||||
prompt = prefix + prompt
|
|
||||||
yield prompt, answer
|
yield prompt, answer
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(examples):
|
def preprocess_pretrain_dataset(examples):
|
||||||
|
|
|
@ -141,7 +141,7 @@ class DataTrainingArguments:
|
||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
||||||
)
|
)
|
||||||
prompt_template: Optional[Literal["alpaca", "ziya"]] = field(
|
prompt_template: Optional[Literal["alpaca", "vicuna", "ziya"]] = field(
|
||||||
default="alpaca",
|
default="alpaca",
|
||||||
metadata={"help": "Which template to use for constructing prompts in training."}
|
metadata={"help": "Which template to use for constructing prompts in training."}
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.generation.utils import LogitsProcessorList
|
from transformers.generation.utils import LogitsProcessorList
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
|
||||||
from peft.utils.other import WEIGHTS_NAME
|
from peft.utils import WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
|
@ -1,14 +1,43 @@
|
||||||
def prompt_template_alpaca(query, history=None):
|
from typing import Optional
|
||||||
prompt = ""
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Template:
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str:
|
||||||
|
return getattr(self, "_format_{}".format(self.name))(query, history, prefix)
|
||||||
|
|
||||||
|
def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
|
||||||
|
if prefix:
|
||||||
|
prompt = prefix
|
||||||
|
else:
|
||||||
|
prompt = "Below is an instruction that describes a task. "
|
||||||
|
prompt += "Write a response that appropriately completes the request.\n"
|
||||||
|
prompt += "Instruction:\n"
|
||||||
if history:
|
if history:
|
||||||
for old_query, response in history:
|
for old_query, response in history:
|
||||||
prompt += "Human:{}\nAssistant:{}\n".format(old_query, response)
|
prompt += "Human:{}\nAssistant:{}\n".format(old_query, response)
|
||||||
prompt += "Human:{}\nAssistant:".format(query)
|
prompt += "Human:{}\nAssistant:".format(query)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
|
||||||
|
if prefix:
|
||||||
|
prompt = prefix
|
||||||
|
else:
|
||||||
|
prompt = "A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
prompt += "The assistant gives helpful, detailed, and polite answers to the user's questions. "
|
||||||
|
if history:
|
||||||
|
for old_query, response in history:
|
||||||
|
prompt += "USER: {} ASSISTANT: {}</s>".format(old_query, response)
|
||||||
|
prompt += "USER: {} ASSISTANT:".format(query)
|
||||||
|
return prompt
|
||||||
|
|
||||||
def prompt_template_ziya(query, history=None):
|
|
||||||
prompt = ""
|
def _format_ziya(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
|
||||||
|
prompt = prefix
|
||||||
if history:
|
if history:
|
||||||
for old_query, response in history:
|
for old_query, response in history:
|
||||||
prompt += "<human>:{}\n<bot>:{}\n".format(old_query, response)
|
prompt += "<human>:{}\n<bot>:{}\n".format(old_query, response)
|
||||||
|
|
|
@ -8,11 +8,10 @@ import gradio as gr
|
||||||
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from utils import (
|
from utils import (
|
||||||
|
Template,
|
||||||
load_pretrained,
|
load_pretrained,
|
||||||
prepare_infer_args,
|
prepare_infer_args,
|
||||||
get_logits_processor,
|
get_logits_processor
|
||||||
prompt_template_alpaca,
|
|
||||||
prompt_template_ziya
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
@ -25,7 +24,7 @@ require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
||||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||||
|
|
||||||
format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
|
prompt_template = Template(data_args.prompt_template)
|
||||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,7 +80,7 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
|
||||||
def predict(query, chatbot, max_length, top_p, temperature, history):
|
def predict(query, chatbot, max_length, top_p, temperature, history):
|
||||||
chatbot.append((parse_text(query), ""))
|
chatbot.append((parse_text(query), ""))
|
||||||
|
|
||||||
input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"]
|
input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
|
||||||
input_ids = input_ids.to(model.device)
|
input_ids = input_ids.to(model.device)
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
|
|
Loading…
Reference in New Issue