support function calling

This commit is contained in:
hiyouga 2024-01-18 09:54:23 +08:00
parent 28135d787d
commit d9f1cae351
69 changed files with 1329 additions and 1085 deletions

View File

@ -165,9 +165,13 @@
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
"ms_hub_url": "AI-ModelScope/ultrachat_200k",
"columns": {
"messages": "messages",
"role": "role",
"content": "content"
"messages": "messages"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "human",
"assistant_tag": "assistant"
},
"formatting": "sharegpt"
},
@ -180,9 +184,13 @@
"hf_hub_url": "lmsys/lmsys-chat-1m",
"ms_hub_url": "AI-ModelScope/lmsys-chat-1m",
"columns": {
"messages": "conversation",
"role": "role",
"content": "content"
"messages": "conversation"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "human",
"assistant_tag": "assistant"
},
"formatting": "sharegpt"
},
@ -190,6 +198,14 @@
"hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k",
"formatting": "sharegpt"
},
"glaive_func_call": {
"file_name": "glaive_func_call.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"tool": "tools"
}
},
"hh_rlhf_en": {
"script_url": "hh_rlhf_en",
"columns": {

View File

@ -0,0 +1,68 @@
[
{
"conversations": [
{
"from": "human",
"value": "I need a new password. Can you generate one for me?"
},
{
"from": "gpt",
"value": "Of course. How long would you like your password to be? And would you like it to include symbols?"
},
{
"from": "human",
"value": "I would like it to be 12 characters long and yes, please include symbols."
},
{
"from": "function_call",
"value": "{\"name\": \"generate_password\", \"arguments\": {\"length\": 12, \"include_symbols\": true}}"
},
{
"from": "observation",
"value": "{\"password\": \"4&7j#9@1Q6*\"}"
},
{
"from": "gpt",
"value": "Here is your new password: 4&7j#9@1Q6*. Please make sure to save it in a secure location."
}
],
"tools": "[{\"name\": \"generate_password\", \"description\": \"Generate a random password\", \"parameters\": {\"type\": \"object\", \"properties\": {\"length\": {\"type\": \"integer\", \"description\": \"The length of the password\"}, \"include_symbols\": {\"type\": \"boolean\", \"description\": \"Whether to include symbols in the password\"}}, \"required\": [\"length\"]}}, {\"name\": \"create_task\", \"description\": \"Create a new task in a task management system\", \"parameters\": {\"type\": \"object\", \"properties\": {\"title\": {\"type\": \"string\", \"description\": \"The title of the task\"}, \"due_date\": {\"type\": \"string\", \"format\": \"date\", \"description\": \"The due date of the task\"}, \"priority\": {\"type\": \"string\", \"enum\": [\"low\", \"medium\", \"high\"], \"description\": \"The priority of the task\"}}, \"required\": [\"title\", \"due_date\", \"priority\"]}}]"
},
{
"conversations": [
{
"from": "human",
"value": "Can you tell me the latest news headlines for the United States?"
},
{
"from": "function_call",
"value": "{\"name\": \"get_news_headlines\", \"arguments\": {\"country\": \"United States\"}}"
},
{
"from": "observation",
"value": "{\"headlines\": [\"Biden announces new vaccine mandates\", \"Hurricane Ida devastates Louisiana\", \"Apple unveils new iPhone\", \"NASA's Perseverance rover collects first Mars rock sample\"]}"
},
{
"from": "gpt",
"value": "Here are the latest news headlines for the United States:"
},
{
"from": "human",
"value": "That's interesting. What about the news in France?"
},
{
"from": "function_call",
"value": "{\"name\": \"get_news_headlines\", \"arguments\": {\"country\": \"France\"}}"
},
{
"from": "observation",
"value": "{\"headlines\": [\"France recalls ambassadors to US and Australia\", \"French election: Macron's party braces for tough fight\", \"Louvre Museum to undergo major overhaul\", \"France to offer free birth control to all women under 25\"]}"
},
{
"from": "gpt",
"value": "Here are the latest news headlines for France:"
}
],
"tools": "[{\"name\": \"get_news_headlines\", \"description\": \"Get the latest news headlines\", \"parameters\": {\"type\": \"object\", \"properties\": {\"country\": {\"type\": \"string\", \"description\": \"The country for which to fetch news\"}}, \"required\": [\"country\"]}}]"
}
]

View File

@ -9,7 +9,6 @@ scipy
einops
sentencepiece
protobuf
tiktoken
jieba
rouge-chinese
nltk

View File

@ -8,3 +8,12 @@ from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.4.0"
__all__ = [
"create_app",
"ChatModel",
"Evaluator",
"export_model",
"run_exp",
"create_ui",
"create_web_demo"
]

View File

@ -1 +1,4 @@
from llmtuner.api.app import create_app
from .app import create_app
__all__ = ["create_app"]

View File

@ -5,7 +5,7 @@ from typing import List, Tuple
from pydantic import BaseModel
from contextlib import asynccontextmanager
from llmtuner.api.protocol import (
from .protocol import (
Role,
Finish,
ModelCard,
@ -21,9 +21,9 @@ from llmtuner.api.protocol import (
ScoreEvaluationRequest,
ScoreEvaluationResponse
)
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.extras.packages import (
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available
)

View File

@ -1,15 +1,17 @@
import time
from enum import Enum
from enum import Enum, unique
from pydantic import BaseModel, Field
from typing import List, Optional
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
@unique
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"

View File

@ -1 +1,4 @@
from llmtuner.chat.chat_model import ChatModel
from .chat_model import ChatModel
__all__ = ["ChatModel"]

View File

@ -1,13 +1,13 @@
import torch
import tiktoken
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..model import dispatch_model, load_model_and_tokenizer
from ..hparams import get_infer_args
@dataclass
@ -139,11 +139,6 @@ class ChatModel:
batch_input: List[str],
**input_kwargs
) -> List[float]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
@ -153,7 +148,7 @@ class ChatModel:
truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
return_tensors="pt",
**kwargs
add_special_tokens=True
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]

View File

@ -1,4 +1,6 @@
from llmtuner.data.loader import get_dataset
from llmtuner.data.preprocess import preprocess_dataset
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.data.utils import split_dataset
from .loader import get_dataset
from .template import get_template_and_fix_tokenizer, templates
from .utils import split_dataset
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset"]

View File

@ -0,0 +1,106 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
from .utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from ..hparams import DataArguments
from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tool": []}
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history:
for old_prompt, old_response in examples[dataset_attr.history][i]:
prompt.append({"role": Role.USER, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT, "content": old_response})
instruction = examples[dataset_attr.prompt][i]
if dataset_attr.query and examples[dataset_attr.query][i]:
instruction += "\n" + examples[dataset_attr.query][i]
prompt.append({"role": Role.USER, "content": instruction})
if isinstance(examples[dataset_attr.response][i], list):
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
else:
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tool"].append("")
return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tool": []}
tag_mapping = {
dataset_attr.user_tag: Role.USER,
dataset_attr.assistant_tag: Role.ASSISTANT,
dataset_attr.observation_tag: Role.OBSERVATION,
dataset_attr.function_tag: Role.FUNCTION
}
for i, messages in enumerate(examples[dataset_attr.messages]):
messages = messages[:len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
prompt = []
response = []
for turn_idx, message in enumerate(messages):
if turn_idx % 2 == 0:
accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
else:
accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
if message[dataset_attr.role_tag] not in accept_tags:
raise ValueError("Invalid role tag.")
prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]})
last_message = prompt.pop(-1)
response.append(last_message)
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tool"].append(examples[dataset_attr.tool][i] if dataset_attr.tool else "")
return outputs
def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
prompt: [{"role": "user", "content": "..."}]
response: [{"role": "assistant", "content": "..."}]
system: "..."
tool: "..."
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset"
)
return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
**kwargs
)

