fix #2260
This commit is contained in:
parent
71cfdc2658
commit
cf818a2598
|
@ -13,7 +13,7 @@ except ImportError:
|
|||
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
history = []
|
||||
messages = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
while True:
|
||||
|
@ -37,12 +37,13 @@ def main():
|
|||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in chat_model.stream_chat(query, history):
|
||||
for new_text in chat_model.stream_chat(messages):
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
print()
|
||||
|
||||
history = history + [(query, response)]
|
||||
messages.append({"role": "user", "content": query})
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import Role, get_template_and_fix_tokenizer
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..hparams import get_infer_args
|
||||
from ..model import dispatch_model, load_model_and_tokenizer
|
||||
|
@ -32,20 +32,11 @@ class ChatModel:
|
|||
|
||||
def _process_args(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
messages = []
|
||||
if history is not None:
|
||||
for old_prompt, old_response in history:
|
||||
messages.append({"role": Role.USER, "content": old_prompt})
|
||||
messages.append({"role": Role.ASSISTANT, "content": old_response})
|
||||
|
||||
messages.append({"role": Role.USER, "content": query})
|
||||
messages.append({"role": Role.ASSISTANT, "content": ""})
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
|
||||
)
|
||||
|
@ -97,18 +88,12 @@ class ChatModel:
|
|||
@torch.inference_mode()
|
||||
def chat(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> List[Response]:
|
||||
r"""
|
||||
Args: query, history, system, **input_kwargs
|
||||
|
||||
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
|
||||
"""
|
||||
gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs)
|
||||
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
|
||||
generate_output = self.model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = self.tokenizer.batch_decode(
|
||||
|
@ -132,13 +117,12 @@ class ChatModel:
|
|||
@torch.inference_mode()
|
||||
def stream_chat(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
|
||||
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Set, Sequence, Tuple, Union
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
|
||||
|
||||
JSON_FORMAT_PROMPT = (
|
||||
|
@ -18,56 +23,7 @@ TOOL_SYSTEM_PROMPT = (
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StringFormatter:
|
||||
container: List[Union[str, Dict[str, str]]]
|
||||
|
||||
def __call__(self, **kwargs) -> List[Union[str, Dict[str, str]]]:
|
||||
elements = []
|
||||
for elem in self.container:
|
||||
if isinstance(elem, str):
|
||||
for name, value in kwargs.items():
|
||||
elem = elem.replace("{{" + name + "}}", value)
|
||||
elements.append(elem)
|
||||
elif isinstance(elem, (dict, set)):
|
||||
elements.append(elem)
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionFormatter:
|
||||
container: List[Union[str, Dict[str, str]]]
|
||||
|
||||
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
|
||||
try:
|
||||
function = json.loads(content)
|
||||
name = function["name"]
|
||||
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
||||
except Exception:
|
||||
name, arguments = "", ""
|
||||
|
||||
elements = []
|
||||
for elem in self.container:
|
||||
if isinstance(elem, str):
|
||||
elem = elem.replace("{{name}}", name)
|
||||
elem = elem.replace("{{arguments}}", arguments)
|
||||
elements.append(elem)
|
||||
elif isinstance(elem, (dict, set)):
|
||||
elements.append(elem)
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolFormatter:
|
||||
type: Literal["default"]
|
||||
|
||||
def _default(self, tools: List[Dict[str, Any]]) -> str:
|
||||
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
|
@ -92,13 +48,98 @@ class ToolFormatter:
|
|||
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
||||
)
|
||||
|
||||
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
|
||||
|
||||
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
|
||||
action_match = re.search(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
tool_name = action_match.group(1).strip()
|
||||
tool_input = action_match.group(2).strip().strip('"').strip("```")
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return tool_name, json.dumps(arguments, ensure_ascii=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Literal["default"] = "default"
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmptyFormatter(Formatter):
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
return self.slots
|
||||
|
||||
|
||||
@dataclass
|
||||
class StringFormatter(Formatter):
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
for name, value in kwargs.items():
|
||||
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionFormatter(Formatter):
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
function = json.loads(content)
|
||||
name = function["name"]
|
||||
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
||||
except Exception:
|
||||
name, arguments = "", ""
|
||||
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolFormatter(Formatter):
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
tools = json.loads(content)
|
||||
if not len(tools):
|
||||
return [""]
|
||||
|
||||
if self.type == "default":
|
||||
return [self._default(tools)]
|
||||
if self.tool_format == "default":
|
||||
return [default_tool_formatter(tools)]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
except Exception:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
if self.tool_format == "default":
|
||||
return default_tool_extractor(content)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,31 +1,34 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .formatter import FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .utils import Role, infer_max_len
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .formatter import Formatter
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
format_user: Callable
|
||||
format_assistant: Callable
|
||||
format_system: Callable
|
||||
format_tool: Callable
|
||||
format_observation: Callable
|
||||
format_function: Callable
|
||||
system: str
|
||||
separator: List[Union[str, Dict[str, str]]]
|
||||
format_user: "Formatter"
|
||||
format_assistant: "Formatter"
|
||||
format_system: "Formatter"
|
||||
format_function: "Formatter"
|
||||
format_observation: "Formatter"
|
||||
format_tools: "Formatter"
|
||||
format_separator: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
force_system: bool
|
||||
|
||||
def encode_oneturn(
|
||||
self,
|
||||
|
@ -34,14 +37,15 @@ class Template:
|
|||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
reserved_label_len: Optional[int] = 16,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
prompt_ids = []
|
||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||
prompt_ids = prompt_ids + query_ids + resp_ids
|
||||
prompt_ids += query_ids + resp_ids
|
||||
prompt_ids = prompt_ids + encoded_pairs[-1][0]
|
||||
answer_ids = encoded_pairs[-1][1]
|
||||
return prompt_ids, answer_ids
|
||||
|
@ -50,15 +54,15 @@ class Template:
|
|||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
reserved_label_len: Optional[int] = 16,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
|
||||
return encoded_pairs
|
||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
|
@ -67,48 +71,37 @@ class Template:
|
|||
system: str,
|
||||
tools: str,
|
||||
cutoff_len: int,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: system + query resp + eos
|
||||
Turn t: sep + query resp + eos
|
||||
Turn 0: system + query resp
|
||||
Turn t: sep + query resp
|
||||
"""
|
||||
system = system or self.system
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
if i == 0 and (system or tools):
|
||||
tool_text = self.format_tool(content=tools)[0] if tools else ""
|
||||
elements += self.format_system(content=(system + tool_text))
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.separator
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
elements += self.format_user(content=message["content"], idx=str(i // 2))
|
||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elements += self.format_assistant(content=message["content"])
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elements += self.format_observation(content=message["content"])
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elements += self.format_function(content=message["content"])
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
# TODO: need to improve
|
||||
encoded_pairs = []
|
||||
total_length = 0
|
||||
for i in range(0, len(encoded_messages), 2):
|
||||
if total_length >= cutoff_len:
|
||||
break
|
||||
|
||||
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
|
||||
total_length += len(encoded_messages[i])
|
||||
|
||||
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
|
||||
total_length += len(encoded_messages[i + 1])
|
||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||
|
||||
return encoded_pairs
|
||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
|
||||
def _convert_elements_to_ids(
|
||||
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
|
||||
|
@ -120,19 +113,44 @@ class Template:
|
|||
for elem in elements:
|
||||
if isinstance(elem, str):
|
||||
if len(elem) != 0:
|
||||
token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False)
|
||||
token_ids += tokenizer.encode(elem, add_special_tokens=False)
|
||||
elif isinstance(elem, dict):
|
||||
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||
elif isinstance(elem, set):
|
||||
if "bos_token" in elem and tokenizer.bos_token_id:
|
||||
token_ids = token_ids + [tokenizer.bos_token_id]
|
||||
token_ids += [tokenizer.bos_token_id]
|
||||
elif "eos_token" in elem and tokenizer.eos_token_id:
|
||||
token_ids = token_ids + [tokenizer.eos_token_id]
|
||||
token_ids += [tokenizer.eos_token_id]
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||
|
||||
return token_ids
|
||||
|
||||
def _make_pairs(
|
||||
self,
|
||||
encoded_messages: Sequence[List[int]],
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
encoded_pairs = []
|
||||
total_length = 0
|
||||
for i in range(0, len(encoded_messages), 2):
|
||||
if total_length >= cutoff_len:
|
||||
break
|
||||
|
||||
max_source_len, max_target_len = infer_max_len(
|
||||
source_len=len(encoded_messages[i]),
|
||||
target_len=len(encoded_messages[i + 1]),
|
||||
cutoff_len=(cutoff_len - total_length),
|
||||
reserved_label_len=reserved_label_len,
|
||||
)
|
||||
encoded_messages[i] = encoded_messages[i][: max_source_len]
|
||||
encoded_messages[i + 1] = encoded_messages[i + 1][: max_target_len]
|
||||
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
|
||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||
|
||||
return encoded_pairs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
|
@ -143,49 +161,38 @@ class Llama2Template(Template):
|
|||
system: str,
|
||||
tools: str,
|
||||
cutoff_len: int,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: system + query resp + eos
|
||||
Turn t: sep + query resp + eos
|
||||
Turn 0: system + query resp
|
||||
Turn t: sep + query resp
|
||||
"""
|
||||
system = system or self.system
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
system_text = ""
|
||||
if i == 0 and (system or tools):
|
||||
tool_text = self.format_tool(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system(content=(system + tool_text))[0]
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.separator
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
elements += self.format_user(content=system_text + message["content"], idx=str(i // 2))
|
||||
elements += self.format_user.apply(content=system_text + message["content"])
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elements += self.format_assistant(content=message["content"])
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elements += self.format_observation(content=message["content"])
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elements += self.format_function(content=message["content"])
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
# TODO: need to improve
|
||||
encoded_pairs = []
|
||||
total_length = 0
|
||||
for i in range(0, len(encoded_messages), 2):
|
||||
if total_length >= cutoff_len:
|
||||
break
|
||||
|
||||
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
|
||||
total_length += len(encoded_messages[i])
|
||||
|
||||
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
|
||||
total_length += len(encoded_messages[i + 1])
|
||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||
|
||||
return encoded_pairs
|
||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
|
||||
|
||||
templates: Dict[str, Template] = {}
|
||||
|
@ -193,32 +200,39 @@ templates: Dict[str, Template] = {}
|
|||
|
||||
def register_template(
|
||||
name: str,
|
||||
format_user: Optional[Callable] = None,
|
||||
format_assistant: Optional[Callable] = None,
|
||||
format_system: Optional[Callable] = None,
|
||||
format_tool: Optional[Callable] = None,
|
||||
format_observation: Optional[Callable] = None,
|
||||
format_function: Optional[Callable] = None,
|
||||
system: Optional[str] = "",
|
||||
separator: Optional[List[Union[str, Dict[str, str]]]] = "",
|
||||
format_user: Optional["Formatter"] = None,
|
||||
format_assistant: Optional["Formatter"] = None,
|
||||
format_system: Optional["Formatter"] = None,
|
||||
format_function: Optional["Formatter"] = None,
|
||||
format_observation: Optional["Formatter"] = None,
|
||||
format_tools: Optional["Formatter"] = None,
|
||||
format_separator: Optional["Formatter"] = None,
|
||||
default_system: Optional[str] = "",
|
||||
stop_words: Optional[List[str]] = [],
|
||||
efficient_eos: Optional[bool] = False,
|
||||
replace_eos: Optional[bool] = False,
|
||||
force_system: Optional[bool] = False,
|
||||
) -> None:
|
||||
eos_slots = [] if efficient_eos else [{"eos_token"}]
|
||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||
default_tool_formatter = ToolFormatter(slots="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
templates[name] = template_class(
|
||||
format_user=format_user or StringFormatter(container=["{{content}}"]),
|
||||
format_assistant=format_assistant or StringFormatter(container=["{{content}}", {"eos_token"}]),
|
||||
format_system=format_system or StringFormatter(container=["{{content}}"]),
|
||||
format_tool=format_tool or ToolFormatter(type="default"),
|
||||
format_observation=format_observation or format_user,
|
||||
format_function=format_function
|
||||
or FunctionFormatter(container=["Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}]),
|
||||
system=system,
|
||||
separator=separator,
|
||||
format_user=format_user or default_user_formatter,
|
||||
format_assistant=format_assistant or default_assistant_formatter,
|
||||
format_system=format_system or default_user_formatter,
|
||||
format_function=format_function or default_function_formatter,
|
||||
format_observation=format_observation or format_user or default_user_formatter,
|
||||
format_tools=format_tools or default_tool_formatter,
|
||||
format_separator=format_separator or default_separator_formatter,
|
||||
default_system=default_system,
|
||||
stop_words=stop_words,
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
force_system=force_system,
|
||||
)
|
||||
|
||||
|
||||
|
@ -257,23 +271,22 @@ def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer")
|
|||
|
||||
register_template(
|
||||
name="alpaca",
|
||||
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
system=(
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
|
||||
),
|
||||
separator=["\n\n"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="aquila",
|
||||
format_user=StringFormatter(container=["Human: {{content}}###Assistant:"]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
system=(
|
||||
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
||||
format_separator=EmptyFormatter(slots=["###"]),
|
||||
default_system=(
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
),
|
||||
separator=["###"],
|
||||
stop_words=["</s>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
@ -281,51 +294,53 @@ register_template(
|
|||
|
||||
register_template(
|
||||
name="baichuan",
|
||||
format_user=StringFormatter(container=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="baichuan2",
|
||||
format_user=StringFormatter(container=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="belle", format_user=StringFormatter(container=["Human: {{content}}\n\nBelle: "]), separator=["\n\n"]
|
||||
name="belle",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="bluelm",
|
||||
format_user=StringFormatter(container=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(container=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
separator=["\n\n"],
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="chatglm3",
|
||||
format_user=StringFormatter(container=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(container=["\n" "{{content}}"]),
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
container=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
||||
),
|
||||
format_observation=StringFormatter(container=[{"token": "<|observation|>"}, "\n", "{{content}}"]),
|
||||
format_function=FunctionFormatter(container=["{{name}}\n{{arguments}}"]),
|
||||
system=(
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(slots=[{"token": "<|observation|>"}, "\n", "{{content}}"]),
|
||||
default_system=(
|
||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
||||
"Follow the user's instructions carefully. Respond using markdown."
|
||||
),
|
||||
|
@ -335,24 +350,30 @@ register_template(
|
|||
|
||||
|
||||
register_template(
|
||||
name="codegeex2", format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"])
|
||||
name="codegeex2",
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(name="deepseek", format_user=StringFormatter(container=["User: {{content}}\n\nAssistant:"]))
|
||||
register_template(
|
||||
name="deepseek",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="deepseekcoder",
|
||||
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n### Response:\n"]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
system=(
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n", {"token": "<|EOT|>"}, "\n"]),
|
||||
default_system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||
"For politically sensitive questions, security and privacy issues, "
|
||||
"and other non-computer science questions, you will refuse to answer\n"
|
||||
),
|
||||
separator=["\n", {"token": "<|EOT|>"}, "\n"],
|
||||
stop_words=["<|EOT|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
@ -360,29 +381,23 @@ register_template(
|
|||
|
||||
register_template(
|
||||
name="default",
|
||||
format_user=StringFormatter(container=["Human: {{content}}\nAssistant: "]),
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
),
|
||||
separator=["\n"],
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(container=["User: {{content}}\nFalcon:"]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
separator=["\n"],
|
||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(container=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
separator=[{"token": "<eoa>"}, "\n"],
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
|
||||
stop_words=["<eoa>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
@ -390,38 +405,26 @@ register_template(
|
|||
|
||||
register_template(
|
||||
name="intern2",
|
||||
format_user=StringFormatter(
|
||||
container=[
|
||||
{"token": "[UNUSED_TOKEN_146]"},
|
||||
"user\n{{content}}",
|
||||
{"token": "[UNUSED_TOKEN_145]"},
|
||||
"\n",
|
||||
{"token": "[UNUSED_TOKEN_146]"},
|
||||
"assistant\n",
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
container=[{"token": "[UNUSED_TOKEN_146]"}, "system\n{{content}}", {"token": "[UNUSED_TOKEN_145]"}, "\n"]
|
||||
),
|
||||
system=(
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||
"- InternLM (书生·浦语) is a conversational language model that is developed "
|
||||
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
|
||||
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
|
||||
"by the user such as English and 中文."
|
||||
),
|
||||
separator=[{"token": "[UNUSED_TOKEN_145]"}, "\n"],
|
||||
stop_words=["[UNUSED_TOKEN_145]"],
|
||||
efficient_eos=True,
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
system=(
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
|
@ -436,51 +439,60 @@ register_template(
|
|||
|
||||
register_template(
|
||||
name="llama2_zh",
|
||||
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||
)
|
||||
|
||||
|
||||
register_template(name="mistral", format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]))
|
||||
register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(
|
||||
container=["GPT4 Correct User: {{content}}", {"token": "<|end_of_turn|>"}, "GPT4 Correct Assistant:"]
|
||||
slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]
|
||||
),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
separator=[{"token": "<|end_of_turn|>"}],
|
||||
stop_words=["<|end_of_turn|>"],
|
||||
efficient_eos=True,
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(container=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
system="You are a helpful assistant.",
|
||||
separator=["\n"],
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(name="solar", format_user=StringFormatter(container=["### User:\n{{content}}\n\n### Assistant:\n"]))
|
||||
register_template(
|
||||
name="solar",
|
||||
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
|
||||
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="starchat",
|
||||
format_user=StringFormatter(
|
||||
container=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
|
||||
slots=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
|
||||
),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
format_system=StringFormatter(container=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
|
||||
separator=[{"token": "<|end|>"}, "\n"],
|
||||
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|end|>"],
|
||||
efficient_eos=True,
|
||||
replace_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -489,8 +501,8 @@ register_template(name="vanilla")
|
|||
|
||||
register_template(
|
||||
name="vicuna",
|
||||
format_user=StringFormatter(container=["USER: {{content}} ASSISTANT:"]),
|
||||
system=(
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
|
@ -499,8 +511,8 @@ register_template(
|
|||
|
||||
register_template(
|
||||
name="xuanyuan",
|
||||
format_user=StringFormatter(container=["Human: {{content}} Assistant:"]),
|
||||
system=(
|
||||
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
||||
default_system=(
|
||||
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
||||
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
||||
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
||||
|
@ -508,14 +520,15 @@ register_template(
|
|||
)
|
||||
|
||||
|
||||
register_template(name="xverse", format_user=StringFormatter(container=["Human: {{content}}\n\nAssistant: "]))
|
||||
register_template(name="xverse", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]))
|
||||
|
||||
|
||||
register_template(
|
||||
name="yayi",
|
||||
format_user=StringFormatter(container=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||
format_system=StringFormatter(container=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||
system=(
|
||||
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant named YaYi "
|
||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
|
@ -526,15 +539,14 @@ register_template(
|
|||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
separator=["\n\n"],
|
||||
stop_words=["<|End|>"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="yi",
|
||||
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
separator=["\n"],
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
@ -542,8 +554,8 @@ register_template(
|
|||
|
||||
register_template(
|
||||
name="yuan",
|
||||
format_user=StringFormatter(container=["{{content}}", {"token": "<sep>"}]),
|
||||
separator=["\n"],
|
||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<eod>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
@ -551,18 +563,14 @@ register_template(
|
|||
|
||||
register_template(
|
||||
name="zephyr",
|
||||
format_user=StringFormatter(container=["<|user|>\n{{content}}</s><|assistant|>"]),
|
||||
format_system=StringFormatter(
|
||||
container=[
|
||||
"<|system|>\n{{content}}</s>",
|
||||
]
|
||||
),
|
||||
system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
default_system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="ziya",
|
||||
format_user=StringFormatter(container=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
|
||||
separator=["\n"],
|
||||
format_user=StringFormatter(slots=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
|
|
|
@ -38,10 +38,10 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
|||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||
|
||||
|
||||
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
|
||||
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
|
||||
max_target_len = max(max_target_len, data_args.reserved_label_len)
|
||||
max_source_len = data_args.cutoff_len - max_target_len
|
||||
def infer_max_len(source_len: int, target_len: int, cutoff_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
||||
max_target_len = max(max_target_len, reserved_label_len)
|
||||
max_source_len = cutoff_len - max_target_len
|
||||
return max_source_len, max_target_len
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue