parent
d4e2af1fa4
commit
1771251ce3
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets
|
from datasets import concatenate_datasets, interleave_datasets
|
||||||
|
|
||||||
|
@ -30,6 +30,9 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
class Role(str, Enum):
|
class Role(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
@ -39,13 +42,6 @@ class Role(str, Enum):
|
||||||
OBSERVATION = "observation"
|
OBSERVATION = "observation"
|
||||||
|
|
||||||
|
|
||||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
|
||||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
|
||||||
max_target_len = max(max_target_len, reserved_label_len)
|
|
||||||
max_source_len = max_len - min(max_target_len, target_len)
|
|
||||||
return max_source_len, max_target_len
|
|
||||||
|
|
||||||
|
|
||||||
def merge_dataset(
|
def merge_dataset(
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
|
|
@ -16,97 +16,10 @@ import json
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
|
from typing import List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from .data_utils import SLOTS
|
||||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TOOL_PROMPT = (
|
|
||||||
"You have access to the following tools:\n{tool_text}"
|
|
||||||
"Use the following format if using a tool:\n"
|
|
||||||
"```\n"
|
|
||||||
"Action: tool name (one of [{tool_names}]).\n"
|
|
||||||
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
|
||||||
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
|
|
||||||
"```\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
GLM4_TOOL_PROMPT = (
|
|
||||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
|
||||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
|
||||||
tool_text = ""
|
|
||||||
tool_names = []
|
|
||||||
for tool in tools:
|
|
||||||
param_text = ""
|
|
||||||
for name, param in tool["parameters"]["properties"].items():
|
|
||||||
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
|
||||||
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
|
||||||
items = (
|
|
||||||
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
|
|
||||||
)
|
|
||||||
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
|
||||||
name=name,
|
|
||||||
type=param.get("type", ""),
|
|
||||||
required=required,
|
|
||||||
desc=param.get("description", ""),
|
|
||||||
enum=enum,
|
|
||||||
items=items,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
|
||||||
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
|
||||||
)
|
|
||||||
tool_names.append(tool["name"])
|
|
||||||
|
|
||||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
|
||||||
|
|
||||||
|
|
||||||
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
|
||||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
|
||||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
|
||||||
if not action_match:
|
|
||||||
return content
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for match in action_match:
|
|
||||||
tool_name = match[0].strip()
|
|
||||||
tool_input = match[1].strip().strip('"').strip("```")
|
|
||||||
try:
|
|
||||||
arguments = json.loads(tool_input)
|
|
||||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return content
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
|
||||||
tool_text = ""
|
|
||||||
for tool in tools:
|
|
||||||
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
|
||||||
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
|
||||||
|
|
||||||
|
|
||||||
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
|
||||||
if "\n" not in content:
|
|
||||||
return content
|
|
||||||
|
|
||||||
tool_name, tool_input = content.split("\n", maxsplit=1)
|
|
||||||
try:
|
|
||||||
arguments = json.loads(tool_input)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return content
|
|
||||||
|
|
||||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -168,15 +81,12 @@ class StringFormatter(Formatter):
|
||||||
@dataclass
|
@dataclass
|
||||||
class FunctionFormatter(Formatter):
|
class FunctionFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
has_name, has_args = False, False
|
if self.tool_format == "default":
|
||||||
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
self.slots = DefaultToolUtils.get_function_slots() + self.slots
|
||||||
if "{{name}}" in slot:
|
elif self.tool_format == "glm4":
|
||||||
has_name = True
|
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
|
||||||
if "{{arguments}}" in slot:
|
else:
|
||||||
has_args = True
|
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||||
|
|
||||||
if not has_name or not has_args:
|
|
||||||
raise ValueError("Name and arguments placeholders are required in the function formatter.")
|
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
|
@ -210,11 +120,11 @@ class FunctionFormatter(Formatter):
|
||||||
class ToolFormatter(Formatter):
|
class ToolFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tool_format == "default":
|
if self.tool_format == "default":
|
||||||
self._tool_formatter = default_tool_formatter
|
self._tool_formatter = DefaultToolUtils.tool_formatter
|
||||||
self._tool_extractor = default_tool_extractor
|
self._tool_extractor = DefaultToolUtils.tool_extractor
|
||||||
elif self.tool_format == "glm4":
|
elif self.tool_format == "glm4":
|
||||||
self._tool_formatter = glm4_tool_formatter
|
self._tool_formatter = GLM4ToolUtils.tool_formatter
|
||||||
self._tool_extractor = glm4_tool_extractor
|
self._tool_extractor = GLM4ToolUtils.tool_extractor
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -55,12 +55,8 @@ def _encode_feedback_example(
|
||||||
else:
|
else:
|
||||||
kl_messages = prompt + [kl_response[1]]
|
kl_messages = prompt + [kl_response[1]]
|
||||||
|
|
||||||
prompt_ids, response_ids = template.encode_oneturn(
|
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
_, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
||||||
)
|
|
||||||
_, kl_response_ids = template.encode_oneturn(
|
|
||||||
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
|
||||||
)
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
response_ids += [tokenizer.eos_token_id]
|
response_ids += [tokenizer.eos_token_id]
|
||||||
|
@ -70,6 +66,12 @@ def _encode_feedback_example(
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||||
|
|
||||||
|
# do not consider the kl_response
|
||||||
|
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
|
||||||
|
prompt_ids = prompt_ids[:source_len]
|
||||||
|
response_ids = response_ids[:target_len]
|
||||||
|
kl_response_ids = kl_response_ids[:target_len]
|
||||||
|
|
||||||
input_ids = prompt_ids + response_ids
|
input_ids = prompt_ids + response_ids
|
||||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||||
kl_input_ids = prompt_ids + kl_response_ids
|
kl_input_ids = prompt_ids + kl_response_ids
|
||||||
|
|
|
@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -44,12 +44,8 @@ def _encode_pairwise_example(
|
||||||
|
|
||||||
chosen_messages = prompt + [response[0]]
|
chosen_messages = prompt + [response[0]]
|
||||||
rejected_messages = prompt + [response[1]]
|
rejected_messages = prompt + [response[1]]
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||||
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
||||||
)
|
|
||||||
_, rejected_ids = template.encode_oneturn(
|
|
||||||
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
|
||||||
)
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
chosen_ids += [tokenizer.eos_token_id]
|
chosen_ids += [tokenizer.eos_token_id]
|
||||||
|
@ -59,6 +55,13 @@ def _encode_pairwise_example(
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||||
|
|
||||||
|
source_len, target_len = infer_seqlen(
|
||||||
|
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
|
||||||
|
) # consider the response is more important
|
||||||
|
prompt_ids = prompt_ids[:source_len]
|
||||||
|
chosen_ids = chosen_ids[:target_len]
|
||||||
|
rejected_ids = rejected_ids[:target_len]
|
||||||
|
|
||||||
chosen_input_ids = prompt_ids + chosen_ids
|
chosen_input_ids = prompt_ids + chosen_ids
|
||||||
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
||||||
rejected_input_ids = prompt_ids + rejected_ids
|
rejected_input_ids = prompt_ids + rejected_ids
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from typing import TYPE_CHECKING, List, Sequence
|
from typing import TYPE_CHECKING, List, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.packages import is_pillow_available
|
from ...extras.packages import is_pillow_available
|
||||||
|
|
||||||
|
@ -76,3 +76,16 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
|
||||||
"""
|
"""
|
||||||
image_seq_length = getattr(processor, "image_seq_length")
|
image_seq_length = getattr(processor, "image_seq_length")
|
||||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
||||||
|
if target_len * 2 < cutoff_len: # truncate source
|
||||||
|
max_target_len = cutoff_len
|
||||||
|
elif source_len * 2 < cutoff_len: # truncate target
|
||||||
|
max_target_len = cutoff_len - source_len
|
||||||
|
else: # truncate both
|
||||||
|
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
||||||
|
|
||||||
|
new_target_len = min(max_target_len, target_len)
|
||||||
|
new_source_len = max(cutoff_len - new_target_len, 0)
|
||||||
|
return new_source_len, new_target_len
|
||||||
|
|
|
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
|
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -51,10 +51,17 @@ def _encode_supervised_example(
|
||||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||||
|
|
||||||
encoded_pairs = template.encode_multiturn(
|
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
total_length = 1 if template.efficient_eos else 0
|
||||||
)
|
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||||
|
if total_length >= data_args.cutoff_len:
|
||||||
|
break
|
||||||
|
|
||||||
|
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
|
||||||
|
source_ids = source_ids[:source_len]
|
||||||
|
target_ids = target_ids[:target_len]
|
||||||
|
total_length += source_len + target_len
|
||||||
|
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
|
|
|
@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ..data_utils import Role
|
from ..data_utils import Role
|
||||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -47,9 +47,7 @@ def _encode_unsupervised_example(
|
||||||
else:
|
else:
|
||||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(
|
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
|
||||||
)
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
labels += [tokenizer.eos_token_id]
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
@ -57,6 +55,9 @@ def _encode_unsupervised_example(
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||||
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
||||||
|
|
||||||
|
source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
|
||||||
|
input_ids = input_ids[:source_len]
|
||||||
|
labels = labels[:target_len]
|
||||||
return input_ids, labels
|
return input_ids, labels
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .data_utils import Role, infer_max_len
|
from .data_utils import Role
|
||||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,36 +48,33 @@ class Template:
|
||||||
def encode_oneturn(
|
def encode_oneturn(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: int = 1_000_000,
|
|
||||||
reserved_label_len: int = 1,
|
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
r"""
|
r"""
|
||||||
Returns a single pair of token ids representing prompt and response respectively.
|
Returns a single pair of token ids representing prompt and response respectively.
|
||||||
"""
|
"""
|
||||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||||
prompt_ids = []
|
prompt_ids = []
|
||||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
for encoded_ids in encoded_messages[:-1]:
|
||||||
prompt_ids += query_ids + resp_ids
|
prompt_ids += encoded_ids
|
||||||
prompt_ids = prompt_ids + encoded_pairs[-1][0]
|
|
||||||
answer_ids = encoded_pairs[-1][1]
|
answer_ids = encoded_messages[-1]
|
||||||
return prompt_ids, answer_ids
|
return prompt_ids, answer_ids
|
||||||
|
|
||||||
def encode_multiturn(
|
def encode_multiturn(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: int = 1_000_000,
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
reserved_label_len: int = 1,
|
|
||||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
|
||||||
r"""
|
r"""
|
||||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
"""
|
"""
|
||||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||||
|
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||||
|
|
||||||
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -88,16 +85,14 @@ class Template:
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str],
|
system: Optional[str],
|
||||||
tools: Optional[str],
|
tools: Optional[str],
|
||||||
cutoff_len: int,
|
) -> List[List[int]]:
|
||||||
reserved_label_len: int,
|
|
||||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
|
||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
Turn 0: system + query resp
|
Turn 0: prefix + system + query resp
|
||||||
Turn t: sep + query resp
|
Turn t: sep + query resp
|
||||||
"""
|
"""
|
||||||
system = system or self.default_system
|
system = system or self.default_system
|
||||||
encoded_messages = []
|
encoded_messages = []
|
||||||
|
@ -106,10 +101,9 @@ class Template:
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
elements += self.format_prefix.apply()
|
elements += self.format_prefix.apply()
|
||||||
|
if system or tools:
|
||||||
if i == 0 and (system or tools):
|
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
elements += self.format_system.apply(content=(system + tool_text))
|
||||||
elements += self.format_system.apply(content=(system + tool_text))
|
|
||||||
|
|
||||||
if i > 0 and i % 2 == 0:
|
if i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
@ -127,11 +121,9 @@ class Template:
|
||||||
|
|
||||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
return encoded_messages
|
||||||
|
|
||||||
def _convert_elements_to_ids(
|
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
|
||||||
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
|
|
||||||
) -> List[int]:
|
|
||||||
r"""
|
r"""
|
||||||
Converts elements to token ids.
|
Converts elements to token ids.
|
||||||
"""
|
"""
|
||||||
|
@ -152,60 +144,32 @@ class Template:
|
||||||
|
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
def _make_pairs(
|
|
||||||
self,
|
|
||||||
encoded_messages: Sequence[List[int]],
|
|
||||||
cutoff_len: int,
|
|
||||||
reserved_label_len: int,
|
|
||||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
|
||||||
encoded_pairs = []
|
|
||||||
total_length = 0
|
|
||||||
for i in range(0, len(encoded_messages), 2):
|
|
||||||
if total_length >= cutoff_len:
|
|
||||||
break
|
|
||||||
|
|
||||||
max_source_len, max_target_len = infer_max_len(
|
|
||||||
source_len=len(encoded_messages[i]),
|
|
||||||
target_len=len(encoded_messages[i + 1]),
|
|
||||||
max_len=(cutoff_len - total_length),
|
|
||||||
reserved_label_len=reserved_label_len,
|
|
||||||
)
|
|
||||||
source_ids = encoded_messages[i][:max_source_len]
|
|
||||||
target_ids = encoded_messages[i + 1][:max_target_len]
|
|
||||||
total_length += len(source_ids) + len(target_ids)
|
|
||||||
encoded_pairs.append((source_ids, target_ids))
|
|
||||||
|
|
||||||
return encoded_pairs
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Llama2Template(Template):
|
class Llama2Template(Template):
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tools: str,
|
tools: str,
|
||||||
cutoff_len: int,
|
) -> List[List[int]]:
|
||||||
reserved_label_len: int,
|
|
||||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
|
||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
Turn 0: system + query resp
|
Turn 0: prefix + system + query resp
|
||||||
Turn t: sep + query resp
|
Turn t: sep + query resp
|
||||||
"""
|
"""
|
||||||
system = system or self.default_system
|
system = system or self.default_system
|
||||||
encoded_messages = []
|
encoded_messages = []
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
elements = []
|
elements = []
|
||||||
|
|
||||||
|
system_text = ""
|
||||||
if i == 0:
|
if i == 0:
|
||||||
elements += self.format_prefix.apply()
|
elements += self.format_prefix.apply()
|
||||||
|
if system or tools:
|
||||||
system_text = ""
|
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||||
if i == 0 and (system or tools):
|
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
|
||||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
|
||||||
|
|
||||||
if i > 0 and i % 2 == 0:
|
if i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
@ -223,7 +187,7 @@ class Llama2Template(Template):
|
||||||
|
|
||||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
return encoded_messages
|
||||||
|
|
||||||
|
|
||||||
TEMPLATES: Dict[str, Template] = {}
|
TEMPLATES: Dict[str, Template] = {}
|
||||||
|
@ -240,7 +204,7 @@ def _register_template(
|
||||||
format_separator: Optional["Formatter"] = None,
|
format_separator: Optional["Formatter"] = None,
|
||||||
format_prefix: Optional["Formatter"] = None,
|
format_prefix: Optional["Formatter"] = None,
|
||||||
default_system: str = "",
|
default_system: str = "",
|
||||||
stop_words: List[str] = [],
|
stop_words: Sequence[str] = [],
|
||||||
image_token: str = "<image>",
|
image_token: str = "<image>",
|
||||||
efficient_eos: bool = False,
|
efficient_eos: bool = False,
|
||||||
replace_eos: bool = False,
|
replace_eos: bool = False,
|
||||||
|
@ -275,9 +239,7 @@ def _register_template(
|
||||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||||
default_function_formatter = FunctionFormatter(
|
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
|
||||||
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
|
|
||||||
)
|
|
||||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||||
default_separator_formatter = EmptyFormatter()
|
default_separator_formatter = EmptyFormatter()
|
||||||
default_prefix_formatter = EmptyFormatter()
|
default_prefix_formatter = EmptyFormatter()
|
||||||
|
@ -390,7 +352,9 @@ def get_template_and_fix_tokenizer(
|
||||||
|
|
||||||
if tool_format is not None:
|
if tool_format is not None:
|
||||||
logger.info("Using tool format: {}.".format(tool_format))
|
logger.info("Using tool format: {}.".format(tool_format))
|
||||||
|
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||||
template.format_tools = ToolFormatter(tool_format=tool_format)
|
template.format_tools = ToolFormatter(tool_format=tool_format)
|
||||||
|
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
|
||||||
|
|
||||||
stop_words = template.stop_words
|
stop_words = template.stop_words
|
||||||
if template.replace_eos:
|
if template.replace_eos:
|
||||||
|
@ -506,10 +470,11 @@ _register_template(
|
||||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||||
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
||||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||||
),
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="glm4"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||||
stop_words=["<|user|>", "<|observation|>"],
|
stop_words=["<|user|>", "<|observation|>"],
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
|
@ -603,16 +568,15 @@ _register_template(
|
||||||
_register_template(
|
_register_template(
|
||||||
name="deepseekcoder",
|
name="deepseekcoder",
|
||||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
|
||||||
format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
default_system=(
|
default_system=(
|
||||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||||
"For politically sensitive questions, security and privacy issues, "
|
"For politically sensitive questions, security and privacy issues, "
|
||||||
"and other non-computer science questions, you will refuse to answer\n"
|
"and other non-computer science questions, you will refuse to answer\n"
|
||||||
),
|
),
|
||||||
stop_words=["<|EOT|>"],
|
|
||||||
efficient_eos=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -662,7 +626,7 @@ _register_template(
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||||
format_tools=ToolFormatter(tool_format="glm4"),
|
format_tools=ToolFormatter(tool_format="glm4"),
|
||||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||||
|
|
|
@ -0,0 +1,140 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
from .data_utils import SLOTS
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_TOOL_PROMPT = (
|
||||||
|
"You have access to the following tools:\n{tool_text}"
|
||||||
|
"Use the following format if using a tool:\n"
|
||||||
|
"```\n"
|
||||||
|
"Action: tool name (one of [{tool_names}]).\n"
|
||||||
|
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||||
|
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
|
||||||
|
"```\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
GLM4_TOOL_PROMPT = (
|
||||||
|
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||||
|
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolUtils(ABC):
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_function_slots() -> SLOTS: ...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultToolUtils(ToolUtils):
|
||||||
|
@staticmethod
|
||||||
|
def get_function_slots() -> SLOTS:
|
||||||
|
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
tool_names = []
|
||||||
|
for tool in tools:
|
||||||
|
param_text = ""
|
||||||
|
for name, param in tool["parameters"]["properties"].items():
|
||||||
|
required, enum, items = "", "", ""
|
||||||
|
if name in tool["parameters"].get("required", []):
|
||||||
|
required = ", required"
|
||||||
|
|
||||||
|
if param.get("enum", None):
|
||||||
|
enum = ", should be one of [{}]".format(", ".join(param["enum"]))
|
||||||
|
|
||||||
|
if param.get("items", None):
|
||||||
|
items = ", where each item should be {}".format(param["items"].get("type", ""))
|
||||||
|
|
||||||
|
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
||||||
|
name=name,
|
||||||
|
type=param.get("type", ""),
|
||||||
|
required=required,
|
||||||
|
desc=param.get("description", ""),
|
||||||
|
enum=enum,
|
||||||
|
items=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||||
|
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
||||||
|
)
|
||||||
|
tool_names.append(tool["name"])
|
||||||
|
|
||||||
|
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
|
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||||
|
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||||
|
if not action_match:
|
||||||
|
return content
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for match in action_match:
|
||||||
|
tool_name = match[0].strip()
|
||||||
|
tool_input = match[1].strip().strip('"').strip("```")
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tool_input)
|
||||||
|
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class GLM4ToolUtils(ToolUtils):
|
||||||
|
@staticmethod
|
||||||
|
def get_function_slots() -> SLOTS:
|
||||||
|
return ["{{name}}\n{{arguments}}"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
for tool in tools:
|
||||||
|
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||||
|
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
|
if "\n" not in content:
|
||||||
|
return content
|
||||||
|
|
||||||
|
tool_name, tool_input = content.split("\n", maxsplit=1)
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tool_input)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
|
@ -45,10 +45,6 @@ class DataArguments:
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||||
)
|
)
|
||||||
reserved_label_len: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
|
|
||||||
)
|
|
||||||
train_on_prompt: bool = field(
|
train_on_prompt: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
||||||
|
@ -111,9 +107,6 @@ class DataArguments:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.reserved_label_len >= self.cutoff_len:
|
|
||||||
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
|
|
||||||
|
|
||||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||||
raise ValueError("Streaming mode should have an integer val size.")
|
raise ValueError("Streaming mode should have an integer val size.")
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ def test_string_formatter():
|
||||||
|
|
||||||
|
|
||||||
def test_function_formatter():
|
def test_function_formatter():
|
||||||
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
|
formatter = FunctionFormatter(slots=[], tool_format="default")
|
||||||
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
||||||
assert formatter.apply(content=tool_calls) == [
|
assert formatter.apply(content=tool_calls) == [
|
||||||
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
|
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
|
||||||
|
@ -36,7 +36,7 @@ def test_function_formatter():
|
||||||
|
|
||||||
|
|
||||||
def test_multi_function_formatter():
|
def test_multi_function_formatter():
|
||||||
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
|
formatter = FunctionFormatter(slots=[], tool_format="default")
|
||||||
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
|
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
|
||||||
assert formatter.apply(content=tool_calls) == [
|
assert formatter.apply(content=tool_calls) == [
|
||||||
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
|
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llamafactory.data.processors.processor_utils import infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_input,test_output",
|
||||||
|
[
|
||||||
|
((3000, 2000, 1000), (600, 400)),
|
||||||
|
((2000, 3000, 1000), (400, 600)),
|
||||||
|
((1000, 100, 1000), (900, 100)),
|
||||||
|
((100, 1000, 1000), (100, 900)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):
|
||||||
|
assert test_output == infer_seqlen(*test_input)
|
|
@ -21,15 +21,60 @@ from llamafactory.data import get_template_and_fix_tokenizer
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
|
MESSAGES = [
|
||||||
|
{"role": "user", "content": "How are you"},
|
||||||
|
{"role": "assistant", "content": "I am fine!"},
|
||||||
|
{"role": "user", "content": "你好"},
|
||||||
|
{"role": "assistant", "content": "很高兴认识你!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_oneturn():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
|
||||||
|
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||||
|
assert tokenizer.decode(prompt_ids) == (
|
||||||
|
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert tokenizer.decode(answer_ids) == "很高兴认识你!<|eot_id|>"
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_multiturn():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
|
||||||
|
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
|
||||||
|
assert tokenizer.decode(encoded_pairs[0][0]) == (
|
||||||
|
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert tokenizer.decode(encoded_pairs[0][1]) == "I am fine!<|eot_id|>"
|
||||||
|
assert tokenizer.decode(encoded_pairs[1][0]) == (
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert tokenizer.decode(encoded_pairs[1][1]) == "很高兴认识你!<|eot_id|>"
|
||||||
|
|
||||||
|
|
||||||
def test_jinja_template():
|
def test_jinja_template():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||||
get_template_and_fix_tokenizer(tokenizer, name="llama3")
|
get_template_and_fix_tokenizer(tokenizer, name="llama3")
|
||||||
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
||||||
|
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "hi!"},
|
def test_qwen_template():
|
||||||
{"role": "assistant", "content": "hello there"},
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
|
||||||
]
|
template = get_template_and_fix_tokenizer(tokenizer, name="qwen")
|
||||||
assert tokenizer.apply_chat_template(messages) == ref_tokenizer.apply_chat_template(messages)
|
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||||
|
assert tokenizer.decode(prompt_ids) == (
|
||||||
|
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||||
|
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n你好<|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
assert tokenizer.decode(answer_ids) == "很高兴认识你!<|im_end|>"
|
||||||
|
|
Loading…
Reference in New Issue