This commit is contained in:
hiyouga 2024-01-20 23:22:09 +08:00
parent 71cfdc2658
commit cf818a2598
5 changed files with 316 additions and 282 deletions

View File

@ -13,7 +13,7 @@ except ImportError:
def main():
chat_model = ChatModel()
history = []
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
@ -37,12 +37,13 @@ def main():
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(query, history):
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
history = history + [(query, response)]
messages.append({"role": "user", "content": query})
messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":

View File

@ -1,11 +1,11 @@
from dataclasses import dataclass
from threading import Thread
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import Role, get_template_and_fix_tokenizer
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..hparams import get_infer_args
from ..model import dispatch_model, load_model_and_tokenizer
@ -32,20 +32,11 @@ class ChatModel:
def _process_args(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> Tuple[Dict[str, Any], int]:
messages = []
if history is not None:
for old_prompt, old_response in history:
messages.append({"role": Role.USER, "content": old_prompt})
messages.append({"role": Role.ASSISTANT, "content": old_response})
messages.append({"role": Role.USER, "content": query})
messages.append({"role": Role.ASSISTANT, "content": ""})
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
)
@ -97,18 +88,12 @@ class ChatModel:
@torch.inference_mode()
def chat(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> List[Response]:
r"""
Args: query, history, system, **input_kwargs
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
"""
gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs)
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode(
@ -132,13 +117,12 @@ class ChatModel:
@torch.inference_mode()
def stream_chat(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> Generator[str, None, None]:
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer

View File

@ -1,6 +1,11 @@
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Union
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Set, Sequence, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
JSON_FORMAT_PROMPT = (
@ -18,30 +23,85 @@ TOOL_SYSTEM_PROMPT = (
)
@dataclass
class StringFormatter:
container: List[Union[str, Dict[str, str]]]
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
)
def __call__(self, **kwargs) -> List[Union[str, Dict[str, str]]]:
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format(
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
)
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
action_match = re.search(regex, content)
if not action_match:
return content
tool_name = action_match.group(1).strip()
tool_input = action_match.group(2).strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return tool_name, json.dumps(arguments, ensure_ascii=False)
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Literal["default"] = "default"
@abstractmethod
def apply(self, **kwargs) -> SLOTS:
...
@dataclass
class EmptyFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
return self.slots
@dataclass
class StringFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
elements = []
for elem in self.container:
if isinstance(elem, str):
for slot in self.slots:
if isinstance(slot, str):
for name, value in kwargs.items():
elem = elem.replace("{{" + name + "}}", value)
elements.append(elem)
elif isinstance(elem, (dict, set)):
elements.append(elem)
slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@dataclass
class FunctionFormatter:
container: List[Union[str, Dict[str, str]]]
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
class FunctionFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
function = json.loads(content)
name = function["name"]
@ -50,55 +110,36 @@ class FunctionFormatter:
name, arguments = "", ""
elements = []
for elem in self.container:
if isinstance(elem, str):
elem = elem.replace("{{name}}", name)
elem = elem.replace("{{arguments}}", arguments)
elements.append(elem)
elif isinstance(elem, (dict, set)):
elements.append(elem)
for slot in self.slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@dataclass
class ToolFormatter:
type: Literal["default"]
def _default(self, tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format(
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
)
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
class ToolFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
if not len(tools):
return [""]
if self.type == "default":
return [self._default(tools)]
if self.tool_format == "default":
return [default_tool_formatter(tools)]
else:
raise NotImplementedError
except Exception:
return [""]
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
if self.tool_format == "default":
return default_tool_extractor(content)
else:
raise NotImplementedError

View File

@ -1,31 +1,34 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger
from .formatter import FunctionFormatter, StringFormatter, ToolFormatter
from .utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .utils import Role, infer_max_len
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from .formatter import Formatter
logger = get_logger(__name__)
@dataclass
class Template:
format_user: Callable
format_assistant: Callable
format_system: Callable
format_tool: Callable
format_observation: Callable
format_function: Callable
system: str
separator: List[Union[str, Dict[str, str]]]
format_user: "Formatter"
format_assistant: "Formatter"
format_system: "Formatter"
format_function: "Formatter"
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
default_system: str
stop_words: List[str]
efficient_eos: bool
replace_eos: bool
force_system: bool
def encode_oneturn(
self,
@ -34,14 +37,15 @@ class Template:
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids += query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1]
return prompt_ids, answer_ids
@ -50,15 +54,15 @@ class Template:
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
) -> List[Tuple[List[int], List[int]]]:
reserved_label_len: Optional[int] = 16,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
return encoded_pairs
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
def _encode(
self,
@ -67,48 +71,37 @@ class Template:
system: str,
tools: str,
cutoff_len: int,
) -> List[Tuple[List[int], List[int]]]:
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp + eos
Turn t: sep + query resp + eos
Turn 0: system + query resp
Turn t: sep + query resp
"""
system = system or self.system
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tools):
tool_text = self.format_tool(content=tools)[0] if tools else ""
elements += self.format_system(content=(system + tool_text))
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.separator
elements += self.format_separator.apply()
if message["role"] == Role.USER:
elements += self.format_user(content=message["content"], idx=str(i // 2))
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant(content=message["content"])
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation(content=message["content"])
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function(content=message["content"])
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
# TODO: need to improve
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
total_length += len(encoded_messages[i])
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
total_length += len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
return encoded_pairs
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
def _convert_elements_to_ids(
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
@ -120,19 +113,44 @@ class Template:
for elem in elements:
if isinstance(elem, str):
if len(elem) != 0:
token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False)
token_ids += tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id:
token_ids = token_ids + [tokenizer.bos_token_id]
token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id:
token_ids = token_ids + [tokenizer.eos_token_id]
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return token_ids
def _make_pairs(
self,
encoded_messages: Sequence[List[int]],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
max_source_len, max_target_len = infer_max_len(
source_len=len(encoded_messages[i]),
target_len=len(encoded_messages[i + 1]),
cutoff_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
encoded_messages[i] = encoded_messages[i][: max_source_len]
encoded_messages[i + 1] = encoded_messages[i + 1][: max_target_len]
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
return encoded_pairs
@dataclass
class Llama2Template(Template):
@ -143,49 +161,38 @@ class Llama2Template(Template):
system: str,
tools: str,
cutoff_len: int,
) -> List[Tuple[List[int], List[int]]]:
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp + eos
Turn t: sep + query resp + eos
Turn 0: system + query resp
Turn t: sep + query resp
"""
system = system or self.system
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
system_text = ""
if i == 0 and (system or tools):
tool_text = self.format_tool(content=tools)[0] if tools else ""
system_text = self.format_system(content=(system + tool_text))[0]
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
elements += self.separator
elements += self.format_separator.apply()
if message["role"] == Role.USER:
elements += self.format_user(content=system_text + message["content"], idx=str(i // 2))
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant(content=message["content"])
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation(content=message["content"])
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function(content=message["content"])
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
# TODO: need to improve
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
total_length += len(encoded_messages[i])
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
total_length += len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
return encoded_pairs
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
templates: Dict[str, Template] = {}
@ -193,32 +200,39 @@ templates: Dict[str, Template] = {}
def register_template(
name: str,
format_user: Optional[Callable] = None,
format_assistant: Optional[Callable] = None,
format_system: Optional[Callable] = None,
format_tool: Optional[Callable] = None,
format_observation: Optional[Callable] = None,
format_function: Optional[Callable] = None,
system: Optional[str] = "",
separator: Optional[List[Union[str, Dict[str, str]]]] = "",
format_user: Optional["Formatter"] = None,
format_assistant: Optional["Formatter"] = None,
format_system: Optional["Formatter"] = None,
format_function: Optional["Formatter"] = None,
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
default_system: Optional[str] = "",
stop_words: Optional[List[str]] = [],
efficient_eos: Optional[bool] = False,
replace_eos: Optional[bool] = False,
force_system: Optional[bool] = False,
) -> None:
eos_slots = [] if efficient_eos else [{"eos_token"}]
template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(slots="default")
default_separator_formatter = EmptyFormatter()
templates[name] = template_class(
format_user=format_user or StringFormatter(container=["{{content}}"]),
format_assistant=format_assistant or StringFormatter(container=["{{content}}", {"eos_token"}]),
format_system=format_system or StringFormatter(container=["{{content}}"]),
format_tool=format_tool or ToolFormatter(type="default"),
format_observation=format_observation or format_user,
format_function=format_function
or FunctionFormatter(container=["Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}]),
system=system,
separator=separator,
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
format_system=format_system or default_user_formatter,
format_function=format_function or default_function_formatter,
format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
default_system=default_system,
stop_words=stop_words,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
force_system=force_system,
)
@ -257,23 +271,22 @@ def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer")
register_template(
name="alpaca",
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n\n### Response:\n"]),
system=(
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
),
separator=["\n\n"],
)
register_template(
name="aquila",
format_user=StringFormatter(container=["Human: {{content}}###Assistant:"]),
format_assistant=StringFormatter(container=["{{content}}"]),
system=(
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_separator=EmptyFormatter(slots=["###"]),
default_system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
separator=["###"],
stop_words=["</s>"],
efficient_eos=True,
)
@ -281,51 +294,53 @@ register_template(
register_template(
name="baichuan",
format_user=StringFormatter(container=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
format_assistant=StringFormatter(container=["{{content}}"]),
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
register_template(
name="baichuan2",
format_user=StringFormatter(container=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
format_assistant=StringFormatter(container=["{{content}}"]),
format_user=StringFormatter(slots=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
efficient_eos=True,
)
register_template(
name="belle", format_user=StringFormatter(container=["Human: {{content}}\n\nBelle: "]), separator=["\n\n"]
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
force_system=True,
)
register_template(
name="bluelm",
format_user=StringFormatter(container=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
register_template(
name="chatglm2",
format_user=StringFormatter(container=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_assistant=StringFormatter(container=["{{content}}"]),
format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
separator=["\n\n"],
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
force_system=True,
)
register_template(
name="chatglm3",
format_user=StringFormatter(container=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(container=["\n" "{{content}}"]),
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(
container=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
),
format_observation=StringFormatter(container=[{"token": "<|observation|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(container=["{{name}}\n{{arguments}}"]),
system=(
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(slots=[{"token": "<|observation|>"}, "\n", "{{content}}"]),
default_system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown."
),
@ -335,24 +350,30 @@ register_template(
register_template(
name="codegeex2", format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"])
name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
force_system=True,
)
register_template(name="deepseek", format_user=StringFormatter(container=["User: {{content}}\n\nAssistant:"]))
register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
register_template(
name="deepseekcoder",
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n### Response:\n"]),
format_assistant=StringFormatter(container=["{{content}}"]),
system=(
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n", {"token": "<|EOT|>"}, "\n"]),
default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
),
separator=["\n", {"token": "<|EOT|>"}, "\n"],
stop_words=["<|EOT|>"],
efficient_eos=True,
)
@ -360,29 +381,23 @@ register_template(
register_template(
name="default",
format_user=StringFormatter(container=["Human: {{content}}\nAssistant: "]),
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
),
separator=["\n"],
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
format_separator=EmptyFormatter(slots=["\n"]),
)
register_template(
name="falcon",
format_user=StringFormatter(container=["User: {{content}}\nFalcon:"]),
format_assistant=StringFormatter(container=["{{content}}"]),
separator=["\n"],
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
format_separator=EmptyFormatter(slots=["\n"]),
efficient_eos=True,
)
register_template(
name="intern",
format_user=StringFormatter(container=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
format_assistant=StringFormatter(container=["{{content}}"]),
separator=[{"token": "<eoa>"}, "\n"],
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
stop_words=["<eoa>"],
efficient_eos=True,
)
@ -390,38 +405,26 @@ register_template(
register_template(
name="intern2",
format_user=StringFormatter(
container=[
{"token": "[UNUSED_TOKEN_146]"},
"user\n{{content}}",
{"token": "[UNUSED_TOKEN_145]"},
"\n",
{"token": "[UNUSED_TOKEN_146]"},
"assistant\n",
]
),
format_assistant=StringFormatter(container=["{{content}}"]),
format_system=StringFormatter(
container=[{"token": "[UNUSED_TOKEN_146]"}, "system\n{{content}}", {"token": "[UNUSED_TOKEN_145]"}, "\n"]
),
system=(
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed "
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
"by the user such as English and 中文."
),
separator=[{"token": "[UNUSED_TOKEN_145]"}, "\n"],
stop_words=["[UNUSED_TOKEN_145]"],
efficient_eos=True,
stop_words=["<|im_end|>"],
replace_eos=True,
)
register_template(
name="llama2",
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
system=(
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
@ -436,51 +439,60 @@ register_template(
register_template(
name="llama2_zh",
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
system="You are a helpful assistant. 你是一个乐于助人的助手。",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
)
register_template(name="mistral", format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]))
register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
register_template(
name="openchat",
format_user=StringFormatter(
container=["GPT4 Correct User: {{content}}", {"token": "<|end_of_turn|>"}, "GPT4 Correct Assistant:"]
slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]
),
format_assistant=StringFormatter(container=["{{content}}"]),
separator=[{"token": "<|end_of_turn|>"}],
stop_words=["<|end_of_turn|>"],
efficient_eos=True,
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
register_template(
name="qwen",
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(container=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
system="You are a helpful assistant.",
separator=["\n"],
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
)
register_template(name="solar", format_user=StringFormatter(container=["### User:\n{{content}}\n\n### Assistant:\n"]))
register_template(
name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
efficient_eos=True,
)
register_template(
name="starchat",
format_user=StringFormatter(
container=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
slots=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
),
format_assistant=StringFormatter(container=["{{content}}"]),
format_system=StringFormatter(container=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
separator=[{"token": "<|end|>"}, "\n"],
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"],
efficient_eos=True,
replace_eos=True,
force_system=True,
)
@ -489,8 +501,8 @@ register_template(name="vanilla")
register_template(
name="vicuna",
format_user=StringFormatter(container=["USER: {{content}} ASSISTANT:"]),
system=(
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
@ -499,8 +511,8 @@ register_template(
register_template(
name="xuanyuan",
format_user=StringFormatter(container=["Human: {{content}} Assistant:"]),
system=(
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
default_system=(
"以下是用户和人工智能助手之间的对话。用户以Human开头人工智能助手以Assistant开头"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
@ -508,14 +520,15 @@ register_template(
)
register_template(name="xverse", format_user=StringFormatter(container=["Human: {{content}}\n\nAssistant: "]))
register_template(name="xverse", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]))
register_template(
name="yayi",
format_user=StringFormatter(container=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_system=StringFormatter(container=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
system=(
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. "
@ -526,15 +539,14 @@ register_template(
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
separator=["\n\n"],
stop_words=["<|End|>"],
)
register_template(
name="yi",
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
separator=["\n"],
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
)
@ -542,8 +554,8 @@ register_template(
register_template(
name="yuan",
format_user=StringFormatter(container=["{{content}}", {"token": "<sep>"}]),
separator=["\n"],
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<eod>"],
replace_eos=True,
)
@ -551,18 +563,14 @@ register_template(
register_template(
name="zephyr",
format_user=StringFormatter(container=["<|user|>\n{{content}}</s><|assistant|>"]),
format_system=StringFormatter(
container=[
"<|system|>\n{{content}}</s>",
]
),
system="You are a friendly chatbot who always responds in the style of a pirate",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are a friendly chatbot who always responds in the style of a pirate",
)
register_template(
name="ziya",
format_user=StringFormatter(container=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
separator=["\n"],
format_user=StringFormatter(slots=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
format_separator=EmptyFormatter(slots=["\n"]),
)

View File

@ -38,10 +38,10 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
def infer_max_len(source_len: int, target_len: int, cutoff_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = cutoff_len - max_target_len
return max_source_len, max_target_len