diff --git a/src/api_demo.py b/src/api_demo.py index 678b5a54..ac761aaf 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -21,11 +21,10 @@ import datetime from fastapi import FastAPI, Request from utils import ( + Template, load_pretrained, prepare_infer_args, - get_logits_processor, - prompt_template_alpaca, - prompt_template_ziya + get_logits_processor ) @@ -43,7 +42,7 @@ app = FastAPI() @app.post("/") async def create_item(request: Request): - global model, tokenizer, format_example + global model, tokenizer, prompt_template # Parse the request JSON json_post_raw = await request.json() @@ -53,7 +52,7 @@ async def create_item(request: Request): history = json_post_list.get("history") # Tokenize the input prompt - input_ids = tokenizer([format_example(prompt, history)], return_tensors="pt")["input_ids"] + input_ids = tokenizer([prompt_template.get_prompt(prompt, history)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) # Generation arguments @@ -98,6 +97,6 @@ async def create_item(request: Request): if __name__ == "__main__": model_args, data_args, finetuning_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) - format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya + prompt_template = Template(data_args.prompt_template) uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/src/cli_demo.py b/src/cli_demo.py index fd0a1c12..31299c30 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -4,11 +4,10 @@ from utils import ( + Template, load_pretrained, prepare_infer_args, - get_logits_processor, - prompt_template_alpaca, - prompt_template_ziya + get_logits_processor ) from threading import Thread from transformers import TextIteratorStreamer @@ -20,11 +19,11 @@ def main(): model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model, tokenizer = load_pretrained(model_args, finetuning_args) - format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya + prompt_template = Template(data_args.prompt_template) streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) def predict_and_print(query, history: list): - input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"] + input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) gen_kwargs = { "input_ids": input_ids, diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 152052ee..f2db999d 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -14,6 +14,6 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer from .ppo import PPOPeftTrainer -from .template import prompt_template_alpaca, prompt_template_ziya +from .template import Template from .other import get_logits_processor, plot_loss diff --git a/src/utils/common.py b/src/utils/common.py index 9354f18b..02880ed8 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -29,6 +29,8 @@ from peft import ( get_peft_model ) +from peft.utils import CONFIG_NAME + from trl import AutoModelForCausalLMWithValueHead from .config import ( @@ -37,10 +39,7 @@ from .config import ( FinetuningArguments ) -from .template import ( - prompt_template_alpaca, - prompt_template_ziya -) +from .template import Template from .other import ( get_logger, @@ -102,6 +101,9 @@ def _init_adapter( logger.info("Fine-tuning method: LoRA") lastest_checkpoint = None + assert os.path.exists(model_args.checkpoint_dir[0], CONFIG_NAME), \ + "The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." + if model_args.checkpoint_dir is not None: if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] @@ -401,7 +403,7 @@ def preprocess_data( column_names = list(dataset.column_names) prefix = data_args.source_prefix if data_args.source_prefix is not None else "" - prompt_template = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya + prompt_template = Template(data_args.prompt_template) # support question with a single answer or multiple answers def format_example(examples): @@ -410,8 +412,7 @@ def preprocess_data( query, answer = examples["prompt"][i], examples["response"][i] if examples["query"][i]: query += "\n" + examples["query"][i] - prompt = prompt_template(query, examples["history"][i]) - prompt = prefix + prompt + prompt = prompt_template.get_prompt(query, examples["history"][i], prefix) yield prompt, answer def preprocess_pretrain_dataset(examples): diff --git a/src/utils/config.py b/src/utils/config.py index ef29a174..e7a7ae2c 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -141,7 +141,7 @@ class DataTrainingArguments: default=0, metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} ) - prompt_template: Optional[Literal["alpaca", "ziya"]] = field( + prompt_template: Optional[Literal["alpaca", "vicuna", "ziya"]] = field( default="alpaca", metadata={"help": "Which template to use for constructing prompts in training."} ) diff --git a/src/utils/other.py b/src/utils/other.py index c8512892..5675e3fd 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -10,7 +10,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.generation.utils import LogitsProcessorList from transformers.generation.logits_process import LogitsProcessor -from peft.utils.other import WEIGHTS_NAME +from peft.utils import WEIGHTS_NAME IGNORE_INDEX = -100 diff --git a/src/utils/template.py b/src/utils/template.py index 2cafd5e5..42391826 100644 --- a/src/utils/template.py +++ b/src/utils/template.py @@ -1,16 +1,45 @@ -def prompt_template_alpaca(query, history=None): - prompt = "" - if history: - for old_query, response in history: - prompt += "Human:{}\nAssistant:{}\n".format(old_query, response) - prompt += "Human:{}\nAssistant:".format(query) - return prompt +from typing import Optional +from dataclasses import dataclass -def prompt_template_ziya(query, history=None): - prompt = "" - if history: - for old_query, response in history: - prompt += ":{}\n:{}\n".format(old_query, response) - prompt += ":{}\n:".format(query) - return prompt +@dataclass +class Template: + + name: str + + def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str: + return getattr(self, "_format_{}".format(self.name))(query, history, prefix) + + def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: + if prefix: + prompt = prefix + else: + prompt = "Below is an instruction that describes a task. " + prompt += "Write a response that appropriately completes the request.\n" + prompt += "Instruction:\n" + if history: + for old_query, response in history: + prompt += "Human:{}\nAssistant:{}\n".format(old_query, response) + prompt += "Human:{}\nAssistant:".format(query) + return prompt + + def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: + if prefix: + prompt = prefix + else: + prompt = "A chat between a curious user and an artificial intelligence assistant. " + prompt += "The assistant gives helpful, detailed, and polite answers to the user's questions. " + if history: + for old_query, response in history: + prompt += "USER: {} ASSISTANT: {}".format(old_query, response) + prompt += "USER: {} ASSISTANT:".format(query) + return prompt + + + def _format_ziya(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: + prompt = prefix + if history: + for old_query, response in history: + prompt += ":{}\n:{}\n".format(old_query, response) + prompt += ":{}\n:".format(query) + return prompt diff --git a/src/web_demo.py b/src/web_demo.py index 9b69c3e6..8a8934fb 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -8,11 +8,10 @@ import gradio as gr from threading import Thread from utils import ( + Template, load_pretrained, prepare_infer_args, - get_logits_processor, - prompt_template_alpaca, - prompt_template_ziya + get_logits_processor ) from transformers import TextIteratorStreamer @@ -25,7 +24,7 @@ require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0") model_args, data_args, finetuning_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) -format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya +prompt_template = Template(data_args.prompt_template) streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) @@ -81,7 +80,7 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT def predict(query, chatbot, max_length, top_p, temperature, history): chatbot.append((parse_text(query), "")) - input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"] + input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) gen_kwargs = { "input_ids": input_ids,