View File

@ -0,0 +1,99 @@
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Union
JSON_FORMAT_PROMPT = (
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
)
TOOL_SYSTEM_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format to answer the question:\n"
"```\n"
"Action: the action to take, should be one of [{tool_names}] if using a tool.\n"
"Action Input: the input to the action{format_prompt}.\n"
"```"
)
@dataclass
class StringFormatter:
container: List[Union[str, Dict[str, str]]]
def __call__(self, **kwargs) -> List[Union[str, Dict[str, str]]]:
elements = []
for elem in self.container:
if isinstance(elem, str):
for name, value in kwargs.items():
elem = elem.replace("{{" + name + "}}", value)
elements.append(elem)
elif isinstance(elem, (dict, set)):
elements.append(elem)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return elements
@dataclass
class FunctionFormatter:
container: List[Union[str, Dict[str, str]]]
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
try:
function = json.loads(content)
name = json.dumps(function["name"], ensure_ascii=False)
arguments = json.dumps(function["arguments"], ensure_ascii=False)
except json.JSONDecodeError:
name, arguments = "", ""
elements = []
for elem in self.container:
if isinstance(elem, str):
elem = elem.replace("{{name}}", name)
elem = elem.replace("{{arguments}}", arguments)
elements.append(elem)
elif isinstance(elem, (dict, set)):
elements.append(elem)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return elements
@dataclass
class ToolFormatter:
type: Literal["default"]
def _default(self, tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
name=name, type=param.get("type", ""), required=required, desc=param.get("description", ""), enum=enum
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format(
tool_text=tool_text,
tool_names=", ".join(tool_names),
format_prompt=JSON_FORMAT_PROMPT
)
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
try:
tools = json.loads(content)
if self.type == "default":
return [self._default(tools)]
except json.JSONDecodeError:
return [""]

View File

@ -1,160 +1,114 @@
import os
from typing import TYPE_CHECKING, Any, Dict, List, Union
from typing import TYPE_CHECKING, List, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
from llmtuner.data.utils import checksum
from llmtuner.extras.constants import FILEEXT2TYPE
from llmtuner.extras.logging import get_logger
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from .utils import checksum
from .parser import get_dataset_list
from .aligner import align_dataset
from .template import get_template_and_fix_tokenizer
from .preprocess import get_preprocess_and_print_func
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from .parser import DatasetAttr
from ..hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
def get_dataset(
def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
data_args: "DataArguments",
):
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
elif data_args.streaming:
raise ValueError("Turn off dataset streaming to save cache files.")
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
elif dataset_attr.load_from == "file":
data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
checksum(data_files, dataset_attr.dataset_sha1)
elif dataset_attr.load_from == "file":
data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise NotImplementedError
raise ValueError("File not found.")
if dataset_attr.load_from == "ms_hub":
try:
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.")
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
dataset_name=data_path,
subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
).to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
dataset = load_dataset(
path=data_path,
name=data_name,
checksum(data_files, dataset_attr.dataset_sha1)
else:
raise NotImplementedError
if dataset_attr.load_from == "ms_hub":
try:
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
dataset_name=data_path,
subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
)
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
).to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
dataset = load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
)
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if max_samples is not None: # truncate dataset
dataset = dataset.select(range(min(len(dataset), max_samples)))
if data_args.max_samples is not None: # truncate dataset
num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples))
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# convert dataset from sharegpt format to alpaca format
outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
for i, msg_list in enumerate(examples[dataset_attr.messages]):
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
if len(msg_list) == 0:
continue
return align_dataset(dataset, dataset_attr, data_args)
msg_pairs = []
user_role, assistant_role = None, None
for idx in range(0, len(msg_list), 2):
if user_role is None and assistant_role is None:
user_role = msg_list[idx][dataset_attr.role]
assistant_role = msg_list[idx + 1][dataset_attr.role]
else:
if (
msg_list[idx][dataset_attr.role] != user_role
or msg_list[idx+1][dataset_attr.role] != assistant_role
):
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
if len(msg_pairs) != 0:
outputs["prompt"].append(msg_pairs[-1][0])
outputs["query"].append("")
outputs["response"].append(msg_pairs[-1][1])
outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
return outputs
if dataset_attr.formatting == "sharegpt": # convert format
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset"
)
dataset = dataset.map(
convert_format,
batched=True,
remove_columns=column_names,
**kwargs
)
else:
for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments"
) -> Union["Dataset", "IterableDataset"]:
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
@ -166,8 +120,72 @@ def get_dataset(
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=data_args.seed,
seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
)
else:
raise ValueError("Unknown mixing strategy.")
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
tokenizer: "PreTrainedTokenizer",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load from cache
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.streaming:
raise ValueError("Turn off dataset streaming to save cache files.")
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args): # TODO: add split
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs
)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Empty dataset!")
return dataset

