diff --git a/src/cli_demo.py b/src/cli_demo.py index 39704153..09d444c6 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -13,7 +13,7 @@ except ImportError: def main(): chat_model = ChatModel() - history = [] + messages = [] print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") while True: @@ -37,12 +37,13 @@ def main(): print("Assistant: ", end="", flush=True) response = "" - for new_text in chat_model.stream_chat(query, history): + for new_text in chat_model.stream_chat(messages): print(new_text, end="", flush=True) response += new_text print() - history = history + [(query, response)] + messages.append({"role": "user", "content": query}) + messages.append({"role": "assistant", "content": response}) if __name__ == "__main__": diff --git a/src/llmtuner/chat/chat_model.py b/src/llmtuner/chat/chat_model.py index f6cbbddc..cfa2700b 100644 --- a/src/llmtuner/chat/chat_model.py +++ b/src/llmtuner/chat/chat_model.py @@ -1,11 +1,11 @@ from dataclasses import dataclass from threading import Thread -from typing import Any, Dict, Generator, List, Literal, Optional, Tuple +from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple import torch from transformers import GenerationConfig, TextIteratorStreamer -from ..data import Role, get_template_and_fix_tokenizer +from ..data import get_template_and_fix_tokenizer from ..extras.misc import get_logits_processor from ..hparams import get_infer_args from ..model import dispatch_model, load_model_and_tokenizer @@ -32,20 +32,11 @@ class ChatModel: def _process_args( self, - query: str, - history: Optional[List[Tuple[str, str]]] = None, + messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, **input_kwargs, ) -> Tuple[Dict[str, Any], int]: - messages = [] - if history is not None: - for old_prompt, old_response in history: - messages.append({"role": Role.USER, "content": old_prompt}) - messages.append({"role": Role.ASSISTANT, "content": old_response}) - - messages.append({"role": Role.USER, "content": query}) - messages.append({"role": Role.ASSISTANT, "content": ""}) prompt, _ = self.template.encode_oneturn( tokenizer=self.tokenizer, messages=messages, system=system, tools=tools ) @@ -97,18 +88,12 @@ class ChatModel: @torch.inference_mode() def chat( self, - query: str, - history: Optional[List[Tuple[str, str]]] = None, + messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, **input_kwargs, ) -> List[Response]: - r""" - Args: query, history, system, **input_kwargs - - Returns: [(response_text, prompt_length, response_length)] * n (default n=1) - """ - gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs) + gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs) generate_output = self.model.generate(**gen_kwargs) response_ids = generate_output[:, prompt_length:] response = self.tokenizer.batch_decode( @@ -132,13 +117,12 @@ class ChatModel: @torch.inference_mode() def stream_chat( self, - query: str, - history: Optional[List[Tuple[str, str]]] = None, + messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, **input_kwargs, ) -> Generator[str, None, None]: - gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs) + gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer diff --git a/src/llmtuner/data/formatter.py b/src/llmtuner/data/formatter.py index 6451d8e3..078539c2 100644 --- a/src/llmtuner/data/formatter.py +++ b/src/llmtuner/data/formatter.py @@ -1,6 +1,11 @@ import json -from dataclasses import dataclass -from typing import Any, Dict, List, Literal, Union +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Set, Sequence, Tuple, Union + + +SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] JSON_FORMAT_PROMPT = ( @@ -18,30 +23,85 @@ TOOL_SYSTEM_PROMPT = ( ) -@dataclass -class StringFormatter: - container: List[Union[str, Dict[str, str]]] +def default_tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + tool_names = [] + for tool in tools: + param_text = "" + for name, param in tool["parameters"]["properties"].items(): + required = ", required" if name in tool["parameters"].get("required", []) else "" + enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else "" + param_text += " - {name} ({type}{required}): {desc}{enum}\n".format( + name=name, + type=param.get("type", ""), + required=required, + desc=param.get("description", ""), + enum=enum, + ) - def __call__(self, **kwargs) -> List[Union[str, Dict[str, str]]]: + tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( + name=tool["name"], desc=tool.get("description", ""), args=param_text + ) + tool_names.append(tool["name"]) + + return TOOL_SYSTEM_PROMPT.format( + tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT + ) + + +def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: + regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL) + action_match = re.search(regex, content) + if not action_match: + return content + + tool_name = action_match.group(1).strip() + tool_input = action_match.group(2).strip().strip('"').strip("```") + try: + arguments = json.loads(tool_input) + except json.JSONDecodeError: + return content + + return tool_name, json.dumps(arguments, ensure_ascii=False) + + +@dataclass +class Formatter(ABC): + slots: SLOTS = field(default_factory=list) + tool_format: Literal["default"] = "default" + + @abstractmethod + def apply(self, **kwargs) -> SLOTS: + ... + + +@dataclass +class EmptyFormatter(Formatter): + def apply(self, **kwargs) -> SLOTS: + return self.slots + + +@dataclass +class StringFormatter(Formatter): + def apply(self, **kwargs) -> SLOTS: elements = [] - for elem in self.container: - if isinstance(elem, str): + for slot in self.slots: + if isinstance(slot, str): for name, value in kwargs.items(): - elem = elem.replace("{{" + name + "}}", value) - elements.append(elem) - elif isinstance(elem, (dict, set)): - elements.append(elem) + slot = slot.replace("{{" + name + "}}", value, 1) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) else: - raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) + raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) return elements @dataclass -class FunctionFormatter: - container: List[Union[str, Dict[str, str]]] - - def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]: +class FunctionFormatter(Formatter): + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") try: function = json.loads(content) name = function["name"] @@ -50,55 +110,36 @@ class FunctionFormatter: name, arguments = "", "" elements = [] - for elem in self.container: - if isinstance(elem, str): - elem = elem.replace("{{name}}", name) - elem = elem.replace("{{arguments}}", arguments) - elements.append(elem) - elif isinstance(elem, (dict, set)): - elements.append(elem) + for slot in self.slots: + if isinstance(slot, str): + slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) else: - raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) + raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) return elements @dataclass -class ToolFormatter: - type: Literal["default"] - - def _default(self, tools: List[Dict[str, Any]]) -> str: - tool_text = "" - tool_names = [] - for tool in tools: - param_text = "" - for name, param in tool["parameters"]["properties"].items(): - required = ", required" if name in tool["parameters"].get("required", []) else "" - enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else "" - param_text += " - {name} ({type}{required}): {desc}{enum}\n".format( - name=name, - type=param.get("type", ""), - required=required, - desc=param.get("description", ""), - enum=enum, - ) - - tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( - name=tool["name"], desc=tool.get("description", ""), args=param_text - ) - tool_names.append(tool["name"]) - - return TOOL_SYSTEM_PROMPT.format( - tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT - ) - - def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]: +class ToolFormatter(Formatter): + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") try: tools = json.loads(content) if not len(tools): return [""] - if self.type == "default": - return [self._default(tools)] + if self.tool_format == "default": + return [default_tool_formatter(tools)] + else: + raise NotImplementedError except Exception: return [""] + + def extract(self, content: str) -> Union[str, Tuple[str, str]]: + if self.tool_format == "default": + return default_tool_extractor(content) + else: + raise NotImplementedError diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 924007bd..77d053f3 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -1,31 +1,34 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from ..extras.logging import get_logger -from .formatter import FunctionFormatter, StringFormatter, ToolFormatter -from .utils import Role +from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter +from .utils import Role, infer_max_len if TYPE_CHECKING: from transformers import PreTrainedTokenizer + from .formatter import Formatter + logger = get_logger(__name__) @dataclass class Template: - format_user: Callable - format_assistant: Callable - format_system: Callable - format_tool: Callable - format_observation: Callable - format_function: Callable - system: str - separator: List[Union[str, Dict[str, str]]] + format_user: "Formatter" + format_assistant: "Formatter" + format_system: "Formatter" + format_function: "Formatter" + format_observation: "Formatter" + format_tools: "Formatter" + format_separator: "Formatter" + default_system: str stop_words: List[str] efficient_eos: bool replace_eos: bool + force_system: bool def encode_oneturn( self, @@ -34,14 +37,15 @@ class Template: system: Optional[str] = None, tools: Optional[str] = None, cutoff_len: Optional[int] = 1_000_000, + reserved_label_len: Optional[int] = 16, ) -> Tuple[List[int], List[int]]: r""" Returns a single pair of token ids representing prompt and response respectively. """ - encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len) + encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: - prompt_ids = prompt_ids + query_ids + resp_ids + prompt_ids += query_ids + resp_ids prompt_ids = prompt_ids + encoded_pairs[-1][0] answer_ids = encoded_pairs[-1][1] return prompt_ids, answer_ids @@ -50,15 +54,15 @@ class Template: self, tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], - system: str, - tools: str, + system: Optional[str] = None, + tools: Optional[str] = None, cutoff_len: Optional[int] = 1_000_000, - ) -> List[Tuple[List[int], List[int]]]: + reserved_label_len: Optional[int] = 16, + ) -> Sequence[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ - encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len) - return encoded_pairs + return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) def _encode( self, @@ -67,48 +71,37 @@ class Template: system: str, tools: str, cutoff_len: int, - ) -> List[Tuple[List[int], List[int]]]: + reserved_label_len: int, + ) -> Sequence[Tuple[List[int], List[int]]]: r""" Encodes formatted inputs to pairs of token ids. - Turn 0: system + query resp + eos - Turn t: sep + query resp + eos + Turn 0: system + query resp + Turn t: sep + query resp """ - system = system or self.system + system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): elements = [] - if i == 0 and (system or tools): - tool_text = self.format_tool(content=tools)[0] if tools else "" - elements += self.format_system(content=(system + tool_text)) + if i == 0 and (system or tools or self.force_system): + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + elements += self.format_system.apply(content=(system + tool_text)) elif i > 0 and i % 2 == 0: - elements += self.separator + elements += self.format_separator.apply() if message["role"] == Role.USER: - elements += self.format_user(content=message["content"], idx=str(i // 2)) + elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) elif message["role"] == Role.ASSISTANT: - elements += self.format_assistant(content=message["content"]) + elements += self.format_assistant.apply(content=message["content"]) elif message["role"] == Role.OBSERVATION: - elements += self.format_observation(content=message["content"]) + elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION: - elements += self.format_function(content=message["content"]) + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) - # TODO: need to improve - encoded_pairs = [] - total_length = 0 - for i in range(0, len(encoded_messages), 2): - if total_length >= cutoff_len: - break - - encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length] - total_length += len(encoded_messages[i]) - - encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)] - total_length += len(encoded_messages[i + 1]) - encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1])) - - return encoded_pairs + return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) def _convert_elements_to_ids( self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]] @@ -120,19 +113,44 @@ class Template: for elem in elements: if isinstance(elem, str): if len(elem) != 0: - token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False) + token_ids += tokenizer.encode(elem, add_special_tokens=False) elif isinstance(elem, dict): - token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] + token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))] elif isinstance(elem, set): if "bos_token" in elem and tokenizer.bos_token_id: - token_ids = token_ids + [tokenizer.bos_token_id] + token_ids += [tokenizer.bos_token_id] elif "eos_token" in elem and tokenizer.eos_token_id: - token_ids = token_ids + [tokenizer.eos_token_id] + token_ids += [tokenizer.eos_token_id] else: raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) return token_ids + def _make_pairs( + self, + encoded_messages: Sequence[List[int]], + cutoff_len: int, + reserved_label_len: int, + ) -> Sequence[Tuple[List[int], List[int]]]: + encoded_pairs = [] + total_length = 0 + for i in range(0, len(encoded_messages), 2): + if total_length >= cutoff_len: + break + + max_source_len, max_target_len = infer_max_len( + source_len=len(encoded_messages[i]), + target_len=len(encoded_messages[i + 1]), + cutoff_len=(cutoff_len - total_length), + reserved_label_len=reserved_label_len, + ) + encoded_messages[i] = encoded_messages[i][: max_source_len] + encoded_messages[i + 1] = encoded_messages[i + 1][: max_target_len] + total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1]) + encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1])) + + return encoded_pairs + @dataclass class Llama2Template(Template): @@ -143,49 +161,38 @@ class Llama2Template(Template): system: str, tools: str, cutoff_len: int, - ) -> List[Tuple[List[int], List[int]]]: + reserved_label_len: int, + ) -> Sequence[Tuple[List[int], List[int]]]: r""" Encodes formatted inputs to pairs of token ids. - Turn 0: system + query resp + eos - Turn t: sep + query resp + eos + Turn 0: system + query resp + Turn t: sep + query resp """ - system = system or self.system + system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): elements = [] system_text = "" - if i == 0 and (system or tools): - tool_text = self.format_tool(content=tools)[0] if tools else "" - system_text = self.format_system(content=(system + tool_text))[0] + if i == 0 and (system or tools or self.force_system): + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + system_text = self.format_system.apply(content=(system + tool_text))[0] elif i > 0 and i % 2 == 0: - elements += self.separator + elements += self.format_separator.apply() if message["role"] == Role.USER: - elements += self.format_user(content=system_text + message["content"], idx=str(i // 2)) + elements += self.format_user.apply(content=system_text + message["content"]) elif message["role"] == Role.ASSISTANT: - elements += self.format_assistant(content=message["content"]) + elements += self.format_assistant.apply(content=message["content"]) elif message["role"] == Role.OBSERVATION: - elements += self.format_observation(content=message["content"]) + elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION: - elements += self.format_function(content=message["content"]) + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) - # TODO: need to improve - encoded_pairs = [] - total_length = 0 - for i in range(0, len(encoded_messages), 2): - if total_length >= cutoff_len: - break - - encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length] - total_length += len(encoded_messages[i]) - - encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)] - total_length += len(encoded_messages[i + 1]) - encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1])) - - return encoded_pairs + return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) templates: Dict[str, Template] = {} @@ -193,32 +200,39 @@ templates: Dict[str, Template] = {} def register_template( name: str, - format_user: Optional[Callable] = None, - format_assistant: Optional[Callable] = None, - format_system: Optional[Callable] = None, - format_tool: Optional[Callable] = None, - format_observation: Optional[Callable] = None, - format_function: Optional[Callable] = None, - system: Optional[str] = "", - separator: Optional[List[Union[str, Dict[str, str]]]] = "", + format_user: Optional["Formatter"] = None, + format_assistant: Optional["Formatter"] = None, + format_system: Optional["Formatter"] = None, + format_function: Optional["Formatter"] = None, + format_observation: Optional["Formatter"] = None, + format_tools: Optional["Formatter"] = None, + format_separator: Optional["Formatter"] = None, + default_system: Optional[str] = "", stop_words: Optional[List[str]] = [], efficient_eos: Optional[bool] = False, replace_eos: Optional[bool] = False, + force_system: Optional[bool] = False, ) -> None: + eos_slots = [] if efficient_eos else [{"eos_token"}] template_class = Llama2Template if name.startswith("llama2") else Template + default_user_formatter = StringFormatter(slots=["{{content}}"]) + default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) + default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) + default_tool_formatter = ToolFormatter(slots="default") + default_separator_formatter = EmptyFormatter() templates[name] = template_class( - format_user=format_user or StringFormatter(container=["{{content}}"]), - format_assistant=format_assistant or StringFormatter(container=["{{content}}", {"eos_token"}]), - format_system=format_system or StringFormatter(container=["{{content}}"]), - format_tool=format_tool or ToolFormatter(type="default"), - format_observation=format_observation or format_user, - format_function=format_function - or FunctionFormatter(container=["Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}]), - system=system, - separator=separator, + format_user=format_user or default_user_formatter, + format_assistant=format_assistant or default_assistant_formatter, + format_system=format_system or default_user_formatter, + format_function=format_function or default_function_formatter, + format_observation=format_observation or format_user or default_user_formatter, + format_tools=format_tools or default_tool_formatter, + format_separator=format_separator or default_separator_formatter, + default_system=default_system, stop_words=stop_words, efficient_eos=efficient_eos, replace_eos=replace_eos, + force_system=force_system, ) @@ -257,23 +271,22 @@ def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer") register_template( name="alpaca", - format_user=StringFormatter(container=["### Instruction:\n{{content}}\n\n### Response:\n"]), - system=( + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + default_system=( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request." ), - separator=["\n\n"], ) register_template( name="aquila", - format_user=StringFormatter(container=["Human: {{content}}###Assistant:"]), - format_assistant=StringFormatter(container=["{{content}}"]), - system=( + format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), + format_separator=EmptyFormatter(slots=["###"]), + default_system=( "A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions." ), - separator=["###"], stop_words=[""], efficient_eos=True, ) @@ -281,51 +294,53 @@ register_template( register_template( name="baichuan", - format_user=StringFormatter(container=[{"token": ""}, "{{content}}", {"token": ""}]), - format_assistant=StringFormatter(container=["{{content}}"]), + format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), efficient_eos=True, ) register_template( name="baichuan2", - format_user=StringFormatter(container=[{"token": ""}, "{{content}}", {"token": ""}]), - format_assistant=StringFormatter(container=["{{content}}"]), + format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), efficient_eos=True, ) register_template( - name="belle", format_user=StringFormatter(container=["Human: {{content}}\n\nBelle: "]), separator=["\n\n"] + name="belle", + format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + force_system=True, ) register_template( name="bluelm", - format_user=StringFormatter(container=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), + format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), ) register_template( name="chatglm2", - format_user=StringFormatter(container=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), - format_assistant=StringFormatter(container=["{{content}}"]), - format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), - separator=["\n\n"], + format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), + format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), + format_separator=EmptyFormatter(slots=["\n\n"]), efficient_eos=True, + force_system=True, ) register_template( name="chatglm3", - format_user=StringFormatter(container=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), - format_assistant=StringFormatter(container=["\n" "{{content}}"]), + format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), + format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_system=StringFormatter( - container=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"] + slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"] ), - format_observation=StringFormatter(container=[{"token": "<|observation|>"}, "\n", "{{content}}"]), - format_function=FunctionFormatter(container=["{{name}}\n{{arguments}}"]), - system=( + format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), + format_observation=StringFormatter(slots=[{"token": "<|observation|>"}, "\n", "{{content}}"]), + default_system=( "You are ChatGLM3, a large language model trained by Zhipu.AI. " "Follow the user's instructions carefully. Respond using markdown." ), @@ -335,24 +350,30 @@ register_template( register_template( - name="codegeex2", format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]) + name="codegeex2", + format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), + force_system=True, ) -register_template(name="deepseek", format_user=StringFormatter(container=["User: {{content}}\n\nAssistant:"])) +register_template( + name="deepseek", + format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + force_system=True, +) register_template( name="deepseekcoder", - format_user=StringFormatter(container=["### Instruction:\n{{content}}\n### Response:\n"]), - format_assistant=StringFormatter(container=["{{content}}"]), - system=( + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:\n"]), + format_separator=EmptyFormatter(slots=["\n", {"token": "<|EOT|>"}, "\n"]), + default_system=( "You are an AI programming assistant, utilizing the Deepseek Coder model, " "developed by Deepseek Company, and you only answer questions related to computer science. " "For politically sensitive questions, security and privacy issues, " "and other non-computer science questions, you will refuse to answer\n" ), - separator=["\n", {"token": "<|EOT|>"}, "\n"], stop_words=["<|EOT|>"], efficient_eos=True, ) @@ -360,29 +381,23 @@ register_template( register_template( name="default", - format_user=StringFormatter(container=["Human: {{content}}\nAssistant: "]), - system=( - "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.\n" - ), - separator=["\n"], + format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]), + format_separator=EmptyFormatter(slots=["\n"]), ) register_template( name="falcon", - format_user=StringFormatter(container=["User: {{content}}\nFalcon:"]), - format_assistant=StringFormatter(container=["{{content}}"]), - separator=["\n"], + format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), + format_separator=EmptyFormatter(slots=["\n"]), efficient_eos=True, ) register_template( name="intern", - format_user=StringFormatter(container=["<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:"]), - format_assistant=StringFormatter(container=["{{content}}"]), - separator=[{"token": ""}, "\n"], + format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:"]), + format_separator=EmptyFormatter(slots=[{"token": ""}, "\n"]), stop_words=[""], efficient_eos=True, ) @@ -390,38 +405,26 @@ register_template( register_template( name="intern2", - format_user=StringFormatter( - container=[ - {"token": "[UNUSED_TOKEN_146]"}, - "user\n{{content}}", - {"token": "[UNUSED_TOKEN_145]"}, - "\n", - {"token": "[UNUSED_TOKEN_146]"}, - "assistant\n", - ] - ), - format_assistant=StringFormatter(container=["{{content}}"]), - format_system=StringFormatter( - container=[{"token": "[UNUSED_TOKEN_146]"}, "system\n{{content}}", {"token": "[UNUSED_TOKEN_145]"}, "\n"] - ), - system=( + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system=( "You are an AI assistant whose name is InternLM (书生·浦语).\n" "- InternLM (书生·浦语) is a conversational language model that is developed " "by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen " "by the user such as English and 中文." ), - separator=[{"token": "[UNUSED_TOKEN_145]"}, "\n"], - stop_words=["[UNUSED_TOKEN_145]"], - efficient_eos=True, + stop_words=["<|im_end|>"], + replace_eos=True, ) register_template( name="llama2", - format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]), - format_system=StringFormatter(container=["<>\n{{content}}\n<>\n\n"]), - system=( + format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), + format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), + default_system=( "You are a helpful, respectful and honest assistant. " "Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, " @@ -436,51 +439,60 @@ register_template( register_template( name="llama2_zh", - format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]), - format_system=StringFormatter(container=["<>\n{{content}}\n<>\n\n"]), - system="You are a helpful assistant. 你是一个乐于助人的助手。", + format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), + format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), + default_system="You are a helpful assistant. 你是一个乐于助人的助手。", ) -register_template(name="mistral", format_user=StringFormatter(container=["[INST] {{content}} [/INST]"])) +register_template( + name="mistral", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + force_system=True, +) register_template( name="openchat", format_user=StringFormatter( - container=["GPT4 Correct User: {{content}}", {"token": "<|end_of_turn|>"}, "GPT4 Correct Assistant:"] + slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"] ), - format_assistant=StringFormatter(container=["{{content}}"]), - separator=[{"token": "<|end_of_turn|>"}], - stop_words=["<|end_of_turn|>"], - efficient_eos=True, + format_assistant=StringFormatter(slots=["{{content}}"]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + force_system=True, ) register_template( name="qwen", - format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_system=StringFormatter(container=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - system="You are a helpful assistant.", - separator=["\n"], + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], replace_eos=True, ) -register_template(name="solar", format_user=StringFormatter(container=["### User:\n{{content}}\n\n### Assistant:\n"])) +register_template( + name="solar", + format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), + format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]), + efficient_eos=True, +) register_template( name="starchat", format_user=StringFormatter( - container=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}] + slots=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}] ), - format_assistant=StringFormatter(container=["{{content}}"]), - format_system=StringFormatter(container=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]), - separator=[{"token": "<|end|>"}, "\n"], + format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]), + format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|end|>"], - efficient_eos=True, + replace_eos=True, + force_system=True, ) @@ -489,8 +501,8 @@ register_template(name="vanilla") register_template( name="vicuna", - format_user=StringFormatter(container=["USER: {{content}} ASSISTANT:"]), - system=( + format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), + default_system=( "A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions." ), @@ -499,8 +511,8 @@ register_template( register_template( name="xuanyuan", - format_user=StringFormatter(container=["Human: {{content}} Assistant:"]), - system=( + format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), + default_system=( "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头," "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、" "不安全、有争议、政治敏感等相关的话题、问题和指示。\n" @@ -508,14 +520,15 @@ register_template( ) -register_template(name="xverse", format_user=StringFormatter(container=["Human: {{content}}\n\nAssistant: "])) +register_template(name="xverse", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "])) register_template( name="yayi", - format_user=StringFormatter(container=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), - format_system=StringFormatter(container=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]), - system=( + format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), + format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + default_system=( "You are a helpful, respectful and honest assistant named YaYi " "developed by Beijing Wenge Technology Co.,Ltd. " "Always answer as helpfully as possible, while being safe. " @@ -526,15 +539,14 @@ register_template( "explain why instead of answering something not correct. " "If you don't know the answer to a question, please don't share false information." ), - separator=["\n\n"], stop_words=["<|End|>"], ) register_template( name="yi", - format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - separator=["\n"], + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|im_end|>"], replace_eos=True, ) @@ -542,8 +554,8 @@ register_template( register_template( name="yuan", - format_user=StringFormatter(container=["{{content}}", {"token": ""}]), - separator=["\n"], + format_user=StringFormatter(slots=["{{content}}", {"token": ""}]), + format_separator=EmptyFormatter(slots=["\n"]), stop_words=[""], replace_eos=True, ) @@ -551,18 +563,14 @@ register_template( register_template( name="zephyr", - format_user=StringFormatter(container=["<|user|>\n{{content}}<|assistant|>"]), - format_system=StringFormatter( - container=[ - "<|system|>\n{{content}}", - ] - ), - system="You are a friendly chatbot who always responds in the style of a pirate", + format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), + default_system="You are a friendly chatbot who always responds in the style of a pirate", ) register_template( name="ziya", - format_user=StringFormatter(container=[{"token": ""}, ":{{content}}\n", {"token": ""}, ":"]), - separator=["\n"], + format_user=StringFormatter(slots=[{"token": ""}, ":{{content}}\n", {"token": ""}, ":"]), + format_separator=EmptyFormatter(slots=["\n"]), ) diff --git a/src/llmtuner/data/utils.py b/src/llmtuner/data/utils.py index 53b8054e..b8dfa123 100644 --- a/src/llmtuner/data/utils.py +++ b/src/llmtuner/data/utils.py @@ -38,10 +38,10 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) -def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]: - max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len))) - max_target_len = max(max_target_len, data_args.reserved_label_len) - max_source_len = data_args.cutoff_len - max_target_len +def infer_max_len(source_len: int, target_len: int, cutoff_len: int, reserved_label_len: int) -> Tuple[int, int]: + max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) + max_target_len = max(max_target_len, reserved_label_len) + max_source_len = cutoff_len - max_target_len return max_source_len, max_target_len