This commit is contained in:
hiyouga 2024-01-20 23:22:09 +08:00
parent 71cfdc2658
commit cf818a2598
5 changed files with 316 additions and 282 deletions

View File

@ -13,7 +13,7 @@ except ImportError:
def main(): def main():
chat_model = ChatModel() chat_model = ChatModel()
history = [] messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True: while True:
@ -37,12 +37,13 @@ def main():
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
response = "" response = ""
for new_text in chat_model.stream_chat(query, history): for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True) print(new_text, end="", flush=True)
response += new_text response += new_text
print() print()
history = history + [(query, response)] messages.append({"role": "user", "content": query})
messages.append({"role": "assistant", "content": response})
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,11 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from threading import Thread from threading import Thread
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from ..data import Role, get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..hparams import get_infer_args from ..hparams import get_infer_args
from ..model import dispatch_model, load_model_and_tokenizer from ..model import dispatch_model, load_model_and_tokenizer
@ -32,20 +32,11 @@ class ChatModel:
def _process_args( def _process_args(
self, self,
query: str, messages: Sequence[Dict[str, str]],
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs, **input_kwargs,
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
messages = []
if history is not None:
for old_prompt, old_response in history:
messages.append({"role": Role.USER, "content": old_prompt})
messages.append({"role": Role.ASSISTANT, "content": old_response})
messages.append({"role": Role.USER, "content": query})
messages.append({"role": Role.ASSISTANT, "content": ""})
prompt, _ = self.template.encode_oneturn( prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
) )
@ -97,18 +88,12 @@ class ChatModel:
@torch.inference_mode() @torch.inference_mode()
def chat( def chat(
self, self,
query: str, messages: Sequence[Dict[str, str]],
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs, **input_kwargs,
) -> List[Response]: ) -> List[Response]:
r""" gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
Args: query, history, system, **input_kwargs
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
"""
gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs) generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode( response = self.tokenizer.batch_decode(
@ -132,13 +117,12 @@ class ChatModel:
@torch.inference_mode() @torch.inference_mode()
def stream_chat( def stream_chat(
self, self,
query: str, messages: Sequence[Dict[str, str]],
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs) gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer

View File

@ -1,6 +1,11 @@
import json import json
from dataclasses import dataclass import re
from typing import Any, Dict, List, Literal, Union from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Set, Sequence, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
JSON_FORMAT_PROMPT = ( JSON_FORMAT_PROMPT = (
@ -18,56 +23,7 @@ TOOL_SYSTEM_PROMPT = (
) )
@dataclass def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
class StringFormatter:
container: List[Union[str, Dict[str, str]]]
def __call__(self, **kwargs) -> List[Union[str, Dict[str, str]]]:
elements = []
for elem in self.container:
if isinstance(elem, str):
for name, value in kwargs.items():
elem = elem.replace("{{" + name + "}}", value)
elements.append(elem)
elif isinstance(elem, (dict, set)):
elements.append(elem)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return elements
@dataclass
class FunctionFormatter:
container: List[Union[str, Dict[str, str]]]
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
try:
function = json.loads(content)
name = function["name"]
arguments = json.dumps(function["arguments"], ensure_ascii=False)
except Exception:
name, arguments = "", ""
elements = []
for elem in self.container:
if isinstance(elem, str):
elem = elem.replace("{{name}}", name)
elem = elem.replace("{{arguments}}", arguments)
elements.append(elem)
elif isinstance(elem, (dict, set)):
elements.append(elem)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return elements
@dataclass
class ToolFormatter:
type: Literal["default"]
def _default(self, tools: List[Dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
tool_names = [] tool_names = []
for tool in tools: for tool in tools:
@ -92,13 +48,98 @@ class ToolFormatter:
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
) )
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
action_match = re.search(regex, content)
if not action_match:
return content
tool_name = action_match.group(1).strip()
tool_input = action_match.group(2).strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return tool_name, json.dumps(arguments, ensure_ascii=False)
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Literal["default"] = "default"
@abstractmethod
def apply(self, **kwargs) -> SLOTS:
...
@dataclass
class EmptyFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
return self.slots
@dataclass
class StringFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
elements = []
for slot in self.slots:
if isinstance(slot, str):
for name, value in kwargs.items():
slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@dataclass
class FunctionFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
function = json.loads(content)
name = function["name"]
arguments = json.dumps(function["arguments"], ensure_ascii=False)
except Exception:
name, arguments = "", ""
elements = []
for slot in self.slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@dataclass
class ToolFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try: try:
tools = json.loads(content) tools = json.loads(content)
if not len(tools): if not len(tools):
return [""] return [""]
if self.type == "default": if self.tool_format == "default":
return [self._default(tools)] return [default_tool_formatter(tools)]
else:
raise NotImplementedError
except Exception: except Exception:
return [""] return [""]
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
if self.tool_format == "default":
return default_tool_extractor(content)
else:
raise NotImplementedError

View File

@ -1,31 +1,34 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .formatter import FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .utils import Role from .utils import Role, infer_max_len
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from .formatter import Formatter
logger = get_logger(__name__) logger = get_logger(__name__)
@dataclass @dataclass
class Template: class Template:
format_user: Callable format_user: "Formatter"
format_assistant: Callable format_assistant: "Formatter"
format_system: Callable format_system: "Formatter"
format_tool: Callable format_function: "Formatter"
format_observation: Callable format_observation: "Formatter"
format_function: Callable format_tools: "Formatter"
system: str format_separator: "Formatter"
separator: List[Union[str, Dict[str, str]]] default_system: str
stop_words: List[str] stop_words: List[str]
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
force_system: bool
def encode_oneturn( def encode_oneturn(
self, self,
@ -34,14 +37,15 @@ class Template:
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000, cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
r""" r"""
Returns a single pair of token ids representing prompt and response respectively. Returns a single pair of token ids representing prompt and response respectively.
""" """
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len) encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
prompt_ids = [] prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]: for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids prompt_ids += query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0] prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1] answer_ids = encoded_pairs[-1][1]
return prompt_ids, answer_ids return prompt_ids, answer_ids
@ -50,15 +54,15 @@ class Template:
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: str, system: Optional[str] = None,
tools: str, tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000, cutoff_len: Optional[int] = 1_000_000,
) -> List[Tuple[List[int], List[int]]]: reserved_label_len: Optional[int] = 16,
) -> Sequence[Tuple[List[int], List[int]]]:
r""" r"""
Returns multiple pairs of token ids representing prompts and responses respectively. Returns multiple pairs of token ids representing prompts and responses respectively.
""" """
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len) return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
return encoded_pairs
def _encode( def _encode(
self, self,
@ -67,48 +71,37 @@ class Template:
system: str, system: str,
tools: str, tools: str,
cutoff_len: int, cutoff_len: int,
) -> List[Tuple[List[int], List[int]]]: reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r""" r"""
Encodes formatted inputs to pairs of token ids. Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp + eos Turn 0: system + query resp
Turn t: sep + query resp + eos Turn t: sep + query resp
""" """
system = system or self.system system = system or self.default_system
encoded_messages = [] encoded_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
elements = [] elements = []
if i == 0 and (system or tools): if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tool(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system(content=(system + tool_text)) elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0: elif i > 0 and i % 2 == 0:
elements += self.separator elements += self.format_separator.apply()
if message["role"] == Role.USER: if message["role"] == Role.USER:
elements += self.format_user(content=message["content"], idx=str(i // 2)) elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT: elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant(content=message["content"]) elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION: elif message["role"] == Role.OBSERVATION:
elements += self.format_observation(content=message["content"]) elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION: elif message["role"] == Role.FUNCTION:
elements += self.format_function(content=message["content"]) elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
# TODO: need to improve return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
total_length += len(encoded_messages[i])
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
total_length += len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
return encoded_pairs
def _convert_elements_to_ids( def _convert_elements_to_ids(
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]] self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
@ -120,19 +113,44 @@ class Template:
for elem in elements: for elem in elements:
if isinstance(elem, str): if isinstance(elem, str):
if len(elem) != 0: if len(elem) != 0:
token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False) token_ids += tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict): elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set): elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id: if "bos_token" in elem and tokenizer.bos_token_id:
token_ids = token_ids + [tokenizer.bos_token_id] token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id: elif "eos_token" in elem and tokenizer.eos_token_id:
token_ids = token_ids + [tokenizer.eos_token_id] token_ids += [tokenizer.eos_token_id]
else: else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return token_ids return token_ids
def _make_pairs(
self,
encoded_messages: Sequence[List[int]],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
max_source_len, max_target_len = infer_max_len(
source_len=len(encoded_messages[i]),
target_len=len(encoded_messages[i + 1]),
cutoff_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
encoded_messages[i] = encoded_messages[i][: max_source_len]
encoded_messages[i + 1] = encoded_messages[i + 1][: max_target_len]
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
return encoded_pairs
@dataclass @dataclass
class Llama2Template(Template): class Llama2Template(Template):
@ -143,49 +161,38 @@ class Llama2Template(Template):
system: str, system: str,
tools: str, tools: str,
cutoff_len: int, cutoff_len: int,
) -> List[Tuple[List[int], List[int]]]: reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r""" r"""
Encodes formatted inputs to pairs of token ids. Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp + eos Turn 0: system + query resp
Turn t: sep + query resp + eos Turn t: sep + query resp
""" """
system = system or self.system system = system or self.default_system
encoded_messages = [] encoded_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
elements = [] elements = []
system_text = "" system_text = ""
if i == 0 and (system or tools): if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tool(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system(content=(system + tool_text))[0] system_text = self.format_system.apply(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0: elif i > 0 and i % 2 == 0:
elements += self.separator elements += self.format_separator.apply()
if message["role"] == Role.USER: if message["role"] == Role.USER:
elements += self.format_user(content=system_text + message["content"], idx=str(i // 2)) elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT: elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant(content=message["content"]) elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION: elif message["role"] == Role.OBSERVATION:
elements += self.format_observation(content=message["content"]) elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION: elif message["role"] == Role.FUNCTION:
elements += self.format_function(content=message["content"]) elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
# TODO: need to improve return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
total_length += len(encoded_messages[i])
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
total_length += len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
return encoded_pairs
templates: Dict[str, Template] = {} templates: Dict[str, Template] = {}
@ -193,32 +200,39 @@ templates: Dict[str, Template] = {}
def register_template( def register_template(
name: str, name: str,
format_user: Optional[Callable] = None, format_user: Optional["Formatter"] = None,
format_assistant: Optional[Callable] = None, format_assistant: Optional["Formatter"] = None,
format_system: Optional[Callable] = None, format_system: Optional["Formatter"] = None,
format_tool: Optional[Callable] = None, format_function: Optional["Formatter"] = None,
format_observation: Optional[Callable] = None, format_observation: Optional["Formatter"] = None,
format_function: Optional[Callable] = None, format_tools: Optional["Formatter"] = None,
system: Optional[str] = "", format_separator: Optional["Formatter"] = None,
separator: Optional[List[Union[str, Dict[str, str]]]] = "", default_system: Optional[str] = "",
stop_words: Optional[List[str]] = [], stop_words: Optional[List[str]] = [],
efficient_eos: Optional[bool] = False, efficient_eos: Optional[bool] = False,
replace_eos: Optional[bool] = False, replace_eos: Optional[bool] = False,
force_system: Optional[bool] = False,
) -> None: ) -> None:
eos_slots = [] if efficient_eos else [{"eos_token"}]
template_class = Llama2Template if name.startswith("llama2") else Template template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(slots="default")
default_separator_formatter = EmptyFormatter()
templates[name] = template_class( templates[name] = template_class(
format_user=format_user or StringFormatter(container=["{{content}}"]), format_user=format_user or default_user_formatter,
format_assistant=format_assistant or StringFormatter(container=["{{content}}", {"eos_token"}]), format_assistant=format_assistant or default_assistant_formatter,
format_system=format_system or StringFormatter(container=["{{content}}"]), format_system=format_system or default_user_formatter,
format_tool=format_tool or ToolFormatter(type="default"), format_function=format_function or default_function_formatter,
format_observation=format_observation or format_user, format_observation=format_observation or format_user or default_user_formatter,
format_function=format_function format_tools=format_tools or default_tool_formatter,
or FunctionFormatter(container=["Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}]), format_separator=format_separator or default_separator_formatter,
system=system, default_system=default_system,
separator=separator,
stop_words=stop_words, stop_words=stop_words,
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
force_system=force_system,
) )
@ -257,23 +271,22 @@ def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer")
register_template( register_template(
name="alpaca", name="alpaca",
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n\n### Response:\n"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
system=( format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request." "Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
), ),
separator=["\n\n"],
) )
register_template( register_template(
name="aquila", name="aquila",
format_user=StringFormatter(container=["Human: {{content}}###Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_assistant=StringFormatter(container=["{{content}}"]), format_separator=EmptyFormatter(slots=["###"]),
system=( default_system=(
"A chat between a curious human and an artificial intelligence assistant. " "A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions." "The assistant gives helpful, detailed, and polite answers to the human's questions."
), ),
separator=["###"],
stop_words=["</s>"], stop_words=["</s>"],
efficient_eos=True, efficient_eos=True,
) )
@ -281,51 +294,53 @@ register_template(
register_template( register_template(
name="baichuan", name="baichuan",
format_user=StringFormatter(container=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]), format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
format_assistant=StringFormatter(container=["{{content}}"]),
efficient_eos=True, efficient_eos=True,
) )
register_template( register_template(
name="baichuan2", name="baichuan2",
format_user=StringFormatter(container=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]), format_user=StringFormatter(slots=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
format_assistant=StringFormatter(container=["{{content}}"]),
efficient_eos=True, efficient_eos=True,
) )
register_template( register_template(
name="belle", format_user=StringFormatter(container=["Human: {{content}}\n\nBelle: "]), separator=["\n\n"] name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
force_system=True,
) )
register_template( register_template(
name="bluelm", name="bluelm",
format_user=StringFormatter(container=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
) )
register_template( register_template(
name="chatglm2", name="chatglm2",
format_user=StringFormatter(container=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_assistant=StringFormatter(container=["{{content}}"]), format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_separator=EmptyFormatter(slots=["\n\n"]),
separator=["\n\n"],
efficient_eos=True, efficient_eos=True,
force_system=True,
) )
register_template( register_template(
name="chatglm3", name="chatglm3",
format_user=StringFormatter(container=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(container=["\n" "{{content}}"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter( format_system=StringFormatter(
container=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"] slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
), ),
format_observation=StringFormatter(container=[{"token": "<|observation|>"}, "\n", "{{content}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_function=FunctionFormatter(container=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter(slots=[{"token": "<|observation|>"}, "\n", "{{content}}"]),
system=( default_system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. " "You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown." "Follow the user's instructions carefully. Respond using markdown."
), ),
@ -335,24 +350,30 @@ register_template(
register_template( register_template(
name="codegeex2", format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]) name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
force_system=True,
) )
register_template(name="deepseek", format_user=StringFormatter(container=["User: {{content}}\n\nAssistant:"])) register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
register_template( register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n### Response:\n"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:\n"]),
format_assistant=StringFormatter(container=["{{content}}"]), format_separator=EmptyFormatter(slots=["\n", {"token": "<|EOT|>"}, "\n"]),
system=( default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, " "You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. " "developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, " "For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n" "and other non-computer science questions, you will refuse to answer\n"
), ),
separator=["\n", {"token": "<|EOT|>"}, "\n"],
stop_words=["<|EOT|>"], stop_words=["<|EOT|>"],
efficient_eos=True, efficient_eos=True,
) )
@ -360,29 +381,23 @@ register_template(
register_template( register_template(
name="default", name="default",
format_user=StringFormatter(container=["Human: {{content}}\nAssistant: "]), format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
system=( format_separator=EmptyFormatter(slots=["\n"]),
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
),
separator=["\n"],
) )
register_template( register_template(
name="falcon", name="falcon",
format_user=StringFormatter(container=["User: {{content}}\nFalcon:"]), format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
format_assistant=StringFormatter(container=["{{content}}"]), format_separator=EmptyFormatter(slots=["\n"]),
separator=["\n"],
efficient_eos=True, efficient_eos=True,
) )
register_template( register_template(
name="intern", name="intern",
format_user=StringFormatter(container=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
format_assistant=StringFormatter(container=["{{content}}"]), format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
separator=[{"token": "<eoa>"}, "\n"],
stop_words=["<eoa>"], stop_words=["<eoa>"],
efficient_eos=True, efficient_eos=True,
) )
@ -390,38 +405,26 @@ register_template(
register_template( register_template(
name="intern2", name="intern2",
format_user=StringFormatter( format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
container=[ format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
{"token": "[UNUSED_TOKEN_146]"}, format_separator=EmptyFormatter(slots=["\n"]),
"user\n{{content}}", default_system=(
{"token": "[UNUSED_TOKEN_145]"},
"\n",
{"token": "[UNUSED_TOKEN_146]"},
"assistant\n",
]
),
format_assistant=StringFormatter(container=["{{content}}"]),
format_system=StringFormatter(
container=[{"token": "[UNUSED_TOKEN_146]"}, "system\n{{content}}", {"token": "[UNUSED_TOKEN_145]"}, "\n"]
),
system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n" "You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed " "- InternLM (书生·浦语) is a conversational language model that is developed "
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" "by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen " "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
"by the user such as English and 中文." "by the user such as English and 中文."
), ),
separator=[{"token": "[UNUSED_TOKEN_145]"}, "\n"], stop_words=["<|im_end|>"],
stop_words=["[UNUSED_TOKEN_145]"], replace_eos=True,
efficient_eos=True,
) )
register_template( register_template(
name="llama2", name="llama2",
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]), format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
system=( default_system=(
"You are a helpful, respectful and honest assistant. " "You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. " "Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, " "Your answers should not include any harmful, unethical, "
@ -436,51 +439,60 @@ register_template(
register_template( register_template(
name="llama2_zh", name="llama2_zh",
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]), format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
system="You are a helpful assistant. 你是一个乐于助人的助手。", default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
) )
register_template(name="mistral", format_user=StringFormatter(container=["[INST] {{content}} [/INST]"])) register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
register_template( register_template(
name="openchat", name="openchat",
format_user=StringFormatter( format_user=StringFormatter(
container=["GPT4 Correct User: {{content}}", {"token": "<|end_of_turn|>"}, "GPT4 Correct Assistant:"] slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]
), ),
format_assistant=StringFormatter(container=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]),
separator=[{"token": "<|end_of_turn|>"}], format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
stop_words=["<|end_of_turn|>"], force_system=True,
efficient_eos=True,
) )
register_template( register_template(
name="qwen", name="qwen",
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(container=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
system="You are a helpful assistant.", format_separator=EmptyFormatter(slots=["\n"]),
separator=["\n"], default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
) )
register_template(name="solar", format_user=StringFormatter(container=["### User:\n{{content}}\n\n### Assistant:\n"])) register_template(
name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
efficient_eos=True,
)
register_template( register_template(
name="starchat", name="starchat",
format_user=StringFormatter( format_user=StringFormatter(
container=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}] slots=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
), ),
format_assistant=StringFormatter(container=["{{content}}"]), format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
format_system=StringFormatter(container=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]), format_separator=EmptyFormatter(slots=["\n"]),
separator=[{"token": "<|end|>"}, "\n"],
stop_words=["<|end|>"], stop_words=["<|end|>"],
efficient_eos=True, replace_eos=True,
force_system=True,
) )
@ -489,8 +501,8 @@ register_template(name="vanilla")
register_template( register_template(
name="vicuna", name="vicuna",
format_user=StringFormatter(container=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
system=( default_system=(
"A chat between a curious user and an artificial intelligence assistant. " "A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions." "The assistant gives helpful, detailed, and polite answers to the user's questions."
), ),
@ -499,8 +511,8 @@ register_template(
register_template( register_template(
name="xuanyuan", name="xuanyuan",
format_user=StringFormatter(container=["Human: {{content}} Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
system=( default_system=(
"以下是用户和人工智能助手之间的对话。用户以Human开头人工智能助手以Assistant开头" "以下是用户和人工智能助手之间的对话。用户以Human开头人工智能助手以Assistant开头"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、" "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n" "不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
@ -508,14 +520,15 @@ register_template(
) )
register_template(name="xverse", format_user=StringFormatter(container=["Human: {{content}}\n\nAssistant: "])) register_template(name="xverse", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]))
register_template( register_template(
name="yayi", name="yayi",
format_user=StringFormatter(container=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_system=StringFormatter(container=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]), format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
system=( format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant named YaYi " "You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. " "developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. " "Always answer as helpfully as possible, while being safe. "
@ -526,15 +539,14 @@ register_template(
"explain why instead of answering something not correct. " "explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information." "If you don't know the answer to a question, please don't share false information."
), ),
separator=["\n\n"],
stop_words=["<|End|>"], stop_words=["<|End|>"],
) )
register_template( register_template(
name="yi", name="yi",
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
separator=["\n"], format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
) )
@ -542,8 +554,8 @@ register_template(
register_template( register_template(
name="yuan", name="yuan",
format_user=StringFormatter(container=["{{content}}", {"token": "<sep>"}]), format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
separator=["\n"], format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<eod>"], stop_words=["<eod>"],
replace_eos=True, replace_eos=True,
) )
@ -551,18 +563,14 @@ register_template(
register_template( register_template(
name="zephyr", name="zephyr",
format_user=StringFormatter(container=["<|user|>\n{{content}}</s><|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
format_system=StringFormatter( format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
container=[ default_system="You are a friendly chatbot who always responds in the style of a pirate",
"<|system|>\n{{content}}</s>",
]
),
system="You are a friendly chatbot who always responds in the style of a pirate",
) )
register_template( register_template(
name="ziya", name="ziya",
format_user=StringFormatter(container=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]), format_user=StringFormatter(slots=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
separator=["\n"], format_separator=EmptyFormatter(slots=["\n"]),
) )

View File

@ -38,10 +38,10 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]: def infer_max_len(source_len: int, target_len: int, cutoff_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len))) max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len) max_target_len = max(max_target_len, reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len max_source_len = cutoff_len - max_target_len
return max_source_len, max_target_len return max_source_len, max_target_len