diff --git a/src/llmtuner/chat/chat_model.py b/src/llmtuner/chat/chat_model.py index 521270f5..88846dee 100644 --- a/src/llmtuner/chat/chat_model.py +++ b/src/llmtuner/chat/chat_model.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Literal, Optional, Tuple from threading import Thread from transformers import GenerationConfig, TextIteratorStreamer -from ..data import get_template_and_fix_tokenizer +from ..data import get_template_and_fix_tokenizer, Role from ..extras.misc import get_logits_processor from ..model import dispatch_model, load_model_and_tokenizer from ..hparams import get_infer_args @@ -36,10 +36,19 @@ class ChatModel: query: str, history: Optional[List[Tuple[str, str]]] = None, system: Optional[str] = None, + tools: Optional[str] = None, **input_kwargs ) -> Tuple[Dict[str, Any], int]: + messages = [] + if history is not None: + for old_prompt, old_response in history: + messages.append({"role": Role.USER, "content": old_prompt}) + messages.append({"role": Role.ASSISTANT, "content": old_response}) + + messages.append({"role": Role.USER, "content": query}) + messages.append({"role": Role.ASSISTANT, "content": ""}) prompt, _ = self.template.encode_oneturn( - tokenizer=self.tokenizer, query=query, resp="", history=history, system=system + tokenizer=self.tokenizer, messages=messages, system=system, tools=tools ) prompt_length = len(prompt) input_ids = torch.tensor([prompt], device=self.model.device) @@ -90,6 +99,7 @@ class ChatModel: query: str, history: Optional[List[Tuple[str, str]]] = None, system: Optional[str] = None, + tools: Optional[str] = None, **input_kwargs ) -> List[Response]: r""" @@ -97,7 +107,7 @@ class ChatModel: Returns: [(response_text, prompt_length, response_length)] * n (default n=1) """ - gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs) + gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs) generate_output = self.model.generate(**gen_kwargs) response_ids = generate_output[:, prompt_length:] response = self.tokenizer.batch_decode( @@ -122,9 +132,10 @@ class ChatModel: query: str, history: Optional[List[Tuple[str, str]]] = None, system: Optional[str] = None, + tools: Optional[str] = None, **input_kwargs ) -> Generator[str, None, None]: - gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs) + gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer diff --git a/src/llmtuner/data/__init__.py b/src/llmtuner/data/__init__.py index 3709b6e1..85be70b7 100644 --- a/src/llmtuner/data/__init__.py +++ b/src/llmtuner/data/__init__.py @@ -1,6 +1,6 @@ from .loader import get_dataset from .template import get_template_and_fix_tokenizer, templates -from .utils import split_dataset +from .utils import split_dataset, Role -__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset"] +__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"] diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 45ae8626..1e6eadb1 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -97,7 +97,7 @@ def preprocess_packed_supervised_dataset( messages = examples["prompt"][i] + examples["response"][i] for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tool"][i], 1_000_000 + tokenizer, messages, examples["system"][i], examples["tool"][i] )): if data_args.train_on_prompt: source_mask = source_ids diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 5690e773..935d951c 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -33,13 +33,13 @@ class Template: tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], system: str, - tool: str, - cutoff_len: int + tools: str, + cutoff_len: Optional[int] = 1_000_000 ) -> Tuple[List[int], List[int]]: r""" Returns a single pair of token ids representing prompt and response respectively. """ - encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len) + encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len) prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: prompt_ids = prompt_ids + query_ids + resp_ids @@ -52,13 +52,13 @@ class Template: tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], system: str, - tool: str, - cutoff_len: int + tools: str, + cutoff_len: Optional[int] = 1_000_000 ) -> List[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ - encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len) + encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len) return encoded_pairs def _encode( @@ -66,7 +66,7 @@ class Template: tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], system: str, - tool: str, + tools: str, cutoff_len: int ) -> List[Tuple[List[int], List[int]]]: r""" @@ -78,8 +78,8 @@ class Template: encoded_messages = [] for i, message in enumerate(messages): elements = [] - if i == 0 and (system or tool): - tool_text = self.format_tool(content=tool)[0] if tool else "" + if i == 0 and (system or tools): + tool_text = self.format_tool(content=tools)[0] if tools else "" elements += self.format_system(content=(system + tool_text)) elif i > 0 and i % 2 == 0: elements += self.separator @@ -131,7 +131,7 @@ class Llama2Template(Template): tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], system: str, - tool: str, + tools: str, cutoff_len: int ) -> List[Tuple[List[int], List[int]]]: r""" @@ -144,8 +144,8 @@ class Llama2Template(Template): for i, message in enumerate(messages): elements = [] system_text = "" - if i == 0 and (system or tool): - tool_text = self.format_tool(content=tool)[0] if tool else "" + if i == 0 and (system or tools): + tool_text = self.format_tool(content=tools)[0] if tools else "" system_text = self.format_system(content=(system + tool_text))[0] elif i > 0 and i % 2 == 0: elements += self.separator diff --git a/src/llmtuner/eval/evaluator.py b/src/llmtuner/eval/evaluator.py index 251dfc0b..1cb55b38 100644 --- a/src/llmtuner/eval/evaluator.py +++ b/src/llmtuner/eval/evaluator.py @@ -65,17 +65,17 @@ class Evaluator: inputs, outputs, labels = [], [], [] for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False): support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"])))) - query, resp, history = self.eval_template.format_example( + messages = self.eval_template.format_example( target_data=dataset[self.data_args.split][i], support_set=support_set, - subject_name=categorys[subject]["name"], - use_history=self.template.use_history + subject_name=categorys[subject]["name"] ) + input_ids, _ = self.template.encode_oneturn( - tokenizer=self.tokenizer, query=query, resp=resp, history=history + tokenizer=self.tokenizer, messages=messages ) inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}) - labels.append(resp) + labels.append(messages[-1]["content"]) for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False): batch_input = self.tokenizer.pad( diff --git a/src/llmtuner/eval/template.py b/src/llmtuner/eval/template.py index 924a3c8b..5514e5d5 100644 --- a/src/llmtuner/eval/template.py +++ b/src/llmtuner/eval/template.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Tuple from ..extras.constants import CHOICES +from ..data import Role if TYPE_CHECKING: from datasets import Dataset @@ -28,20 +29,23 @@ class EvalTemplate: support_set: "Dataset", subject_name: str, use_history: bool - ) -> Tuple[str, str, List[Tuple[str, str]]]: - query, resp = self.parse_example(target_data) - history = [self.parse_example(support_set[k]) for k in range(len(support_set))] + ) -> List[Dict[str, str]]: + messages = [] + for k in range(len(support_set)): + prompt, response = self.parse_example(support_set[k]) + messages.append({"role": Role.USER, "content": prompt}) + messages.append({"role": Role.ASSISTANT, "content": response}) - if len(history): - temp = history.pop(0) - history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1])) - else: - query = self.system.format(subject=subject_name) + query + prompt, response = self.parse_example(target_data) + messages.append({"role": Role.USER, "content": prompt}) + messages.append({"role": Role.ASSISTANT, "content": response}) + + messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"] if not use_history: - query = "\n\n".join(["".join(item) for item in history] + [query]) - history = [] - return query.strip(), resp, history + messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}] + + return messages eval_templates: Dict[str, "EvalTemplate"] = {} diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index a6681665..e211bb2a 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -105,6 +105,7 @@ class WebChatModel(ChatModel): query: str, history: List[Tuple[str, str]], system: str, + tools: str, max_new_tokens: int, top_p: float, temperature: float @@ -112,7 +113,7 @@ class WebChatModel(ChatModel): chatbot.append([query, ""]) response = "" for new_text in self.stream_chat( - query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature ): response += new_text new_history = history + [(query, response)] diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index ee128aca..aa087536 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -18,6 +18,7 @@ def create_chat_box( with gr.Row(): with gr.Column(scale=4): system = gr.Textbox(show_label=False) + tools = gr.Textbox(show_label=False, lines=2) query = gr.Textbox(show_label=False, lines=8) submit_btn = gr.Button(variant="primary") @@ -30,7 +31,7 @@ def create_chat_box( submit_btn.click( engine.chatter.predict, - [chatbot, query, history, system, max_new_tokens, top_p, temperature], + [chatbot, query, history, system, tools, max_new_tokens, top_p, temperature], [chatbot, history], show_progress=True ).then( @@ -41,6 +42,7 @@ def create_chat_box( return chat_box, chatbot, history, dict( system=system, + tools=tools, query=query, submit_btn=submit_btn, clear_btn=clear_btn, diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index d6f9d31f..9ba08e25 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -521,6 +521,14 @@ LOCALES = { "placeholder": "系统提示词(非必填)" } }, + "tools": { + "en": { + "placeholder": "Tools (optional)" + }, + "zh": { + "placeholder": "工具列表(非必填)" + } + }, "query": { "en": { "placeholder": "Input..."