101
src/llmtuner/data/parser.py Normal file
View File

@ -0,0 +1,101 @@
import os
import json
from typing import TYPE_CHECKING, List, Literal, Optional
from dataclasses import dataclass
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope
if TYPE_CHECKING:
from ..hparams import DataArguments
@dataclass
class DatasetAttr:
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
system: Optional[str] = None
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
messages: Optional[str] = "conversations"
tool: Optional[str] = None
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
def __repr__(self) -> str:
return self.dataset_name
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if data_args.dataset is not None:
raise ValueError("Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)))
dataset_info = None
if data_args.interleave_probs is not None:
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr(
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.folder = dataset_info[name].get("folder", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
if "columns" in dataset_info[name]:
if dataset_attr.formatting == "alpaca":
column_names = ["prompt", "query", "response", "history"]
else:
column_names = ["messages", "tool"]
column_names += ["system"]
for column_name in column_names:
setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None))
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]:
setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None))
dataset_list.append(dataset_attr)
return dataset_list

View File

@ -1,272 +1,241 @@
import os
import tiktoken
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
from ..hparams import DataArguments
from .template import Template
logger = get_logger(__name__)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
system = examples["system"][i] if "system" in examples else None
yield query, response, history, system
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len
def preprocess_dataset(
dataset: Union["Dataset", "IterableDataset"],
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
for i in range(len(tokenized_examples["input_ids"])):
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
tokenized_examples["attention_mask"][i] += [1]
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
return dataset # already preprocessed
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
continue
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
setattr(tokenizer, "add_eos_token", True)
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
# make sure the saved tokenizer is the same as the original one
if hasattr(tokenizer, "add_eos_token"):
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
return result
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
source_len, target_len = len(source_ids), len(target_ids)
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
source_ids = source_ids[:max_source_len]
if target_len > max_target_len:
target_ids = target_ids[:max_target_len]
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
if len(input_ids) > data_args.cutoff_len:
input_ids = input_ids[:data_args.cutoff_len]
labels = labels[:data_args.cutoff_len]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
)):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size])
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and query != ""):
continue
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
continue
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
messages = examples["prompt"][i] + examples["response"][i]
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], 1_000_000
)):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if len(input_ids) > data_args.cutoff_len:
input_ids = input_ids[:data_args.cutoff_len]
if len(labels) > data_args.cutoff_len:
labels = labels[:data_args.cutoff_len]
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size])
return model_inputs
return model_inputs
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
continue
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
continue
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
prompt_ids = prompt_ids[:max_source_len]
if target_len > max_target_len:
chosen_ids = chosen_ids[:max_target_len]
rejected_ids = rejected_ids[:max_target_len]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
))
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
preprocess_func = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
preprocess_func = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
with training_args.main_process_first(desc="dataset map pre-processing"):
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = template.encode_oneturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Empty dataset!")
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return dataset
return model_inputs
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) < 2:
continue
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
))
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
def get_preprocess_and_print_func(
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.sft_packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
else:
preprocess_func = partial(
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function

View File

@ -1,8 +1,10 @@
import tiktoken
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from ..extras.logging import get_logger
from .utils import Role
from .formatter import StringFormatter, FunctionFormatter, ToolFormatter
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@ -14,28 +16,30 @@ logger = get_logger(__name__)
@dataclass
class Template:
prefix: List[Union[str, Dict[str, str]]]
prompt: List[Union[str, Dict[str, str]]]
format_user: Callable
format_assistant: Callable
format_system: Callable
format_tool: Callable
format_observation: Callable
format_function: Callable
system: str
sep: List[Union[str, Dict[str, str]]]
separator: List[Union[str, Dict[str, str]]]
stop_words: List[str]
use_history: bool
efficient_eos: bool
replace_eos: bool
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
messages: List[Dict[str, str]],
system: str,
tool: str,
cutoff_len: int
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
@ -46,109 +50,75 @@ class Template:
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
messages: List[Dict[str, str]],
system: str,
tool: str,
cutoff_len: int
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
return encoded_pairs
def _format(
self,
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> Tuple[str, List[Tuple[str, str]]]:
r"""
Aligns inputs to the standard format.
"""
system = system or self.system # use system if provided
history = history if (history and self.use_history) else []
history = history + [(query, resp)]
return system, history
def _get_special_ids(
self,
tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]:
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
bos_ids = [tokenizer.bos_token_id]
else: # baichuan, gpt2, qwen, yi models have no bos token
bos_ids = []
if tokenizer.eos_token_id is None:
raise ValueError("EOS token is required.")
if self.efficient_eos:
eos_ids = []
else:
eos_ids = [tokenizer.eos_token_id]
return bos_ids, eos_ids
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
history: List[Tuple[str, str]]
tool: str,
cutoff_len: int
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + sep + query resp + eos
Turn t: sep + bos + query resp + eos
Turn 0: system + query resp + eos
Turn t: sep + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
if len(prefix_ids) != 0: # has prefix
prefix_ids = bos_ids + prefix_ids + sep_ids
else:
prefix_ids = bos_ids
else:
prefix_ids = sep_ids + bos_ids
system = system or self.system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tool):
tool_text = self.format_tool(content=tool)[0] if tool else ""
elements += self.format_system(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.separator
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx+1))
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
if message["role"] == Role.USER:
elements += self.format_user(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function(content=message["content"])
def _convert_inputs_to_ids(
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)]
def _convert_elements_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]],
system: Optional[str] = None,
query: Optional[str] = None,
idx: Optional[str] = None
elements: List[Union[str, Dict[str, str]]]
) -> List[int]:
r"""
Converts context to token ids.
Converts elements to token ids.
"""
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
token_ids = []
for elem in context:
for elem in elements:
if isinstance(elem, str):
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
if len(elem) != 0:
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id:
token_ids = token_ids + [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id:
token_ids = token_ids + [tokenizer.eos_token_id]
else:
raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return token_ids
@ -159,23 +129,39 @@ class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
history: List[Tuple[str, str]]
tool: str,
cutoff_len: int
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + query resp + eos
Turn t: bos + query resp + eos
Turn 0: system + query resp + eos
Turn t: sep + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: # llama2 template has no sep_ids
query = self.prefix[0].replace("{{system}}", system) + query
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
system = system or self.system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
system_text = ""
if i == 0 and (system or tool):
tool_text = self.format_tool(content=tool)[0] if tool else ""
system_text = self.format_system(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
elements += self.separator
if message["role"] == Role.USER:
elements += self.format_user(content=system_text + message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function(content=message["content"])
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)]
templates: Dict[str, Template] = {}
@ -183,23 +169,33 @@ templates: Dict[str, Template] = {}
def register_template(
name: str,
prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]],
system: str,
sep: List[Union[str, Dict[str, str]]],
format_user: Optional[Callable] = None,
format_assistant: Optional[Callable] = None,
format_system: Optional[Callable] = None,
format_tool: Optional[Callable] = None,
format_observation: Optional[Callable] = None,
format_function: Optional[Callable] = None,
system: Optional[str] = "",
separator: Optional[List[Union[str, Dict[str, str]]]] = "",
stop_words: Optional[List[str]] = [],
use_history: Optional[bool] = True,
efficient_eos: Optional[bool] = False,
replace_eos: Optional[bool] = False
) -> None:
template_class = Llama2Template if name.startswith("llama2") else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
format_user=format_user or StringFormatter(container=["{{content}}"]),
format_assistant=format_assistant or StringFormatter(container=[
"{{content}}", {"eos_token"}
]),
format_system=format_system or StringFormatter(container=["{{content}}"]),
format_tool=format_tool or ToolFormatter(type="default"),
format_observation=format_observation or format_user,
format_function=format_function or FunctionFormatter(container=[
"Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}
]),
system=system,
sep=sep,
separator=separator,
stop_words=stop_words,
use_history=use_history,
efficient_eos=efficient_eos,
replace_eos=replace_eos
)
@ -244,17 +240,14 @@ def get_template_and_fix_tokenizer(
register_template(
name="alpaca",
prefix=[
"{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
format_user=StringFormatter(container=[
"### Instruction:\n{{content}}\n\n### Response:\n"
]),
system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
sep=[
separator=[
"\n\n"
]
)
@ -262,17 +255,14 @@ register_template(
register_template(
name="aquila",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}###Assistant:"
],
format_user=StringFormatter(container=[
"Human: {{content}}###Assistant:"
]),
system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
sep=[
separator=[
"###"
],
stop_words=[
@ -284,46 +274,32 @@ register_template(
register_template(
name="baichuan",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_102>"}, # user token
"{{query}}",
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[],
format_user=StringFormatter(container=[
{"token": "<reserved_102>"},
"{{content}}",
{"token": "<reserved_103>"}
]),
efficient_eos=True
)
register_template(
name="baichuan2",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_106>"}, # user token
"{{query}}",
{"token": "<reserved_107>"} # assistant token
],
system="",
sep=[],
format_user=StringFormatter(container=[
{"token": "<reserved_106>"},
"{{content}}",
{"token": "<reserved_107>"}
]),
efficient_eos=True
)
register_template(
name="belle",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nBelle: "
],
system="",
sep=[
format_user=StringFormatter(container=[
"Human: {{content}}\n\nBelle: "
]),
separator=[
"\n\n"
]
)
@ -331,31 +307,25 @@ register_template(
register_template(
name="bluelm",
prefix=[
"{{system}}"
],
prompt=[
format_user=StringFormatter(container=[
{"token": "[|Human|]:"},
"{{query}}",
"{{content}}",
{"token": "[|AI|]:"}
],
system="",
sep=[]
])
)
register_template(
name="chatglm2",
prefix=[
format_user=StringFormatter(container=[
"[Round {{idx}}]\n\n问:{{content}}\n\n答:"
]),
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
],
system="",
sep=[
"{{content}}"
]),
separator=[
"\n\n"
],
efficient_eos=True
@ -364,53 +334,35 @@ register_template(
register_template(
name="chatglm3",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
{"token": "<|system|>"},
"\n",
"{{system}}"
],
prompt=[
format_user=StringFormatter(container=[
{"token": "<|user|>"},
"\n",
"{{query}}",
{"token": "<|assistant|>"},
"\n" # add an extra newline to avoid error in ChatGLM's process_response method
],
system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown."
),
sep=[],
stop_words=[
"<|user|>",
"<|observation|>"
],
efficient_eos=True
)
register_template(
name="chatglm3_raw", # the raw template for tool tuning
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
{"token": "<|system|>"},
"\n",
"{{system}}"
],
prompt=[
{"token": "<|user|>"},
"\n",
"{{query}}",
"{{content}}",
{"token": "<|assistant|>"}
],
]),
format_assistant=StringFormatter(container=[
"\n"
"{{content}}"
]),
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
{"token": "<|system|>"},
"\n",
"{{content}}"
]),
format_observation=StringFormatter(container=[
{"token": "<|observation|>"},
"\n",
"{{content}}"
]),
format_function=FunctionFormatter(container=[
"{{name}}\n{{arguments}}"
]),
system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown."
),
sep=[],
stop_words=[
"<|user|>",
"<|observation|>"
@ -421,47 +373,34 @@ register_template(
register_template(
name="codegeex2",
prefix=[
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
"{{query}}"
],
system="",
sep=[]
"{{content}}"
])
)
register_template(
name="deepseek",
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\n\nAssistant:"
],
system="",
sep=[]
format_user=StringFormatter(container=[
"User: {{content}}\n\nAssistant:"
])
)
register_template(
name="deepseekcoder",
prefix=[
"{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n### Response:\n"
],
format_user=StringFormatter(container=[
"### Instruction:\n{{content}}\n### Response:\n"
]),
system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
),
sep=[
separator=[
"\n",
{"token": "<|EOT|>"},
"\n"
@ -475,17 +414,14 @@ register_template(
register_template(
name="default",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant:"
],
format_user=StringFormatter(container=[
"Human: {{content}}\nAssistant: "
]),
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
),
sep=[
separator=[
"\n"
]
)
@ -493,14 +429,10 @@ register_template(
register_template(
name="falcon",
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\nFalcon:"
],
system="",
sep=[
format_user=StringFormatter(container=[
"User: {{content}}\nFalcon:"
]),
separator=[
"\n"
],
efficient_eos=True
@ -509,16 +441,12 @@ register_template(
register_template(
name="intern",
prefix=[
"{{system}}"
],
prompt=[
"<|User|>:{{query}}",
format_user=StringFormatter(container=[
"<|User|>:{{content}}",
{"token": "<eoh>"},
"\n<|Bot|>:"
],
system="",
sep=[
]),
separator=[
{"token": "<eoa>"},
"\n"
],
@ -529,14 +457,44 @@ register_template(
)
register_template(
name="intern2",
format_user=StringFormatter(container=[
{"token": "[UNUSED_TOKEN_146]"},
"user\n{{content}}",
{"token": "[UNUSED_TOKEN_145]"},
"\n",
{"token": "[UNUSED_TOKEN_146]"},
"assistant\n"
]),
format_system=StringFormatter(container=[
{"token": "[UNUSED_TOKEN_146]"},
"system\n{{content}}",
{"token": "[UNUSED_TOKEN_145]"},
"\n"
]),
system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed "
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
"by the user such as English and 中文."
),
separator=[
{"token": "[UNUSED_TOKEN_145]"},
"\n"
],
stop_words=[
"[UNUSED_TOKEN_145]"
],
efficient_eos=True
)
register_template(
name="llama2",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST]"
],
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
@ -546,49 +504,32 @@ register_template(
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[]
)
)
register_template(
name="llama2_zh",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system="You are a helpful assistant. 你是一个乐于助人的助手。",
sep=[]
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
system="You are a helpful assistant. 你是一个乐于助人的助手。"
)
register_template(
name="mistral",
prefix=[
"{{system}}"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system="",
sep=[]
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"])
)
register_template(
name="openchat",
prefix=[
"{{system}}"
],
prompt=[
"GPT4 Correct User: {{query}}",
format_user=StringFormatter(container=[
"GPT4 Correct User: {{content}}",
{"token": "<|end_of_turn|>"},
"GPT4 Correct Assistant:"
],
system="",
sep=[
]),
separator=[
{"token": "<|end_of_turn|>"}
],
stop_words=[
@ -600,14 +541,14 @@ register_template(
register_template(
name="qwen",
prefix=[
"<|im_start|>system\n{{system}}<|im_end|>"
],
prompt=[
"<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
],
format_user=StringFormatter(container=[
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
]),
format_system=StringFormatter(container=[
"<|im_start|>system\n{{content}}<|im_end|>\n"
]),
system="You are a helpful assistant.",
sep=[
separator=[
"\n"
],
stop_words=[
@ -619,32 +560,28 @@ register_template(
register_template(
name="solar",
prefix=[
"{{system}}"
],
prompt=[
"### User:\n{{query}}\n\n### Assistant:\n"
],
system="",
sep=[]
format_user=StringFormatter(container=[
"### User:\n{{content}}\n\n### Assistant:\n"
])
)
register_template(
name="starchat",
prefix=[
{"token": "<|system|>"},
"\n{{system}}",
],
prompt=[
format_user=StringFormatter(container=[
{"token": "<|user|>"},
"\n{{query}}",
"\n{{content}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
],
system="",
sep=[
]),
format_system=StringFormatter(container=[
{"token": "<|system|>"},
"\n{{content}}",
{"token": "<|end|>"},
"\n"
]),
separator=[
{"token": "<|end|>"},
"\n"
],
@ -656,75 +593,55 @@ register_template(
register_template(
name="vanilla",
prefix=[],
prompt=[
"{{query}}"
],
system="",
sep=[],
use_history=False
name="vanilla"
)
register_template(
name="vicuna",
prefix=[
"{{system}}"
],
prompt=[
"USER: {{query}} ASSISTANT:"
],
format_user=StringFormatter(container=[
"USER: {{content}} ASSISTANT:"
]),
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[]
)
)
register_template(
name="xuanyuan",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}} Assistant:"
],
format_user=StringFormatter(container=[
"Human: {{content}} Assistant:"
]),
system=(
"以下是用户和人工智能助手之间的对话。用户以Human开头人工智能助手以Assistant开头"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
),
sep=[]
)
)
register_template(
name="xverse",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nAssistant: "
],
system="",
sep=[]
format_user=StringFormatter(container=[
"Human: {{content}}\n\nAssistant: "
])
)
register_template(
name="yayi",
prefix=[
{"token": "<|System|>"},
":\n{{system}}"
],
prompt=[
format_user=StringFormatter(container=[
{"token": "<|Human|>"},
":\n{{query}}\n\n",
":\n{{content}}\n\n",
{"token": "<|YaYi|>"},
":"
],
]),
format_system=StringFormatter(container=[
{"token": "<|System|>"},
":\n{{content}}\n\n"
]),
system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
@ -736,7 +653,7 @@ register_template(
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[
separator=[
"\n\n"
],
stop_words=[
@ -747,14 +664,10 @@ register_template(
register_template(
name="yi",
prefix=[
"{{system}}"
],
prompt=[
"<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
],
system="",
sep=[
format_user=StringFormatter(container=[
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
]),
separator=[
"\n"
],
stop_words=[
@ -766,15 +679,11 @@ register_template(
register_template(
name="yuan",
prefix=[
"{{system}}"
],
prompt=[
"{{query}}",
format_user=StringFormatter(container=[
"{{content}}",
{"token": "<sep>"}
],
system="",
sep=[
]),
separator=[
"\n"
],
stop_words=[
@ -786,30 +695,25 @@ register_template(
register_template(
name="zephyr",
prefix=[
"<|system|>\n{{system}}</s>",
],
prompt=[
"<|user|>\n{{query}}</s><|assistant|>"
],
system="You are a friendly chatbot who always responds in the style of a pirate",
sep=[]
format_user=StringFormatter(container=[
"<|user|>\n{{content}}</s><|assistant|>"
]),
format_system=StringFormatter(container=[
"<|system|>\n{{content}}</s>",
]),
system="You are a friendly chatbot who always responds in the style of a pirate"
)
register_template(
name="ziya",
prefix=[
"{{system}}"
],
prompt=[
format_user=StringFormatter(container=[
{"token": "<human>"},
":{{query}}\n",
":{{content}}\n",
{"token": "<bot>"},
":"
],
system="",
sep=[
]),
separator=[
"\n"
]
)

View File

@ -1,7 +1,8 @@
import hashlib
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
from ..extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
@ -12,6 +13,14 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
OBSERVATION = "observation"
FUNCTION = "function"
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
@ -27,6 +36,13 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len
def split_dataset(
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",

View File

@ -1 +1,4 @@
from llmtuner.eval.evaluator import Evaluator
from .evaluator import Evaluator
__all__ = ["Evaluator"]

View File

@ -3,7 +3,6 @@
import os
import json
import torch
import tiktoken
import numpy as np
from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional
@ -11,10 +10,11 @@ from typing import Any, Dict, List, Optional
from datasets import load_dataset
from transformers.utils import cached_file
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.eval.template import get_eval_template
from llmtuner.extras.constants import CHOICES, SUBJECTS
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
from ..data import get_template_and_fix_tokenizer
from .template import get_eval_template
from ..extras.constants import CHOICES, SUBJECTS
from ..hparams import get_eval_args
from ..model import dispatch_model, load_model_and_tokenizer
class Evaluator:
@ -26,15 +26,9 @@ class Evaluator:
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = self._encode_choices()
def _encode_choices(self) -> List[int]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
self.choice_inputs = [self.tokenizer.encode(
self.eval_template.prefix + ch, add_special_tokens=False
)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from llmtuner.extras.constants import CHOICES
from ..extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@ -44,7 +44,7 @@ class EvalTemplate:
return query.strip(), resp, history
eval_templates: Dict[str, EvalTemplate] = {}
eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template(
@ -62,7 +62,7 @@ def register_eval_template(
)
def get_eval_template(name: str) -> EvalTemplate:
def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name)
return eval_template

View File

@ -6,9 +6,9 @@ from datetime import timedelta
from transformers import TrainerCallback
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import fix_valuehead_checkpoint
from .constants import LOG_FILE_NAME
from .logging import get_logger
from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING:

View File

@ -5,6 +5,8 @@ from typing import Dict, Optional
CHOICES = ["A", "B", "C", "D"]
DATA_CONFIG = "dataset_info.json"
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)

View File

@ -13,8 +13,8 @@ from transformers.utils import (
)
from peft import PeftModel
from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from llmtuner.extras.logging import get_logger
from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from .logging import get_logger
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()

View File

@ -10,7 +10,7 @@ try:
except ImportError:
print("Please upgrade `transformers`.")
from llmtuner.extras.packages import is_flash_attn2_available
from ..packages import is_flash_attn2_available
if is_flash_attn2_available():

View File

@ -4,8 +4,8 @@ import json
from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.packages import is_matplotlib_available
from .logging import get_logger
from .packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.pyplot as plt

View File

@ -3,3 +3,16 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .parser import get_train_args, get_infer_args, get_eval_args
__all__ = [
"DataArguments",
"EvaluationArguments",
"FinetuningArguments",
"GeneratingArguments",
"ModelArguments",
"get_train_args",
"get_infer_args",
"get_eval_args"
]

View File

@ -1,40 +1,7 @@
import os
import json
from typing import List, Literal, Optional
from typing import Literal, Optional
from dataclasses import dataclass, field
DATA_CONFIG = "dataset_info.json"
def use_modelscope() -> bool:
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
@dataclass
class DatasetAttr:
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
messages: Optional[str] = "conversations"
role: Optional[str] = "from"
content: Optional[str] = "value"
system: Optional[str] = None
def __repr__(self) -> str:
return self.dataset_name
@dataclass
class DataArguments:
r"""
@ -126,64 +93,3 @@ class DataArguments:
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")
def init_for_training(self, seed: int): # support mixing multiple datasets
self.seed = seed
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
try:
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if self.dataset is not None:
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
dataset_info = None
if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
self.dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url):
dataset_attr = DatasetAttr(
"ms_hub",
dataset_name=dataset_info[name]["ms_hub_url"]
)
else:
dataset_attr = DatasetAttr(
"hf_hub",
dataset_name=dataset_info[name]["hf_hub_url"]
)
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr(
"script",
dataset_name=dataset_info[name]["script_url"]
)
else:
dataset_attr = DatasetAttr(
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
if "columns" in dataset_info[name]:
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
dataset_attr.system = dataset_info[name]["columns"].get("system", None)
dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.folder = dataset_info[name].get("folder", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
self.dataset_list.append(dataset_attr)

View File

@ -43,13 +43,5 @@ class EvaluationArguments:
)
def __post_init__(self):
task_available = []
for folder in os.listdir(self.task_dir):
if os.path.isdir(os.path.join(self.task_dir, folder)):
task_available.append(folder)
if self.task not in task_available:
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
if self.save_dir is not None and os.path.exists(self.save_dir):
raise ValueError("`save_dir` already exists, use another one.")

View File

@ -8,14 +8,12 @@ from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import (
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments,
GeneratingArguments
)
from ..extras.logging import get_logger
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
logger = get_logger(__name__)
@ -107,8 +105,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
_set_transformers_logging()
# Check arguments
data_args.init_for_training(training_args.seed)
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")

View File

@ -1,5 +1,5 @@
# Level: loader > adapter > parser, utils
from .loader import load_model_and_tokenizer
from .utils import dispatch_model, get_modelcard_args, load_valuehead_params
from llmtuner.model.loader import load_model_and_tokenizer
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
from llmtuner.model.utils import dispatch_model, get_modelcard_args, load_valuehead_params
__all__ = ["load_model_and_tokenizer", "dispatch_model", "get_modelcard_args", "load_valuehead_params"]

View File

@ -3,12 +3,12 @@ from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from llmtuner.extras.logging import get_logger
from llmtuner.model.utils import find_all_linear_modules
from ..extras.logging import get_logger
from .utils import find_all_linear_modules
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, FinetuningArguments
from ..hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)

View File

@ -4,15 +4,15 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from llmtuner.model.adapter import init_adapter
from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
from llmtuner.model.utils import load_valuehead_params, register_autoclass
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter
from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from llmtuner.hparams import ModelArguments, FinetuningArguments
from ..hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)

View File

@ -10,15 +10,15 @@ from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTra
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from llmtuner.extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import get_current_device, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments
from ..hparams import ModelArguments
logger = get_logger(__name__)

View File

@ -4,13 +4,13 @@ from typing import TYPE_CHECKING, Any, Dict, List
from transformers import PreTrainedModel
from transformers.utils import cached_file
from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import get_current_device
from ..extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from ..extras.logging import get_logger
from ..extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from ..hparams import ModelArguments, DataArguments, FinetuningArguments
logger = get_logger(__name__)

View File

@ -1 +1,4 @@
from llmtuner.train.tuner import export_model, run_exp
from .tuner import export_model, run_exp
__all__ = ["export_model", "run_exp"]

View File

@ -1 +1,4 @@
from llmtuner.train.dpo.workflow import run_dpo
from .workflow import run_dpo
__all__ = ["run_dpo"]

View File

@ -5,7 +5,7 @@ from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX
from ...extras.constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers import PreTrainedModel

View File

@ -3,18 +3,18 @@
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
from llmtuner.train.dpo.trainer import CustomDPOTrainer
from llmtuner.train.utils import create_modelcard_and_push, create_ref_model
from ...data import get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model_and_tokenizer
from ...train.dpo.collator import DPODataCollatorWithPadding
from ...train.dpo.trainer import CustomDPOTrainer
from ...train.utils import create_modelcard_and_push, create_ref_model
if TYPE_CHECKING:
from transformers import TrainerCallback
from llmtuner.hparams import DataArguments, FinetuningArguments
from ...hparams import DataArguments, FinetuningArguments
def run_dpo(
@ -24,9 +24,8 @@ def run_dpo(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=8,

View File

@ -1 +1,4 @@
from llmtuner.train.ppo.workflow import run_ppo
from .workflow import run_ppo
__all__ = ["run_ppo"]

View File

@ -13,15 +13,15 @@ from transformers.trainer_pt_utils import remove_dummy_checkpoint
from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, FixValueHeadModelCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
from ...extras.callbacks import LogCallback, FixValueHeadModelCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
from .utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
from ...hparams import ModelArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)

View File

@ -2,7 +2,7 @@ import json
import torch
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
from llmtuner.extras.packages import is_requests_available
from ...extras.packages import is_requests_available
if TYPE_CHECKING:
from transformers import PreTrainedModel

View File

@ -7,17 +7,17 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import FixValueHeadModelCallback
from llmtuner.extras.misc import fix_valuehead_checkpoint
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model, create_reward_model
from llmtuner.train.ppo.trainer import CustomPPOTrainer
from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.utils import create_ref_model, create_reward_model
from ...train.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_ppo(
@ -28,9 +28,8 @@ def run_ppo(
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="ppo")
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

View File

@ -1 +1,4 @@
from llmtuner.train.pt.workflow import run_pt
from .workflow import run_pt
__all__ = ["run_pt"]

View File

@ -4,14 +4,14 @@ import math
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.utils import create_modelcard_and_push
from ...data import get_dataset, split_dataset
from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from ...hparams import ModelArguments, DataArguments, FinetuningArguments
def run_pt(
@ -21,9 +21,8 @@ def run_pt(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="pt")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer

View File

@ -1 +1,4 @@
from llmtuner.train.rm.workflow import run_rm
from .workflow import run_rm
__all__ = ["run_rm"]

View File

@ -4,7 +4,7 @@ import torch
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from transformers import Trainer
from llmtuner.extras.logging import get_logger
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput

View File

@ -3,19 +3,19 @@
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import FixValueHeadModelCallback
from llmtuner.extras.misc import fix_valuehead_checkpoint
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.train.rm.metric import compute_accuracy
from llmtuner.train.rm.trainer import PairwiseTrainer
from llmtuner.train.utils import create_modelcard_and_push
from ...data import get_dataset, split_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.rm.collator import PairwiseDataCollatorWithPadding
from ...train.rm.metric import compute_accuracy
from ...train.rm.trainer import PairwiseTrainer
from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from ...hparams import ModelArguments, DataArguments, FinetuningArguments
def run_rm(
@ -25,9 +25,8 @@ def run_rm(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
# Update arguments

View File

@ -1 +1,4 @@
from llmtuner.train.sft.workflow import run_sft
from .workflow import run_sft
__all__ = ["run_sft"]

View File

@ -2,8 +2,8 @@ import numpy as np
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.packages import (
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import (
is_jieba_available, is_nltk_available, is_rouge_available
)

View File

@ -6,8 +6,8 @@ import torch.nn as nn
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from transformers import Seq2SeqTrainer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput

View File

@ -3,18 +3,19 @@
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.sft.metric import ComputeMetrics
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
from llmtuner.train.utils import create_modelcard_and_push
from ...data import get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.sft.metric import ComputeMetrics
from ...train.sft.trainer import CustomSeq2SeqTrainer
from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_sft(
@ -25,9 +26,8 @@ def run_sft(
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="sft")
if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation

View File

@ -2,14 +2,15 @@ import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import PreTrainedModel
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.train.pt import run_pt
from llmtuner.train.sft import run_sft
from llmtuner.train.rm import run_rm
from llmtuner.train.ppo import run_ppo
from llmtuner.train.dpo import run_dpo
from ..extras.callbacks import LogCallback
from ..extras.logging import get_logger
from ..hparams import get_train_args, get_infer_args
from ..model import load_model_and_tokenizer
from .pt import run_pt
from .sft import run_sft
from .rm import run_rm
from .ppo import run_ppo
from .dpo import run_dpo
if TYPE_CHECKING:
from transformers import TrainerCallback

View File

@ -1,15 +1,15 @@
import torch
from typing import TYPE_CHECKING, Optional, Union
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
from ..extras.logging import get_logger
from ..hparams import ModelArguments, FinetuningArguments
from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import DataArguments
from ..hparams import DataArguments
logger = get_logger(__name__)

View File

@ -1 +1,4 @@
from llmtuner.webui.interface import create_ui, create_web_demo
from .interface import create_ui, create_web_demo
__all__ = ["create_ui", "create_web_demo"]

View File

@ -2,14 +2,14 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..hparams import GeneratingArguments
from .common import get_save_dir
from .locales import ALERTS
if TYPE_CHECKING:
from llmtuner.webui.manager import Manager
from .manager import Manager
class WebChatModel(ChatModel):

View File

@ -5,7 +5,8 @@ from collections import defaultdict
from typing import Any, Dict, Optional
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
from llmtuner.extras.constants import (
from ..extras.constants import (
DATA_CONFIG,
DEFAULT_MODULE,
DEFAULT_TEMPLATE,
PEFT_METHODS,
@ -13,8 +14,7 @@ from llmtuner.extras.constants import (
TRAINING_STAGES,
DownloadSource
)
from llmtuner.extras.misc import use_modelscope
from llmtuner.hparams.data_args import DATA_CONFIG
from ..extras.misc import use_modelscope
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}

View File

@ -1,6 +1,11 @@
from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.train import create_train_tab
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab
from llmtuner.webui.components.chatbot import create_chat_box
from .top import create_top
from .train import create_train_tab
from .eval import create_eval_tab
from .infer import create_infer_tab
from .export import create_export_tab
from .chatbot import create_chat_box
__all__ = [
"create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box"
]

View File

@ -4,7 +4,8 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_chat_box(

View File

@ -3,7 +3,7 @@ import json
import gradio as gr
from typing import TYPE_CHECKING, Any, Dict, Tuple
from llmtuner.webui.common import DATA_CONFIG
from ...extras.constants import DATA_CONFIG
if TYPE_CHECKING:
from gradio.components import Component

View File

@ -1,12 +1,13 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from ..common import list_dataset, DEFAULT_DATA_DIR
from .data import create_preview_box
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:

View File

@ -1,13 +1,14 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List
from llmtuner.train import export_model
from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS
from ...train import export_model
from ..common import get_save_dir
from ..locales import ALERTS
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"]

View File

@ -1,11 +1,12 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.webui.components.chatbot import create_chat_box
from .chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:

View File

@ -1,10 +1,10 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.data.template import templates
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.webui.common import get_model_path, get_template, list_adapters, save_config
from llmtuner.webui.utils import can_quantize
from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ..common import get_model_path, get_template, list_adapters, save_config
from ..utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component

View File

@ -2,14 +2,15 @@ import gradio as gr
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.webui.common import list_adapters, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import gen_plot
from ...extras.constants import TRAINING_STAGES
from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR
from ..components.data import create_preview_box
from ..utils import gen_plot
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:

View File

@ -2,12 +2,12 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, Optional
from llmtuner.webui.chatter import WebChatModel
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import get_time
from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config
from .locales import LOCALES
from .manager import Manager
from .runner import Runner
from .utils import get_time
class Engine:

View File

@ -2,7 +2,7 @@ import gradio as gr
from typing import Optional
from transformers.utils.versions import require_version
from llmtuner.webui.components import (
from .components import (
create_top,
create_train_tab,
create_eval_tab,
@ -10,9 +10,9 @@ from llmtuner.webui.components import (
create_export_tab,
create_chat_box
)
from llmtuner.webui.common import save_config
from llmtuner.webui.css import CSS
from llmtuner.webui.engine import Engine
from .common import save_config
from .css import CSS
from .engine import Engine
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")

View File

@ -9,17 +9,17 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import get_device_count, torch_gc
from llmtuner.train import run_exp
from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
from ..extras.callbacks import LogCallback
from ..extras.constants import TRAINING_STAGES
from ..extras.logging import LoggerHandler
from ..extras.misc import get_device_count, torch_gc
from ..train import run_exp
from .common import get_module, get_save_dir, load_config
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar
if TYPE_CHECKING:
from llmtuner.webui.manager import Manager
from .manager import Manager
class Runner:

View File

@ -4,12 +4,12 @@ import gradio as gr
from typing import TYPE_CHECKING, Any, Dict
from datetime import datetime
from llmtuner.extras.packages import is_matplotlib_available
from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir
from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth
from .common import get_save_dir
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
from ..extras.callbacks import LogCallback
if is_matplotlib_available():
import matplotlib.figure

View File

@ -1,6 +1,7 @@
# coding=utf-8
# Converts the InternLM2 model in the same format as LLaMA2.
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
import os
import fire
@ -43,19 +44,18 @@ def save_weight(
llama2_state_dict[key.replace("output", "lm_head")] = value
elif "tok_embeddings" in key:
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
elif "attention_norm" in key:
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
elif "wqkv" in key:
proj_size = value.size(0)
num_q_heads = internlm2_config_dict["num_attention_heads"]
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
q_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_q_heads
kv_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...]
elif "wo" in key:
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
elif "attention_norm" in key:
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
elif "ffn_norm" in key:
llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
elif "w1" in key: