From b6faf0207d5b637722a1fd45984d27b3ac095fd4 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 14 Jun 2023 22:27:39 +0800 Subject: [PATCH] support multiturn training like FastChat --- src/api_demo.py | 6 ++ src/cli_demo.py | 12 ++- src/utils/common.py | 50 ++++++----- src/utils/template.py | 198 +++++++++++++++++++++++++----------------- src/web_demo.py | 8 +- 5 files changed, 166 insertions(+), 108 deletions(-) diff --git a/src/api_demo.py b/src/api_demo.py index 0797da84..0ff46636 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -50,6 +50,9 @@ async def create_item(request: Request): json_post_list = json.loads(json_post) prompt = json_post_list.get("prompt") history = json_post_list.get("history") + max_new_tokens = json_post_list.get("max_new_tokens", None) + top_p = json_post_list.get("top_p", None) + temperature = json_post_list.get("temperature", None) # Tokenize the input prompt input_ids = tokenizer([prompt_template.get_prompt(prompt, history)], return_tensors="pt")["input_ids"] @@ -59,6 +62,9 @@ async def create_item(request: Request): gen_kwargs = generating_args.to_dict() gen_kwargs["input_ids"] = input_ids gen_kwargs["logits_processor"] = get_logits_processor() + gen_kwargs["max_new_tokens"] = max_new_tokens if max_new_tokens else gen_kwargs["max_new_tokens"] + gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"] + gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"] # Generate response with torch.no_grad(): diff --git a/src/cli_demo.py b/src/cli_demo.py index 9bcfec53..7dd92d45 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -16,23 +16,27 @@ from transformers import TextIteratorStreamer def main(): model_args, data_args, finetuning_args, generating_args = prepare_infer_args() - model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model, tokenizer = load_pretrained(model_args, finetuning_args) + model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" prompt_template = Template(data_args.prompt_template) - streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) - def predict_and_print(query, history: list): + def predict_and_print(query, history: list) -> list: input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) + gen_kwargs = generating_args.to_dict() gen_kwargs["input_ids"] = input_ids gen_kwargs["logits_processor"] = get_logits_processor() gen_kwargs["streamer"] = streamer + thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() + + print("{}: ".format(model_name), end="", flush=True) response = "" - print("{}: ".format(model_name), end="") for new_text in streamer: print(new_text, end="", flush=True) response += new_text diff --git a/src/utils/common.py b/src/utils/common.py index 6cd43aa9..e63c1659 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -421,18 +421,17 @@ def preprocess_data( prompt_template = Template(data_args.prompt_template) # support question with a single answer or multiple answers - def format_example(examples): + def get_dialog(examples): for i in range(len(examples["prompt"])): if examples["prompt"][i] and examples["response"][i]: query, answer = examples["prompt"][i], examples["response"][i] - if examples["query"][i]: - query += "\n" + examples["query"][i] - prompt = prompt_template.get_prompt(query, examples["history"][i], prefix) - yield prompt, answer + query = query + "\n" + examples["query"][i] if examples["query"][i] else query + dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix) + yield dialog def preprocess_pretrain_dataset(examples): - # build grouped texts with format ` X1 X2 X3 ...` (without ) - text_ids = tokenizer(examples["prompt"])["input_ids"] + # build grouped texts with format `X1 X2 X3 ...` (without [BOS] and [EOS]) + text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] concatenated_ids = list(chain(*text_ids)) total_length = len(concatenated_ids) # we drop the small remainder, and if the total_length < block_size, we exclude this batch @@ -446,28 +445,29 @@ def preprocess_data( } def preprocess_supervised_dataset(examples): - # build inputs with format `X Y ` and labels with format ` ... Y ` + # build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]` + # for input with history, we build multiple input-label pairs just like: + # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112 model_inputs = {"input_ids": [], "labels": []} - for prompt, answer in format_example(examples): - source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) - target_ids = tokenizer.encode(text=answer, add_special_tokens=False) + for dialog in get_dialog(examples): + input_ids, labels = [], [] - if len(source_ids) > data_args.max_source_length - 1: # bos token - source_ids = source_ids[:data_args.max_source_length - 1] - if len(target_ids) > data_args.max_target_length - 1: # eos token - target_ids = target_ids[:data_args.max_target_length - 1] + for i in range(len(dialog) // 2): + source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False) + target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) + input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id] + labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id] - input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id] - labels = [IGNORE_INDEX] * len(source_ids) + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id] - - model_inputs["input_ids"].append(input_ids) - model_inputs["labels"].append(labels) + model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length]) + model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length]) return model_inputs def preprocess_unsupervised_dataset(examples): - # build inputs with format `X ` and labels with format `Y ` + # build inputs with format `X [BOS]` and labels with format `Y [BOS]` model_inputs = {"input_ids": [], "labels": []} - for prompt, answer in format_example(examples): + for dialog in get_dialog(examples): + prompt, answer = "".join(dialog[:-1]), dialog[-1] + source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) target_ids = tokenizer.encode(text=answer, add_special_tokens=False) @@ -484,9 +484,11 @@ def preprocess_data( return model_inputs def preprocess_pairwise_dataset(examples): - # build input pairs with format `X Y1 ` and `X Y2 ` + # build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]` model_inputs = {"accept_ids": [], "reject_ids": []} - for prompt, answer in format_example(examples): + for dialog in get_dialog(examples): + prompt, answer = "".join(dialog[:-1]), dialog[-1] + source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) diff --git a/src/utils/template.py b/src/utils/template.py index 64134e0c..5354f32a 100644 --- a/src/utils/template.py +++ b/src/utils/template.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import List, Optional from dataclasses import dataclass @@ -8,89 +8,131 @@ class Template: name: str def __post_init__(self): - assert hasattr(self, "_format_{}".format(self.name)), "Template {} does not exist.".format(self.name) + + if self.name == "vanilla": + r""" + Supports language model inference without histories. + """ + self._register_template( + prefix="", + prompt="", + sep="", + use_history=False + ) + + elif self.name == "alpaca": + r""" + Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff + https://github.com/ymcui/Chinese-LLaMA-Alpaca + """ + self._register_template( + prefix="Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n", + prompt="### Instruction:\n{query}\n\n### Response:\n", + sep="\n\n", + use_history=True + ) + + elif self.name == "vicuna": + r""" + Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 + https://huggingface.co/lmsys/vicuna-13b-delta-v1.1 + """ + self._register_template( + prefix="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + prompt="USER: {query} ASSISTANT: ", + sep="", + use_history=True + ) + + elif self.name == "belle": + r""" + Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B + """ + self._register_template( + prefix="", + prompt="Human: {query}\n\nBelle: ", + sep="\n\n", + use_history=True + ) + + elif self.name == "linly": + r""" + Supports: https://github.com/CVI-SZU/Linly + """ + self._register_template( + prefix="", + prompt="User: {query}\nBot: ", + sep="\n", + use_history=True + ) + + elif self.name == "billa": + r""" + Supports: https://github.com/Neutralzz/BiLLa + """ + self._register_template( + prefix="", + prompt="Human: {query}\nAssistant: ", + sep="\n", + use_history=True + ) + + elif self.name == "ziya": + r""" + Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 + """ + self._register_template( + prefix="", + prompt=":{query}\n:", + sep="\n", + use_history=True + ) + + elif self.name == "aquila": + r""" + Supports: https://huggingface.co/qhduan/aquilachat-7b + """ + self._register_template( + prefix="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + prompt="Human: {query}\nAssistant: ", + sep="###", + use_history=True + ) + + else: + raise ValueError("Template {} does not exist.".format(self.name)) def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str: - return getattr(self, "_format_{}".format(self.name))(query, history, prefix) - - def _format_vanilla(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: r""" - Use for language model inference without histories. + Returns a string containing prompt without response. """ - return query + return "".join(self._format_example(query, history, prefix)) - def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: + def get_dialog(self, query: str, resp: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]: r""" - Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff - https://github.com/ymcui/Chinese-LLaMA-Alpaca + Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response. """ - if prefix: - prompt = prefix - else: - prompt = "Below is an instruction that describes a task. " - prompt += "Write a response that appropriately completes the request.\n\n" - if history: - for old_query, response in history: - prompt += "### Instruction:\n{}\n\n### Response:\n{}\n\n".format(old_query, response) - prompt += "### Instruction:\n{}\n\n### Response:\n".format(query) - return prompt + return self._format_example(query, history, prefix) + [resp] - def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: - r""" - Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 - https://huggingface.co/lmsys/vicuna-13b-delta-v1.1 - """ - if prefix: - prompt = prefix - else: - prompt = "A chat between a curious user and an artificial intelligence assistant. " - prompt += "The assistant gives helpful, detailed, and polite answers to the user's questions. " - if history: - for old_query, response in history: - prompt += "USER: {} ASSISTANT: {}".format(old_query, response) - prompt += "USER: {} ASSISTANT: ".format(query) - return prompt + def _register_template(self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True) -> None: + self.prefix = prefix + self.prompt = prompt + self.sep = sep + self.use_history = use_history - def _format_belle(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: - r""" - Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B - """ - prompt = prefix - if history: - for old_query, response in history: - prompt += "Human: {}\n\nBelle: {}\n\n".format(old_query, response) - prompt += "Human: {}\n\nBelle: ".format(query) - return prompt - - def _format_linly(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: - r""" - Supports: https://github.com/CVI-SZU/Linly - """ - prompt = prefix - if history: - for old_query, response in history: - prompt += "User: {}\nBot: {}\n".format(old_query, response) - prompt += "User: {}\nBot: ".format(query) - return prompt - - def _format_billa(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: - r""" - Supports: https://github.com/Neutralzz/BiLLa - """ - prompt = prefix - if history: - for old_query, response in history: - prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) - prompt += "Human: {}\nAssistant: ".format(query) - return prompt - - def _format_ziya(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: - r""" - Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 - """ - prompt = prefix - if history: - for old_query, response in history: - prompt += ":{}\n:{}\n".format(old_query, response) - prompt += ":{}\n:".format(query) - return prompt + def _format_example(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]: + prefix = prefix if prefix else self.prefix + history = history if (history and self.use_history) else [] + history = history + [(query, "")] + convs = [] + for turn_idx, (user_query, bot_resp) in enumerate(history): + if turn_idx == 0: + convs.append(prefix + self.prompt.format(query=user_query)) + convs.append(bot_resp) + else: + convs.append(self.sep + self.prompt.format(query=user_query)) + convs.append(bot_resp) + return convs[:-1] # drop last diff --git a/src/web_demo.py b/src/web_demo.py index 2cceddd3..ba886ab1 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -25,7 +25,6 @@ model_args, data_args, finetuning_args, generating_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) prompt_template = Template(data_args.prompt_template) -streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) def postprocess(self, y): @@ -82,9 +81,12 @@ def predict(query, chatbot, max_length, top_p, temperature, history): input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) + gen_kwargs = { "input_ids": input_ids, - "do_sample": True, + "do_sample": generating_args.do_sample, "top_p": top_p, "temperature": temperature, "num_beams": generating_args.num_beams, @@ -93,8 +95,10 @@ def predict(query, chatbot, max_length, top_p, temperature, history): "logits_processor": get_logits_processor(), "streamer": streamer } + thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() + response = "" for new_text in streamer: response += new_text