diff --git a/data/dataset_info.json b/data/dataset_info.json index bc031d76..f5fe4edd 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -165,9 +165,13 @@ "hf_hub_url": "HuggingFaceH4/ultrachat_200k", "ms_hub_url": "AI-ModelScope/ultrachat_200k", "columns": { - "messages": "messages", - "role": "role", - "content": "content" + "messages": "messages" + }, + "tags": { + "role_tag": "role", + "content_tag": "content", + "user_tag": "human", + "assistant_tag": "assistant" }, "formatting": "sharegpt" }, @@ -180,9 +184,13 @@ "hf_hub_url": "lmsys/lmsys-chat-1m", "ms_hub_url": "AI-ModelScope/lmsys-chat-1m", "columns": { - "messages": "conversation", - "role": "role", - "content": "content" + "messages": "conversation" + }, + "tags": { + "role_tag": "role", + "content_tag": "content", + "user_tag": "human", + "assistant_tag": "assistant" }, "formatting": "sharegpt" }, @@ -190,6 +198,14 @@ "hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k", "formatting": "sharegpt" }, + "glaive_func_call": { + "file_name": "glaive_func_call.json", + "formatting": "sharegpt", + "columns": { + "messages": "conversations", + "tool": "tools" + } + }, "hh_rlhf_en": { "script_url": "hh_rlhf_en", "columns": { diff --git a/data/glaive_func_call.json b/data/glaive_func_call.json new file mode 100644 index 00000000..e1d6cce7 --- /dev/null +++ b/data/glaive_func_call.json @@ -0,0 +1,68 @@ +[ + { + "conversations": [ + { + "from": "human", + "value": "I need a new password. Can you generate one for me?" + }, + { + "from": "gpt", + "value": "Of course. How long would you like your password to be? And would you like it to include symbols?" + }, + { + "from": "human", + "value": "I would like it to be 12 characters long and yes, please include symbols." + }, + { + "from": "function_call", + "value": "{\"name\": \"generate_password\", \"arguments\": {\"length\": 12, \"include_symbols\": true}}" + }, + { + "from": "observation", + "value": "{\"password\": \"4&7j#9@1Q6*\"}" + }, + { + "from": "gpt", + "value": "Here is your new password: 4&7j#9@1Q6*. Please make sure to save it in a secure location." + } + ], + "tools": "[{\"name\": \"generate_password\", \"description\": \"Generate a random password\", \"parameters\": {\"type\": \"object\", \"properties\": {\"length\": {\"type\": \"integer\", \"description\": \"The length of the password\"}, \"include_symbols\": {\"type\": \"boolean\", \"description\": \"Whether to include symbols in the password\"}}, \"required\": [\"length\"]}}, {\"name\": \"create_task\", \"description\": \"Create a new task in a task management system\", \"parameters\": {\"type\": \"object\", \"properties\": {\"title\": {\"type\": \"string\", \"description\": \"The title of the task\"}, \"due_date\": {\"type\": \"string\", \"format\": \"date\", \"description\": \"The due date of the task\"}, \"priority\": {\"type\": \"string\", \"enum\": [\"low\", \"medium\", \"high\"], \"description\": \"The priority of the task\"}}, \"required\": [\"title\", \"due_date\", \"priority\"]}}]" + }, + { + "conversations": [ + { + "from": "human", + "value": "Can you tell me the latest news headlines for the United States?" + }, + { + "from": "function_call", + "value": "{\"name\": \"get_news_headlines\", \"arguments\": {\"country\": \"United States\"}}" + }, + { + "from": "observation", + "value": "{\"headlines\": [\"Biden announces new vaccine mandates\", \"Hurricane Ida devastates Louisiana\", \"Apple unveils new iPhone\", \"NASA's Perseverance rover collects first Mars rock sample\"]}" + }, + { + "from": "gpt", + "value": "Here are the latest news headlines for the United States:" + }, + { + "from": "human", + "value": "That's interesting. What about the news in France?" + }, + { + "from": "function_call", + "value": "{\"name\": \"get_news_headlines\", \"arguments\": {\"country\": \"France\"}}" + }, + { + "from": "observation", + "value": "{\"headlines\": [\"France recalls ambassadors to US and Australia\", \"French election: Macron's party braces for tough fight\", \"Louvre Museum to undergo major overhaul\", \"France to offer free birth control to all women under 25\"]}" + }, + { + "from": "gpt", + "value": "Here are the latest news headlines for France:" + } + ], + "tools": "[{\"name\": \"get_news_headlines\", \"description\": \"Get the latest news headlines\", \"parameters\": {\"type\": \"object\", \"properties\": {\"country\": {\"type\": \"string\", \"description\": \"The country for which to fetch news\"}}, \"required\": [\"country\"]}}]" + } +] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ce3c92a3..0e9090e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,6 @@ scipy einops sentencepiece protobuf -tiktoken jieba rouge-chinese nltk diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index 932dd56b..bdbd5af2 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -8,3 +8,12 @@ from llmtuner.webui import create_ui, create_web_demo __version__ = "0.4.0" +__all__ = [ + "create_app", + "ChatModel", + "Evaluator", + "export_model", + "run_exp", + "create_ui", + "create_web_demo" +] diff --git a/src/llmtuner/api/__init__.py b/src/llmtuner/api/__init__.py index b3ce183a..d7059fbd 100644 --- a/src/llmtuner/api/__init__.py +++ b/src/llmtuner/api/__init__.py @@ -1 +1,4 @@ -from llmtuner.api.app import create_app +from .app import create_app + + +__all__ = ["create_app"] diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index f130eab6..f8115227 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -5,7 +5,7 @@ from typing import List, Tuple from pydantic import BaseModel from contextlib import asynccontextmanager -from llmtuner.api.protocol import ( +from .protocol import ( Role, Finish, ModelCard, @@ -21,9 +21,9 @@ from llmtuner.api.protocol import ( ScoreEvaluationRequest, ScoreEvaluationResponse ) -from llmtuner.chat import ChatModel -from llmtuner.extras.misc import torch_gc -from llmtuner.extras.packages import ( +from ..chat import ChatModel +from ..extras.misc import torch_gc +from ..extras.packages import ( is_fastapi_availble, is_starlette_available, is_uvicorn_available ) diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py index a5b5c81d..42569f84 100644 --- a/src/llmtuner/api/protocol.py +++ b/src/llmtuner/api/protocol.py @@ -1,15 +1,17 @@ import time -from enum import Enum +from enum import Enum, unique from pydantic import BaseModel, Field from typing import List, Optional +@unique class Role(str, Enum): USER = "user" ASSISTANT = "assistant" SYSTEM = "system" +@unique class Finish(str, Enum): STOP = "stop" LENGTH = "length" diff --git a/src/llmtuner/chat/__init__.py b/src/llmtuner/chat/__init__.py index f86efe96..702d0ac7 100644 --- a/src/llmtuner/chat/__init__.py +++ b/src/llmtuner/chat/__init__.py @@ -1 +1,4 @@ -from llmtuner.chat.chat_model import ChatModel +from .chat_model import ChatModel + + +__all__ = ["ChatModel"] diff --git a/src/llmtuner/chat/chat_model.py b/src/llmtuner/chat/chat_model.py index 0c2f9c92..521270f5 100644 --- a/src/llmtuner/chat/chat_model.py +++ b/src/llmtuner/chat/chat_model.py @@ -1,13 +1,13 @@ import torch -import tiktoken from dataclasses import dataclass from typing import Any, Dict, Generator, List, Literal, Optional, Tuple from threading import Thread from transformers import GenerationConfig, TextIteratorStreamer -from llmtuner.data.template import get_template_and_fix_tokenizer -from llmtuner.extras.misc import get_logits_processor -from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer +from ..data import get_template_and_fix_tokenizer +from ..extras.misc import get_logits_processor +from ..model import dispatch_model, load_model_and_tokenizer +from ..hparams import get_infer_args @dataclass @@ -139,11 +139,6 @@ class ChatModel: batch_input: List[str], **input_kwargs ) -> List[float]: - if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) - kwargs = dict(allowed_special="all") - else: - kwargs = dict(add_special_tokens=True) - max_length = input_kwargs.pop("max_length", None) device = getattr(self.model.pretrained_model, "device", "cuda") @@ -153,7 +148,7 @@ class ChatModel: truncation=True, max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024), return_tensors="pt", - **kwargs + add_special_tokens=True ).to(device) input_ids: torch.Tensor = inputs["input_ids"] diff --git a/src/llmtuner/data/__init__.py b/src/llmtuner/data/__init__.py index 35f7caa3..3709b6e1 100644 --- a/src/llmtuner/data/__init__.py +++ b/src/llmtuner/data/__init__.py @@ -1,4 +1,6 @@ -from llmtuner.data.loader import get_dataset -from llmtuner.data.preprocess import preprocess_dataset -from llmtuner.data.template import get_template_and_fix_tokenizer -from llmtuner.data.utils import split_dataset +from .loader import get_dataset +from .template import get_template_and_fix_tokenizer, templates +from .utils import split_dataset + + +__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset"] diff --git a/src/llmtuner/data/aligner.py b/src/llmtuner/data/aligner.py new file mode 100644 index 00000000..e1e53a0c --- /dev/null +++ b/src/llmtuner/data/aligner.py @@ -0,0 +1,106 @@ +from functools import partial +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from .utils import Role + + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + + from ..hparams import DataArguments + from .parser import DatasetAttr + + +def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: + outputs = {"prompt": [], "response": [], "system": [], "tool": []} + for i in range(len(examples[dataset_attr.prompt])): + prompt = [] + if dataset_attr.history: + for old_prompt, old_response in examples[dataset_attr.history][i]: + prompt.append({"role": Role.USER, "content": old_prompt}) + prompt.append({"role": Role.ASSISTANT, "content": old_response}) + + instruction = examples[dataset_attr.prompt][i] + if dataset_attr.query and examples[dataset_attr.query][i]: + instruction += "\n" + examples[dataset_attr.query][i] + prompt.append({"role": Role.USER, "content": instruction}) + + if isinstance(examples[dataset_attr.response][i], list): + response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]] + else: + response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}] + + outputs["prompt"].append(prompt) + outputs["response"].append(response) + outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") + outputs["tool"].append("") + + return outputs + + +def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: + outputs = {"prompt": [], "response": [], "system": [], "tool": []} + tag_mapping = { + dataset_attr.user_tag: Role.USER, + dataset_attr.assistant_tag: Role.ASSISTANT, + dataset_attr.observation_tag: Role.OBSERVATION, + dataset_attr.function_tag: Role.FUNCTION + } + for i, messages in enumerate(examples[dataset_attr.messages]): + messages = messages[:len(messages) // 2 * 2] # should be multiples of 2 + if len(messages) == 0: + continue + + prompt = [] + response = [] + for turn_idx, message in enumerate(messages): + if turn_idx % 2 == 0: + accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag] + else: + accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag] + + if message[dataset_attr.role_tag] not in accept_tags: + raise ValueError("Invalid role tag.") + + prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}) + + last_message = prompt.pop(-1) + response.append(last_message) + outputs["prompt"].append(prompt) + outputs["response"].append(response) + outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") + outputs["tool"].append(examples[dataset_attr.tool][i] if dataset_attr.tool else "") + + return outputs + + +def align_dataset( + dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments" +) -> Union["Dataset", "IterableDataset"]: + r""" + Aligned dataset: + prompt: [{"role": "user", "content": "..."}] + response: [{"role": "assistant", "content": "..."}] + system: "..." + tool: "..." + """ + if dataset_attr.formatting == "alpaca": + convert_func = partial(convert_alpaca, dataset_attr=dataset_attr) + else: + convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr) + + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache), + desc="Converting format of dataset" + ) + + return dataset.map( + convert_func, + batched=True, + remove_columns=column_names, + **kwargs + ) diff --git a/src/llmtuner/data/formatter.py b/src/llmtuner/data/formatter.py new file mode 100644 index 00000000..36484d3f --- /dev/null +++ b/src/llmtuner/data/formatter.py @@ -0,0 +1,99 @@ +import json +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Union + + +JSON_FORMAT_PROMPT = ( + """, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)""" +) + + +TOOL_SYSTEM_PROMPT = ( + "You have access to the following tools:\n{tool_text}" + "Use the following format to answer the question:\n" + "```\n" + "Action: the action to take, should be one of [{tool_names}] if using a tool.\n" + "Action Input: the input to the action{format_prompt}.\n" + "```" +) + + +@dataclass +class StringFormatter: + container: List[Union[str, Dict[str, str]]] + + def __call__(self, **kwargs) -> List[Union[str, Dict[str, str]]]: + elements = [] + for elem in self.container: + if isinstance(elem, str): + for name, value in kwargs.items(): + elem = elem.replace("{{" + name + "}}", value) + elements.append(elem) + elif isinstance(elem, (dict, set)): + elements.append(elem) + else: + raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) + + return elements + + +@dataclass +class FunctionFormatter: + container: List[Union[str, Dict[str, str]]] + + def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]: + try: + function = json.loads(content) + name = json.dumps(function["name"], ensure_ascii=False) + arguments = json.dumps(function["arguments"], ensure_ascii=False) + except json.JSONDecodeError: + name, arguments = "", "" + + elements = [] + for elem in self.container: + if isinstance(elem, str): + elem = elem.replace("{{name}}", name) + elem = elem.replace("{{arguments}}", arguments) + elements.append(elem) + elif isinstance(elem, (dict, set)): + elements.append(elem) + else: + raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) + + return elements + + +@dataclass +class ToolFormatter: + type: Literal["default"] + + def _default(self, tools: List[Dict[str, Any]]) -> str: + tool_text = "" + tool_names = [] + for tool in tools: + param_text = "" + for name, param in tool["parameters"]["properties"].items(): + required = ", required" if name in tool["parameters"].get("required", []) else "" + enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else "" + param_text += " - {name} ({type}{required}): {desc}{enum}\n".format( + name=name, type=param.get("type", ""), required=required, desc=param.get("description", ""), enum=enum + ) + + tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( + name=tool["name"], desc=tool.get("description", ""), args=param_text + ) + tool_names.append(tool["name"]) + + return TOOL_SYSTEM_PROMPT.format( + tool_text=tool_text, + tool_names=", ".join(tool_names), + format_prompt=JSON_FORMAT_PROMPT + ) + + def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]: + try: + tools = json.loads(content) + if self.type == "default": + return [self._default(tools)] + except json.JSONDecodeError: + return [""] diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index f9019c8b..87f42558 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -1,160 +1,114 @@ import os -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, List, Literal, Union from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk -from llmtuner.data.utils import checksum -from llmtuner.extras.constants import FILEEXT2TYPE -from llmtuner.extras.logging import get_logger +from ..extras.constants import FILEEXT2TYPE +from ..extras.logging import get_logger +from .utils import checksum +from .parser import get_dataset_list +from .aligner import align_dataset +from .template import get_template_and_fix_tokenizer +from .preprocess import get_preprocess_and_print_func + if TYPE_CHECKING: from datasets import Dataset, IterableDataset - from llmtuner.hparams import ModelArguments, DataArguments + from transformers import Seq2SeqTrainingArguments + from transformers.tokenization_utils import PreTrainedTokenizer + + from .parser import DatasetAttr + from ..hparams import ModelArguments, DataArguments logger = get_logger(__name__) -def get_dataset( +def load_single_dataset( + dataset_attr: "DatasetAttr", model_args: "ModelArguments", - data_args: "DataArguments" -) -> Union["Dataset", "IterableDataset"]: - max_samples = data_args.max_samples - all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets + data_args: "DataArguments", +): + data_path, data_name, data_dir, data_files = None, None, None, None + if dataset_attr.load_from in ["hf_hub", "ms_hub"]: + data_path = dataset_attr.dataset_name + data_name = dataset_attr.subset + data_dir = dataset_attr.folder - if data_args.cache_path is not None: - if os.path.exists(data_args.cache_path): - logger.warning("Loading dataset from disk will ignore other data arguments.") - dataset = load_from_disk(data_args.cache_path) - if data_args.streaming: - dataset = dataset.to_iterable_dataset() - return dataset - elif data_args.streaming: - raise ValueError("Turn off dataset streaming to save cache files.") + elif dataset_attr.load_from == "script": + data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + data_name = dataset_attr.subset + data_dir = dataset_attr.folder - for dataset_attr in data_args.dataset_list: - logger.info("Loading dataset {}...".format(dataset_attr)) - - data_path, data_name, data_dir, data_files = None, None, None, None - if dataset_attr.load_from in ["hf_hub", "ms_hub"]: - data_path = dataset_attr.dataset_name - data_name = dataset_attr.subset - data_dir = dataset_attr.folder - elif dataset_attr.load_from == "script": - data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) - data_name = dataset_attr.subset - elif dataset_attr.load_from == "file": - data_files = [] - local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) - if os.path.isdir(local_path): # is directory - for file_name in os.listdir(local_path): - data_files.append(os.path.join(local_path, file_name)) - if data_path is None: - data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None) - else: - assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical." - elif os.path.isfile(local_path): # is file - data_files.append(local_path) - data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) - else: - raise ValueError("File not found.") - - assert data_path, "File extension must be txt, csv, json or jsonl." - checksum(data_files, dataset_attr.dataset_sha1) + elif dataset_attr.load_from == "file": + data_files = [] + local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + if os.path.isdir(local_path): # is directory + for file_name in os.listdir(local_path): + data_files.append(os.path.join(local_path, file_name)) + if data_path is None: + data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None) + elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None): + raise ValueError("File types should be identical.") + elif os.path.isfile(local_path): # is file + data_files.append(local_path) + data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) else: - raise NotImplementedError + raise ValueError("File not found.") - if dataset_attr.load_from == "ms_hub": - try: - from modelscope import MsDataset - from modelscope.utils.config_ds import MS_DATASETS_CACHE + if data_path is None: + raise ValueError("File extension must be txt, csv, json or jsonl.") - cache_dir = model_args.cache_dir or MS_DATASETS_CACHE - dataset = MsDataset.load( - dataset_name=data_path, - subset_name=data_name, - data_dir=data_dir, - data_files=data_files, - split=data_args.split, - cache_dir=cache_dir, - token=model_args.ms_hub_token, - use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")) - ).to_hf_dataset() - except ImportError: - raise ImportError("Please install modelscope via `pip install modelscope -U`") - else: - dataset = load_dataset( - path=data_path, - name=data_name, + checksum(data_files, dataset_attr.dataset_sha1) + else: + raise NotImplementedError + + if dataset_attr.load_from == "ms_hub": + try: + from modelscope import MsDataset + from modelscope.utils.config_ds import MS_DATASETS_CACHE + + cache_dir = model_args.cache_dir or MS_DATASETS_CACHE + dataset = MsDataset.load( + dataset_name=data_path, + subset_name=data_name, data_dir=data_dir, data_files=data_files, split=data_args.split, - cache_dir=model_args.cache_dir, - token=model_args.hf_hub_token, - streaming=(data_args.streaming and (dataset_attr.load_from != "file")) - ) + cache_dir=cache_dir, + token=model_args.ms_hub_token, + use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")) + ).to_hf_dataset() + except ImportError: + raise ImportError("Please install modelscope via `pip install modelscope -U`") + else: + dataset = load_dataset( + path=data_path, + name=data_name, + data_dir=data_dir, + data_files=data_files, + split=data_args.split, + cache_dir=model_args.cache_dir, + token=model_args.hf_hub_token, + streaming=(data_args.streaming and (dataset_attr.load_from != "file")) + ) - if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True - dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter + if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True + dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter - if max_samples is not None: # truncate dataset - dataset = dataset.select(range(min(len(dataset), max_samples))) + if data_args.max_samples is not None: # truncate dataset + num_samples = min(data_args.max_samples, len(dataset)) + dataset = dataset.select(range(num_samples)) - def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - # convert dataset from sharegpt format to alpaca format - outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []} - for i, msg_list in enumerate(examples[dataset_attr.messages]): - msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2 - if len(msg_list) == 0: - continue + return align_dataset(dataset, dataset_attr, data_args) - msg_pairs = [] - user_role, assistant_role = None, None - for idx in range(0, len(msg_list), 2): - if user_role is None and assistant_role is None: - user_role = msg_list[idx][dataset_attr.role] - assistant_role = msg_list[idx + 1][dataset_attr.role] - else: - if ( - msg_list[idx][dataset_attr.role] != user_role - or msg_list[idx+1][dataset_attr.role] != assistant_role - ): - raise ValueError("Only accepts conversation in u/a/u/a/u/a order.") - msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content])) - if len(msg_pairs) != 0: - outputs["prompt"].append(msg_pairs[-1][0]) - outputs["query"].append("") - outputs["response"].append(msg_pairs[-1][1]) - outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None) - outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") - - return outputs - - if dataset_attr.formatting == "sharegpt": # convert format - column_names = list(next(iter(dataset)).keys()) - kwargs = {} - if not data_args.streaming: - kwargs = dict( - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache), - desc="Converting format of dataset" - ) - - dataset = dataset.map( - convert_format, - batched=True, - remove_columns=column_names, - **kwargs - ) - else: - for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset - if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: - dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) - - all_datasets.append(dataset) - - if len(data_args.dataset_list) == 1: +def merge_dataset( + all_datasets: List[Union["Dataset", "IterableDataset"]], + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments" +) -> Union["Dataset", "IterableDataset"]: + if len(all_datasets) == 1: return all_datasets[0] elif data_args.mix_strategy == "concat": if data_args.streaming: @@ -166,8 +120,72 @@ def get_dataset( return interleave_datasets( datasets=all_datasets, probabilities=data_args.interleave_probs, - seed=data_args.seed, + seed=training_args.seed, stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" ) else: raise ValueError("Unknown mixing strategy.") + + +def get_dataset( + model_args: "ModelArguments", + data_args: "DataArguments", + tokenizer: "PreTrainedTokenizer", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo"], + # split: Optional[str] = "train", # TODO: add split +) -> Union["Dataset", "IterableDataset"]: + template = get_template_and_fix_tokenizer(data_args.template, tokenizer) + if data_args.train_on_prompt and template.efficient_eos: + raise ValueError("Current template does not support `train_on_prompt`.") + + # Load from cache + if data_args.cache_path is not None: + if os.path.exists(data_args.cache_path): + logger.warning("Loading dataset from disk will ignore other data arguments.") + dataset = load_from_disk(data_args.cache_path) + if data_args.streaming: + dataset = dataset.to_iterable_dataset() + return dataset + + if data_args.streaming: + raise ValueError("Turn off dataset streaming to save cache files.") + + with training_args.main_process_first(desc="load dataset"): + all_datasets = [] + for dataset_attr in get_dataset_list(data_args): # TODO: add split + all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) + dataset = merge_dataset(all_datasets, data_args, training_args) + + with training_args.main_process_first(desc="pre-process dataset"): + preprocess_func, print_function = get_preprocess_and_print_func( + tokenizer, template, data_args, training_args, stage + ) + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache), + desc="Running tokenizer on dataset" + ) + + dataset = dataset.map( + preprocess_func, + batched=True, + remove_columns=column_names, + **kwargs + ) + + if data_args.cache_path is not None and not os.path.exists(data_args.cache_path): + if training_args.should_save: + dataset.save_to_disk(data_args.cache_path) + logger.info("Dataset cache saved at {}.".format(data_args.cache_path)) + + if training_args.should_log: + try: + print_function(next(iter(dataset))) + except StopIteration: + raise RuntimeError("Empty dataset!") + + return dataset diff --git a/src/llmtuner/data/parser.py b/src/llmtuner/data/parser.py new file mode 100644 index 00000000..842461d4 --- /dev/null +++ b/src/llmtuner/data/parser.py @@ -0,0 +1,101 @@ +import os +import json +from typing import TYPE_CHECKING, List, Literal, Optional +from dataclasses import dataclass + +from ..extras.constants import DATA_CONFIG +from ..extras.misc import use_modelscope + +if TYPE_CHECKING: + from ..hparams import DataArguments + + +@dataclass +class DatasetAttr: + + load_from: Literal["hf_hub", "ms_hub", "script", "file"] + dataset_name: Optional[str] = None + dataset_sha1: Optional[str] = None + subset: Optional[str] = None + folder: Optional[str] = None + ranking: Optional[bool] = False + formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" + + system: Optional[str] = None + + prompt: Optional[str] = "instruction" + query: Optional[str] = "input" + response: Optional[str] = "output" + history: Optional[str] = None + + messages: Optional[str] = "conversations" + tool: Optional[str] = None + + role_tag: Optional[str] = "from" + content_tag: Optional[str] = "value" + user_tag: Optional[str] = "human" + assistant_tag: Optional[str] = "gpt" + observation_tag: Optional[str] = "observation" + function_tag: Optional[str] = "function_call" + + def __repr__(self) -> str: + return self.dataset_name + + +def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: + dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else [] + try: + with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f: + dataset_info = json.load(f) + except Exception as err: + if data_args.dataset is not None: + raise ValueError("Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))) + dataset_info = None + + if data_args.interleave_probs is not None: + data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")] + + dataset_list: List[DatasetAttr] = [] + for name in dataset_names: + if name not in dataset_info: + raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) + + has_hf_url = "hf_hub_url" in dataset_info[name] + has_ms_url = "ms_hub_url" in dataset_info[name] + + if has_hf_url or has_ms_url: + if (use_modelscope() and has_ms_url) or (not has_hf_url): + dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) + else: + dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + elif "script_url" in dataset_info[name]: + dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + else: + dataset_attr = DatasetAttr( + "file", + dataset_name=dataset_info[name]["file_name"], + dataset_sha1=dataset_info[name].get("file_sha1", None) + ) + + dataset_attr.subset = dataset_info[name].get("subset", None) + dataset_attr.folder = dataset_info[name].get("folder", None) + dataset_attr.ranking = dataset_info[name].get("ranking", False) + dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca") + + if "columns" in dataset_info[name]: + if dataset_attr.formatting == "alpaca": + column_names = ["prompt", "query", "response", "history"] + else: + column_names = ["messages", "tool"] + + column_names += ["system"] + for column_name in column_names: + setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None)) + + if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: + for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]: + setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None)) + + dataset_list.append(dataset_attr) + + return dataset_list diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 6f98c8f5..45ae8626 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -1,272 +1,241 @@ -import os -import tiktoken +from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple -from llmtuner.data.template import get_template_and_fix_tokenizer -from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.extras.logging import get_logger +from ..extras.constants import IGNORE_INDEX +from ..extras.logging import get_logger if TYPE_CHECKING: - from datasets import Dataset, IterableDataset from transformers import Seq2SeqTrainingArguments from transformers.tokenization_utils import PreTrainedTokenizer - from llmtuner.hparams import DataArguments + + from ..hparams import DataArguments + from .template import Template logger = get_logger(__name__) -def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: - for i in range(len(examples["prompt"])): - query, response = examples["prompt"][i], examples["response"][i] - query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query - history = examples["history"][i] if "history" in examples else None - system = examples["system"][i] if "system" in examples else None - yield query, response, history, system - - -def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]: - max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len))) - max_target_len = max(max_target_len, data_args.reserved_label_len) - max_source_len = data_args.cutoff_len - max_target_len - return max_source_len, max_target_len - - -def preprocess_dataset( - dataset: Union["Dataset", "IterableDataset"], +def preprocess_pretrain_dataset( + examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", + data_args: "DataArguments" +) -> Dict[str, List[List[int]]]: + # build grouped texts with format `X1 X2 X3 ...` + text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))] + tokenized_examples = tokenizer(text_examples, add_special_tokens=False) + for i in range(len(tokenized_examples["input_ids"])): + tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id] + tokenized_examples["attention_mask"][i] += [1] + + concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} + total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) + block_size = data_args.cutoff_len + # we drop the small remainder, and if the total_length < block_size, we exclude this batch + total_length = (total_length // block_size) * block_size + # split by chunks of cutoff_len + result = { + k: [t[i: i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + return result + + +def preprocess_supervised_dataset( + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", - stage: Literal["pt", "sft", "rm", "ppo"] -) -> Union["Dataset", "IterableDataset"]: - template = get_template_and_fix_tokenizer(data_args.template, tokenizer) +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - if data_args.cache_path is not None and os.path.exists(data_args.cache_path): - return dataset # already preprocessed + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1: + continue - if data_args.train_on_prompt and template.efficient_eos: - raise ValueError("Current template does not support `train_on_prompt`.") - - def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: - # build grouped texts with format `X1 X2 X3 ...` - if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) - kwargs = dict(allowed_special="all") - else: - kwargs = dict(add_special_tokens=True) - - if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer - add_eos_token_flag = getattr(tokenizer, "add_eos_token") - setattr(tokenizer, "add_eos_token", True) - - tokenized_examples = tokenizer(examples["prompt"], **kwargs) - concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} - total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) - block_size = data_args.cutoff_len - # we drop the small remainder, and if the total_length < block_size, we exclude this batch - total_length = (total_length // block_size) * block_size - # split by chunks of cutoff_len - result = { - k: [t[i: i + block_size] for i in range(0, total_length, block_size)] - for k, t in concatenated_examples.items() - } - # make sure the saved tokenizer is the same as the original one - if hasattr(tokenizer, "add_eos_token"): - setattr(tokenizer, "add_eos_token", add_eos_token_flag) - return result - - def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: - # build inputs with format ` X Y ` and labels with format ` ... Y ` - # for multiturn examples, we only mask the prompt part in each prompt-response pair. - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - - for query, response, history, system in construct_example(examples): - if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): - continue - - input_ids, labels = [], [] - for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( - tokenizer, query, response, history, system - )): - source_len, target_len = len(source_ids), len(target_ids) - max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args) - if source_len > max_source_len: - source_ids = source_ids[:max_source_len] - if target_len > max_target_len: - target_ids = target_ids[:max_target_len] - - if data_args.train_on_prompt: - source_mask = source_ids - elif turn_idx != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) - else: - source_mask = [IGNORE_INDEX] * len(source_ids) - - input_ids += source_ids + target_ids - labels += source_mask + target_ids - - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] - - if len(input_ids) > data_args.cutoff_len: - input_ids = input_ids[:data_args.cutoff_len] - labels = labels[:data_args.cutoff_len] - - model_inputs["input_ids"].append(input_ids) - model_inputs["attention_mask"].append([1] * len(input_ids)) - model_inputs["labels"].append(labels) - - return model_inputs - - def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: - # build inputs with format ` X1 Y1 X2 Y2 ` - # and labels with format ` ... Y1 ... Y2 ` - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = [], [] - for query, response, history, system in construct_example(examples): - if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): - continue + for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( + tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len + )): + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) - for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( - tokenizer, query, response, history, system - )): - if data_args.train_on_prompt: - source_mask = source_ids - elif turn_idx != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) - else: - source_mask = [IGNORE_INDEX] * len(source_ids) - input_ids += source_ids + target_ids - labels += source_mask + target_ids + input_ids += source_ids + target_ids + labels += source_mask + target_ids if template.efficient_eos: input_ids += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id] - total_length = len(input_ids) - block_size = data_args.cutoff_len - # we drop the small remainder, and if the total_length < block_size, we exclude this batch - total_length = (total_length // block_size) * block_size - # split by chunks of cutoff_len - for i in range(0, total_length, block_size): - model_inputs["input_ids"].append(input_ids[i: i + block_size]) - model_inputs["attention_mask"].append([1] * block_size) - model_inputs["labels"].append(labels[i: i + block_size]) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) - return model_inputs + return model_inputs - def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: - # build inputs with format ` X` and labels with format `Y ` - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - for query, response, history, system in construct_example(examples): - if not (isinstance(query, str) and query != ""): - continue +def preprocess_packed_supervised_dataset( + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X1 Y1 X2 Y2 ` + # and labels with format ` ... Y1 ... Y2 ` + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + input_ids, labels = [], [] + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1: + continue - input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system) + messages = examples["prompt"][i] + examples["response"][i] + for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( + tokenizer, messages, examples["system"][i], examples["tool"][i], 1_000_000 + )): + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) - if template.efficient_eos: - labels += [tokenizer.eos_token_id] + input_ids += source_ids + target_ids + labels += source_mask + target_ids - if len(input_ids) > data_args.cutoff_len: - input_ids = input_ids[:data_args.cutoff_len] - if len(labels) > data_args.cutoff_len: - labels = labels[:data_args.cutoff_len] + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] - model_inputs["input_ids"].append(input_ids) - model_inputs["attention_mask"].append([1] * len(input_ids)) - model_inputs["labels"].append(labels) + total_length = len(input_ids) + block_size = data_args.cutoff_len + # we drop the small remainder, and if the total_length < block_size, we exclude this batch + total_length = (total_length // block_size) * block_size + # split by chunks of cutoff_len + for i in range(0, total_length, block_size): + model_inputs["input_ids"].append(input_ids[i: i + block_size]) + model_inputs["attention_mask"].append([1] * block_size) + model_inputs["labels"].append(labels[i: i + block_size]) - return model_inputs + return model_inputs - def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: - # build input pairs with format ` X`, `Y1 ` and `Y2 ` - model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} - for query, response, history, system in construct_example(examples): - if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1): - continue - prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system) - _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system) +def preprocess_unsupervised_dataset( + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X` and labels with format `Y ` + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - if template.efficient_eos: - chosen_ids += [tokenizer.eos_token_id] - rejected_ids += [tokenizer.eos_token_id] + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1: + continue - source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids)) - max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args) - if source_len > max_source_len: - prompt_ids = prompt_ids[:max_source_len] - if target_len > max_target_len: - chosen_ids = chosen_ids[:max_target_len] - rejected_ids = rejected_ids[:max_target_len] - - model_inputs["prompt_ids"].append(prompt_ids) - model_inputs["chosen_ids"].append(chosen_ids) - model_inputs["rejected_ids"].append(rejected_ids) - - return model_inputs - - def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None: - print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) - print("label_ids:\n{}".format(example["labels"])) - print("labels:\n{}".format( - tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) - )) - - def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None: - print("prompt_ids:\n{}".format(example["prompt_ids"])) - print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False))) - print("chosen_ids:\n{}".format(example["chosen_ids"])) - print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False))) - print("rejected_ids:\n{}".format(example["rejected_ids"])) - print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False))) - - def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None: - print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) - - if stage == "pt": - preprocess_func = preprocess_pretrain_dataset - print_function = print_unsupervised_dataset_example - elif stage == "sft" and not training_args.predict_with_generate: - preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset - print_function = print_supervised_dataset_example - elif stage == "rm": - preprocess_func = preprocess_pairwise_dataset - print_function = print_pairwise_dataset_example - else: - preprocess_func = preprocess_unsupervised_dataset - print_function = print_unsupervised_dataset_example - - with training_args.main_process_first(desc="dataset map pre-processing"): - column_names = list(next(iter(dataset)).keys()) - kwargs = {} - if not data_args.streaming: - kwargs = dict( - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache), - desc="Running tokenizer on dataset" - ) - - dataset = dataset.map( - preprocess_func, - batched=True, - remove_columns=column_names, - **kwargs + messages = examples["prompt"][i] + examples["response"][i] + input_ids, labels = template.encode_oneturn( + tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len ) - if data_args.cache_path is not None and not os.path.exists(data_args.cache_path): - if training_args.should_save: - dataset.save_to_disk(data_args.cache_path) - logger.info("Dataset cache saved at {}.".format(data_args.cache_path)) + if template.efficient_eos: + labels += [tokenizer.eos_token_id] - if training_args.should_log: - try: - print_function(next(iter(dataset))) - except StopIteration: - raise RuntimeError("Empty dataset!") + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) - return dataset + return model_inputs + + +def preprocess_pairwise_dataset( + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build input pairs with format ` X`, `Y1 ` and `Y2 ` + model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) < 2: + continue + + chosen_messages = examples["prompt"][i] + [examples["response"][i][0]] + rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] + + prompt_ids, chosen_ids = template.encode_oneturn( + tokenizer, chosen_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len + ) + _, rejected_ids = template.encode_oneturn( + tokenizer, rejected_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len + ) + + if template.efficient_eos: + chosen_ids += [tokenizer.eos_token_id] + rejected_ids += [tokenizer.eos_token_id] + + model_inputs["prompt_ids"].append(prompt_ids) + model_inputs["chosen_ids"].append(chosen_ids) + model_inputs["rejected_ids"].append(rejected_ids) + + return model_inputs + + +def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format( + tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) + )) + + +def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + print("prompt_ids:\n{}".format(example["prompt_ids"])) + print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False))) + print("chosen_ids:\n{}".format(example["chosen_ids"])) + print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False))) + print("rejected_ids:\n{}".format(example["rejected_ids"])) + print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False))) + + +def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + + +def get_preprocess_and_print_func( + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo"], +) -> Tuple[Callable, Callable]: + if stage == "pt": + preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args) + print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) + elif stage == "sft" and not training_args.predict_with_generate: + if data_args.sft_packing: + preprocess_func = partial( + preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args + ) + else: + preprocess_func = partial( + preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args + ) + + print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) + elif stage == "rm": + preprocess_func = partial(preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args) + print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) + else: + preprocess_func = partial(preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args) + print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) + + return preprocess_func, print_function diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 9fa4be57..5690e773 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -1,8 +1,10 @@ -import tiktoken from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union + +from ..extras.logging import get_logger +from .utils import Role +from .formatter import StringFormatter, FunctionFormatter, ToolFormatter -from llmtuner.extras.logging import get_logger if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -14,28 +16,30 @@ logger = get_logger(__name__) @dataclass class Template: - prefix: List[Union[str, Dict[str, str]]] - prompt: List[Union[str, Dict[str, str]]] + format_user: Callable + format_assistant: Callable + format_system: Callable + format_tool: Callable + format_observation: Callable + format_function: Callable system: str - sep: List[Union[str, Dict[str, str]]] + separator: List[Union[str, Dict[str, str]]] stop_words: List[str] - use_history: bool efficient_eos: bool replace_eos: bool def encode_oneturn( self, tokenizer: "PreTrainedTokenizer", - query: str, - resp: str, - history: Optional[List[Tuple[str, str]]] = None, - system: Optional[str] = None + messages: List[Dict[str, str]], + system: str, + tool: str, + cutoff_len: int ) -> Tuple[List[int], List[int]]: r""" Returns a single pair of token ids representing prompt and response respectively. """ - system, history = self._format(query, resp, history, system) - encoded_pairs = self._encode(tokenizer, system, history) + encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len) prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: prompt_ids = prompt_ids + query_ids + resp_ids @@ -46,109 +50,75 @@ class Template: def encode_multiturn( self, tokenizer: "PreTrainedTokenizer", - query: str, - resp: str, - history: Optional[List[Tuple[str, str]]] = None, - system: Optional[str] = None + messages: List[Dict[str, str]], + system: str, + tool: str, + cutoff_len: int ) -> List[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ - system, history = self._format(query, resp, history, system) - encoded_pairs = self._encode(tokenizer, system, history) + encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len) return encoded_pairs - def _format( - self, - query: str, - resp: str, - history: Optional[List[Tuple[str, str]]] = None, - system: Optional[str] = None - ) -> Tuple[str, List[Tuple[str, str]]]: - r""" - Aligns inputs to the standard format. - """ - system = system or self.system # use system if provided - history = history if (history and self.use_history) else [] - history = history + [(query, resp)] - return system, history - - def _get_special_ids( - self, - tokenizer: "PreTrainedTokenizer" - ) -> Tuple[List[int], List[int]]: - if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True): - bos_ids = [tokenizer.bos_token_id] - else: # baichuan, gpt2, qwen, yi models have no bos token - bos_ids = [] - - if tokenizer.eos_token_id is None: - raise ValueError("EOS token is required.") - - if self.efficient_eos: - eos_ids = [] - else: - eos_ids = [tokenizer.eos_token_id] - - return bos_ids, eos_ids - def _encode( self, tokenizer: "PreTrainedTokenizer", + messages: List[Dict[str, str]], system: str, - history: List[Tuple[str, str]] + tool: str, + cutoff_len: int ) -> List[Tuple[List[int], List[int]]]: r""" Encodes formatted inputs to pairs of token ids. - Turn 0: bos + prefix + sep + query resp + eos - Turn t: sep + bos + query resp + eos + Turn 0: system + query resp + eos + Turn t: sep + query resp + eos """ - bos_ids, eos_ids = self._get_special_ids(tokenizer) - sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) - encoded_pairs = [] - for turn_idx, (query, resp) in enumerate(history): - if turn_idx == 0: - prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system) - if len(prefix_ids) != 0: # has prefix - prefix_ids = bos_ids + prefix_ids + sep_ids - else: - prefix_ids = bos_ids - else: - prefix_ids = sep_ids + bos_ids + system = system or self.system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + if i == 0 and (system or tool): + tool_text = self.format_tool(content=tool)[0] if tool else "" + elements += self.format_system(content=(system + tool_text)) + elif i > 0 and i % 2 == 0: + elements += self.separator - query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx+1)) - resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) - encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids)) - return encoded_pairs + if message["role"] == Role.USER: + elements += self.format_user(content=message["content"], idx=str(i // 2)) + elif message["role"] == Role.ASSISTANT: + elements += self.format_assistant(content=message["content"]) + elif message["role"] == Role.OBSERVATION: + elements += self.format_observation(content=message["content"]) + elif message["role"] == Role.FUNCTION: + elements += self.format_function(content=message["content"]) - def _convert_inputs_to_ids( + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)] + + def _convert_elements_to_ids( self, tokenizer: "PreTrainedTokenizer", - context: List[Union[str, Dict[str, str]]], - system: Optional[str] = None, - query: Optional[str] = None, - idx: Optional[str] = None + elements: List[Union[str, Dict[str, str]]] ) -> List[int]: r""" - Converts context to token ids. + Converts elements to token ids. """ - if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) - kwargs = dict(allowed_special="all") - else: - kwargs = dict(add_special_tokens=False) - token_ids = [] - for elem in context: + for elem in elements: if isinstance(elem, str): - elem = elem.replace("{{system}}", system, 1) if system is not None else elem - elem = elem.replace("{{query}}", query, 1) if query is not None else elem - elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem if len(elem) != 0: - token_ids = token_ids + tokenizer.encode(elem, **kwargs) + token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False) elif isinstance(elem, dict): token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] + elif isinstance(elem, set): + if "bos_token" in elem and tokenizer.bos_token_id: + token_ids = token_ids + [tokenizer.bos_token_id] + elif "eos_token" in elem and tokenizer.eos_token_id: + token_ids = token_ids + [tokenizer.eos_token_id] else: - raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem))) + raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) return token_ids @@ -159,23 +129,39 @@ class Llama2Template(Template): def _encode( self, tokenizer: "PreTrainedTokenizer", + messages: List[Dict[str, str]], system: str, - history: List[Tuple[str, str]] + tool: str, + cutoff_len: int ) -> List[Tuple[List[int], List[int]]]: r""" Encodes formatted inputs to pairs of token ids. - Turn 0: bos + prefix + query resp + eos - Turn t: bos + query resp + eos + Turn 0: system + query resp + eos + Turn t: sep + query resp + eos """ - bos_ids, eos_ids = self._get_special_ids(tokenizer) - encoded_pairs = [] - for turn_idx, (query, resp) in enumerate(history): - if turn_idx == 0: # llama2 template has no sep_ids - query = self.prefix[0].replace("{{system}}", system) + query - query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) - resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) - encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) - return encoded_pairs + system = system or self.system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + system_text = "" + if i == 0 and (system or tool): + tool_text = self.format_tool(content=tool)[0] if tool else "" + system_text = self.format_system(content=(system + tool_text))[0] + elif i > 0 and i % 2 == 0: + elements += self.separator + + if message["role"] == Role.USER: + elements += self.format_user(content=system_text + message["content"], idx=str(i // 2)) + elif message["role"] == Role.ASSISTANT: + elements += self.format_assistant(content=message["content"]) + elif message["role"] == Role.OBSERVATION: + elements += self.format_observation(content=message["content"]) + elif message["role"] == Role.FUNCTION: + elements += self.format_function(content=message["content"]) + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)] templates: Dict[str, Template] = {} @@ -183,23 +169,33 @@ templates: Dict[str, Template] = {} def register_template( name: str, - prefix: List[Union[str, Dict[str, str]]], - prompt: List[Union[str, Dict[str, str]]], - system: str, - sep: List[Union[str, Dict[str, str]]], + format_user: Optional[Callable] = None, + format_assistant: Optional[Callable] = None, + format_system: Optional[Callable] = None, + format_tool: Optional[Callable] = None, + format_observation: Optional[Callable] = None, + format_function: Optional[Callable] = None, + system: Optional[str] = "", + separator: Optional[List[Union[str, Dict[str, str]]]] = "", stop_words: Optional[List[str]] = [], - use_history: Optional[bool] = True, efficient_eos: Optional[bool] = False, replace_eos: Optional[bool] = False ) -> None: template_class = Llama2Template if name.startswith("llama2") else Template templates[name] = template_class( - prefix=prefix, - prompt=prompt, + format_user=format_user or StringFormatter(container=["{{content}}"]), + format_assistant=format_assistant or StringFormatter(container=[ + "{{content}}", {"eos_token"} + ]), + format_system=format_system or StringFormatter(container=["{{content}}"]), + format_tool=format_tool or ToolFormatter(type="default"), + format_observation=format_observation or format_user, + format_function=format_function or FunctionFormatter(container=[ + "Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"} + ]), system=system, - sep=sep, + separator=separator, stop_words=stop_words, - use_history=use_history, efficient_eos=efficient_eos, replace_eos=replace_eos ) @@ -244,17 +240,14 @@ def get_template_and_fix_tokenizer( register_template( name="alpaca", - prefix=[ - "{{system}}" - ], - prompt=[ - "### Instruction:\n{{query}}\n\n### Response:\n" - ], + format_user=StringFormatter(container=[ + "### Instruction:\n{{content}}\n\n### Response:\n" + ]), system=( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request." ), - sep=[ + separator=[ "\n\n" ] ) @@ -262,17 +255,14 @@ register_template( register_template( name="aquila", - prefix=[ - "{{system}}" - ], - prompt=[ - "Human: {{query}}###Assistant:" - ], + format_user=StringFormatter(container=[ + "Human: {{content}}###Assistant:" + ]), system=( "A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions." ), - sep=[ + separator=[ "###" ], stop_words=[ @@ -284,46 +274,32 @@ register_template( register_template( name="baichuan", - prefix=[ - "{{system}}" - ], - prompt=[ - {"token": ""}, # user token - "{{query}}", - {"token": ""} # assistant token - ], - system="", - sep=[], + format_user=StringFormatter(container=[ + {"token": ""}, + "{{content}}", + {"token": ""} + ]), efficient_eos=True ) register_template( name="baichuan2", - prefix=[ - "{{system}}" - ], - prompt=[ - {"token": ""}, # user token - "{{query}}", - {"token": ""} # assistant token - ], - system="", - sep=[], + format_user=StringFormatter(container=[ + {"token": ""}, + "{{content}}", + {"token": ""} + ]), efficient_eos=True ) register_template( name="belle", - prefix=[ - "{{system}}" - ], - prompt=[ - "Human: {{query}}\n\nBelle: " - ], - system="", - sep=[ + format_user=StringFormatter(container=[ + "Human: {{content}}\n\nBelle: " + ]), + separator=[ "\n\n" ] ) @@ -331,31 +307,25 @@ register_template( register_template( name="bluelm", - prefix=[ - "{{system}}" - ], - prompt=[ + format_user=StringFormatter(container=[ {"token": "[|Human|]:"}, - "{{query}}", + "{{content}}", {"token": "[|AI|]:"} - ], - system="", - sep=[] + ]) ) register_template( name="chatglm2", - prefix=[ + format_user=StringFormatter(container=[ + "[Round {{idx}}]\n\n问:{{content}}\n\n答:" + ]), + format_system=StringFormatter(container=[ {"token": "[gMASK]"}, {"token": "sop"}, - "{{system}}" - ], - prompt=[ - "[Round {{idx}}]\n\n问:{{query}}\n\n答:" - ], - system="", - sep=[ + "{{content}}" + ]), + separator=[ "\n\n" ], efficient_eos=True @@ -364,53 +334,35 @@ register_template( register_template( name="chatglm3", - prefix=[ - {"token": "[gMASK]"}, - {"token": "sop"}, - {"token": "<|system|>"}, - "\n", - "{{system}}" - ], - prompt=[ + format_user=StringFormatter(container=[ {"token": "<|user|>"}, "\n", - "{{query}}", - {"token": "<|assistant|>"}, - "\n" # add an extra newline to avoid error in ChatGLM's process_response method - ], - system=( - "You are ChatGLM3, a large language model trained by Zhipu.AI. " - "Follow the user's instructions carefully. Respond using markdown." - ), - sep=[], - stop_words=[ - "<|user|>", - "<|observation|>" - ], - efficient_eos=True -) - - -register_template( - name="chatglm3_raw", # the raw template for tool tuning - prefix=[ - {"token": "[gMASK]"}, - {"token": "sop"}, - {"token": "<|system|>"}, - "\n", - "{{system}}" - ], - prompt=[ - {"token": "<|user|>"}, - "\n", - "{{query}}", + "{{content}}", {"token": "<|assistant|>"} - ], + ]), + format_assistant=StringFormatter(container=[ + "\n" + "{{content}}" + ]), + format_system=StringFormatter(container=[ + {"token": "[gMASK]"}, + {"token": "sop"}, + {"token": "<|system|>"}, + "\n", + "{{content}}" + ]), + format_observation=StringFormatter(container=[ + {"token": "<|observation|>"}, + "\n", + "{{content}}" + ]), + format_function=FunctionFormatter(container=[ + "{{name}}\n{{arguments}}" + ]), system=( "You are ChatGLM3, a large language model trained by Zhipu.AI. " "Follow the user's instructions carefully. Respond using markdown." ), - sep=[], stop_words=[ "<|user|>", "<|observation|>" @@ -421,47 +373,34 @@ register_template( register_template( name="codegeex2", - prefix=[ + format_system=StringFormatter(container=[ {"token": "[gMASK]"}, {"token": "sop"}, - "{{system}}" - ], - prompt=[ - "{{query}}" - ], - system="", - sep=[] + "{{content}}" + ]) ) register_template( name="deepseek", - prefix=[ - "{{system}}" - ], - prompt=[ - "User: {{query}}\n\nAssistant:" - ], - system="", - sep=[] + format_user=StringFormatter(container=[ + "User: {{content}}\n\nAssistant:" + ]) ) register_template( name="deepseekcoder", - prefix=[ - "{{system}}" - ], - prompt=[ - "### Instruction:\n{{query}}\n### Response:\n" - ], + format_user=StringFormatter(container=[ + "### Instruction:\n{{content}}\n### Response:\n" + ]), system=( "You are an AI programming assistant, utilizing the Deepseek Coder model, " "developed by Deepseek Company, and you only answer questions related to computer science. " "For politically sensitive questions, security and privacy issues, " "and other non-computer science questions, you will refuse to answer\n" ), - sep=[ + separator=[ "\n", {"token": "<|EOT|>"}, "\n" @@ -475,17 +414,14 @@ register_template( register_template( name="default", - prefix=[ - "{{system}}" - ], - prompt=[ - "Human: {{query}}\nAssistant:" - ], + format_user=StringFormatter(container=[ + "Human: {{content}}\nAssistant: " + ]), system=( "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions." + "The assistant gives helpful, detailed, and polite answers to the user's questions.\n" ), - sep=[ + separator=[ "\n" ] ) @@ -493,14 +429,10 @@ register_template( register_template( name="falcon", - prefix=[ - "{{system}}" - ], - prompt=[ - "User: {{query}}\nFalcon:" - ], - system="", - sep=[ + format_user=StringFormatter(container=[ + "User: {{content}}\nFalcon:" + ]), + separator=[ "\n" ], efficient_eos=True @@ -509,16 +441,12 @@ register_template( register_template( name="intern", - prefix=[ - "{{system}}" - ], - prompt=[ - "<|User|>:{{query}}", + format_user=StringFormatter(container=[ + "<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:" - ], - system="", - sep=[ + ]), + separator=[ {"token": ""}, "\n" ], @@ -529,14 +457,44 @@ register_template( ) +register_template( + name="intern2", + format_user=StringFormatter(container=[ + {"token": "[UNUSED_TOKEN_146]"}, + "user\n{{content}}", + {"token": "[UNUSED_TOKEN_145]"}, + "\n", + {"token": "[UNUSED_TOKEN_146]"}, + "assistant\n" + ]), + format_system=StringFormatter(container=[ + {"token": "[UNUSED_TOKEN_146]"}, + "system\n{{content}}", + {"token": "[UNUSED_TOKEN_145]"}, + "\n" + ]), + system=( + "You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) is a conversational language model that is developed " + "by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen " + "by the user such as English and 中文." + ), + separator=[ + {"token": "[UNUSED_TOKEN_145]"}, + "\n" + ], + stop_words=[ + "[UNUSED_TOKEN_145]" + ], + efficient_eos=True +) + + register_template( name="llama2", - prefix=[ - "<>\n{{system}}\n<>\n\n" - ], - prompt=[ - "[INST] {{query}} [/INST]" - ], + format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]), + format_system=StringFormatter(container=["<>\n{{content}}\n<>\n\n"]), system=( "You are a helpful, respectful and honest assistant. " "Always answer as helpfully as possible, while being safe. " @@ -546,49 +504,32 @@ register_template( "If a question does not make any sense, or is not factually coherent, " "explain why instead of answering something not correct. " "If you don't know the answer to a question, please don't share false information." - ), - sep=[] + ) ) register_template( name="llama2_zh", - prefix=[ - "<>\n{{system}}\n<>\n\n" - ], - prompt=[ - "[INST] {{query}} [/INST]" - ], - system="You are a helpful assistant. 你是一个乐于助人的助手。", - sep=[] + format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]), + format_system=StringFormatter(container=["<>\n{{content}}\n<>\n\n"]), + system="You are a helpful assistant. 你是一个乐于助人的助手。" ) register_template( name="mistral", - prefix=[ - "{{system}}" - ], - prompt=[ - "[INST] {{query}} [/INST]" - ], - system="", - sep=[] + format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]) ) register_template( name="openchat", - prefix=[ - "{{system}}" - ], - prompt=[ - "GPT4 Correct User: {{query}}", + format_user=StringFormatter(container=[ + "GPT4 Correct User: {{content}}", {"token": "<|end_of_turn|>"}, "GPT4 Correct Assistant:" - ], - system="", - sep=[ + ]), + separator=[ {"token": "<|end_of_turn|>"} ], stop_words=[ @@ -600,14 +541,14 @@ register_template( register_template( name="qwen", - prefix=[ - "<|im_start|>system\n{{system}}<|im_end|>" - ], - prompt=[ - "<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n" - ], + format_user=StringFormatter(container=[ + "<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n" + ]), + format_system=StringFormatter(container=[ + "<|im_start|>system\n{{content}}<|im_end|>\n" + ]), system="You are a helpful assistant.", - sep=[ + separator=[ "\n" ], stop_words=[ @@ -619,32 +560,28 @@ register_template( register_template( name="solar", - prefix=[ - "{{system}}" - ], - prompt=[ - "### User:\n{{query}}\n\n### Assistant:\n" - ], - system="", - sep=[] + format_user=StringFormatter(container=[ + "### User:\n{{content}}\n\n### Assistant:\n" + ]) ) register_template( name="starchat", - prefix=[ - {"token": "<|system|>"}, - "\n{{system}}", - ], - prompt=[ + format_user=StringFormatter(container=[ {"token": "<|user|>"}, - "\n{{query}}", + "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"} - ], - system="", - sep=[ + ]), + format_system=StringFormatter(container=[ + {"token": "<|system|>"}, + "\n{{content}}", + {"token": "<|end|>"}, + "\n" + ]), + separator=[ {"token": "<|end|>"}, "\n" ], @@ -656,75 +593,55 @@ register_template( register_template( - name="vanilla", - prefix=[], - prompt=[ - "{{query}}" - ], - system="", - sep=[], - use_history=False + name="vanilla" ) register_template( name="vicuna", - prefix=[ - "{{system}}" - ], - prompt=[ - "USER: {{query}} ASSISTANT:" - ], + format_user=StringFormatter(container=[ + "USER: {{content}} ASSISTANT:" + ]), system=( "A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions." - ), - sep=[] + ) ) register_template( name="xuanyuan", - prefix=[ - "{{system}}" - ], - prompt=[ - "Human: {{query}} Assistant:" - ], + format_user=StringFormatter(container=[ + "Human: {{content}} Assistant:" + ]), system=( "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头," "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、" "不安全、有争议、政治敏感等相关的话题、问题和指示。\n" - ), - sep=[] + ) ) register_template( name="xverse", - prefix=[ - "{{system}}" - ], - prompt=[ - "Human: {{query}}\n\nAssistant: " - ], - system="", - sep=[] + format_user=StringFormatter(container=[ + "Human: {{content}}\n\nAssistant: " + ]) ) register_template( name="yayi", - prefix=[ - {"token": "<|System|>"}, - ":\n{{system}}" - ], - prompt=[ + format_user=StringFormatter(container=[ {"token": "<|Human|>"}, - ":\n{{query}}\n\n", + ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":" - ], + ]), + format_system=StringFormatter(container=[ + {"token": "<|System|>"}, + ":\n{{content}}\n\n" + ]), system=( "You are a helpful, respectful and honest assistant named YaYi " "developed by Beijing Wenge Technology Co.,Ltd. " @@ -736,7 +653,7 @@ register_template( "explain why instead of answering something not correct. " "If you don't know the answer to a question, please don't share false information." ), - sep=[ + separator=[ "\n\n" ], stop_words=[ @@ -747,14 +664,10 @@ register_template( register_template( name="yi", - prefix=[ - "{{system}}" - ], - prompt=[ - "<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n" - ], - system="", - sep=[ + format_user=StringFormatter(container=[ + "<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n" + ]), + separator=[ "\n" ], stop_words=[ @@ -766,15 +679,11 @@ register_template( register_template( name="yuan", - prefix=[ - "{{system}}" - ], - prompt=[ - "{{query}}", + format_user=StringFormatter(container=[ + "{{content}}", {"token": ""} - ], - system="", - sep=[ + ]), + separator=[ "\n" ], stop_words=[ @@ -786,30 +695,25 @@ register_template( register_template( name="zephyr", - prefix=[ - "<|system|>\n{{system}}", - ], - prompt=[ - "<|user|>\n{{query}}<|assistant|>" - ], - system="You are a friendly chatbot who always responds in the style of a pirate", - sep=[] + format_user=StringFormatter(container=[ + "<|user|>\n{{content}}<|assistant|>" + ]), + format_system=StringFormatter(container=[ + "<|system|>\n{{content}}", + ]), + system="You are a friendly chatbot who always responds in the style of a pirate" ) register_template( name="ziya", - prefix=[ - "{{system}}" - ], - prompt=[ + format_user=StringFormatter(container=[ {"token": ""}, - ":{{query}}\n", + ":{{content}}\n", {"token": ""}, ":" - ], - system="", - sep=[ + ]), + separator=[ "\n" ] ) diff --git a/src/llmtuner/data/utils.py b/src/llmtuner/data/utils.py index 9dfe4dc3..106e87a7 100644 --- a/src/llmtuner/data/utils.py +++ b/src/llmtuner/data/utils.py @@ -1,7 +1,8 @@ import hashlib -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from enum import Enum, unique +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union -from llmtuner.extras.logging import get_logger +from ..extras.logging import get_logger if TYPE_CHECKING: from datasets import Dataset, IterableDataset @@ -12,6 +13,14 @@ if TYPE_CHECKING: logger = get_logger(__name__) +@unique +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + OBSERVATION = "observation" + FUNCTION = "function" + + def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: if file_sha1 is None: logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") @@ -27,6 +36,13 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) +def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]: + max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len))) + max_target_len = max(max_target_len, data_args.reserved_label_len) + max_source_len = data_args.cutoff_len - max_target_len + return max_source_len, max_target_len + + def split_dataset( dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", diff --git a/src/llmtuner/eval/__init__.py b/src/llmtuner/eval/__init__.py index a7c9a127..95ce0377 100644 --- a/src/llmtuner/eval/__init__.py +++ b/src/llmtuner/eval/__init__.py @@ -1 +1,4 @@ -from llmtuner.eval.evaluator import Evaluator +from .evaluator import Evaluator + + +__all__ = ["Evaluator"] diff --git a/src/llmtuner/eval/evaluator.py b/src/llmtuner/eval/evaluator.py index 0bf4c3f4..251dfc0b 100644 --- a/src/llmtuner/eval/evaluator.py +++ b/src/llmtuner/eval/evaluator.py @@ -3,7 +3,6 @@ import os import json import torch -import tiktoken import numpy as np from tqdm import tqdm, trange from typing import Any, Dict, List, Optional @@ -11,10 +10,11 @@ from typing import Any, Dict, List, Optional from datasets import load_dataset from transformers.utils import cached_file -from llmtuner.data.template import get_template_and_fix_tokenizer -from llmtuner.eval.template import get_eval_template -from llmtuner.extras.constants import CHOICES, SUBJECTS -from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer +from ..data import get_template_and_fix_tokenizer +from .template import get_eval_template +from ..extras.constants import CHOICES, SUBJECTS +from ..hparams import get_eval_args +from ..model import dispatch_model, load_model_and_tokenizer class Evaluator: @@ -26,15 +26,9 @@ class Evaluator: self.model = dispatch_model(self.model) self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer) self.eval_template = get_eval_template(self.eval_args.lang) - self.choice_inputs = self._encode_choices() - - def _encode_choices(self) -> List[int]: - if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) - kwargs = dict(allowed_special="all") - else: - kwargs = dict(add_special_tokens=False) - - return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES] + self.choice_inputs = [self.tokenizer.encode( + self.eval_template.prefix + ch, add_special_tokens=False + )[-1] for ch in CHOICES] @torch.inference_mode() def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: diff --git a/src/llmtuner/eval/template.py b/src/llmtuner/eval/template.py index 2251ad57..924a3c8b 100644 --- a/src/llmtuner/eval/template.py +++ b/src/llmtuner/eval/template.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Tuple -from llmtuner.extras.constants import CHOICES +from ..extras.constants import CHOICES if TYPE_CHECKING: from datasets import Dataset @@ -44,7 +44,7 @@ class EvalTemplate: return query.strip(), resp, history -eval_templates: Dict[str, EvalTemplate] = {} +eval_templates: Dict[str, "EvalTemplate"] = {} def register_eval_template( @@ -62,7 +62,7 @@ def register_eval_template( ) -def get_eval_template(name: str) -> EvalTemplate: +def get_eval_template(name: str) -> "EvalTemplate": eval_template = eval_templates.get(name, None) assert eval_template is not None, "Template {} does not exist.".format(name) return eval_template diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 17ab5dc1..b97d0168 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -6,9 +6,9 @@ from datetime import timedelta from transformers import TrainerCallback from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR -from llmtuner.extras.constants import LOG_FILE_NAME -from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import fix_valuehead_checkpoint +from .constants import LOG_FILE_NAME +from .logging import get_logger +from .misc import fix_valuehead_checkpoint if TYPE_CHECKING: diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 546e3d5f..8b13aac8 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -5,6 +5,8 @@ from typing import Dict, Optional CHOICES = ["A", "B", "C", "D"] +DATA_CONFIG = "dataset_info.json" + DEFAULT_MODULE = defaultdict(str) DEFAULT_TEMPLATE = defaultdict(str) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index dee101ec..2a1199e4 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -13,8 +13,8 @@ from transformers.utils import ( ) from peft import PeftModel -from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME -from llmtuner.extras.logging import get_logger +from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME +from .logging import get_logger _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index 1fb7ed3b..a9f5da28 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -10,7 +10,7 @@ try: except ImportError: print("Please upgrade `transformers`.") -from llmtuner.extras.packages import is_flash_attn2_available +from ..packages import is_flash_attn2_available if is_flash_attn2_available(): diff --git a/src/llmtuner/extras/ploting.py b/src/llmtuner/extras/ploting.py index cf2c72ac..65b3bf42 100644 --- a/src/llmtuner/extras/ploting.py +++ b/src/llmtuner/extras/ploting.py @@ -4,8 +4,8 @@ import json from typing import List, Optional from transformers.trainer import TRAINER_STATE_NAME -from llmtuner.extras.logging import get_logger -from llmtuner.extras.packages import is_matplotlib_available +from .logging import get_logger +from .packages import is_matplotlib_available if is_matplotlib_available(): import matplotlib.pyplot as plt diff --git a/src/llmtuner/hparams/__init__.py b/src/llmtuner/hparams/__init__.py index 623d6517..80deeb72 100644 --- a/src/llmtuner/hparams/__init__.py +++ b/src/llmtuner/hparams/__init__.py @@ -3,3 +3,16 @@ from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments +from .parser import get_train_args, get_infer_args, get_eval_args + + +__all__ = [ + "DataArguments", + "EvaluationArguments", + "FinetuningArguments", + "GeneratingArguments", + "ModelArguments", + "get_train_args", + "get_infer_args", + "get_eval_args" +] diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 7be4f4f5..a635e47a 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -1,40 +1,7 @@ -import os -import json -from typing import List, Literal, Optional +from typing import Literal, Optional from dataclasses import dataclass, field -DATA_CONFIG = "dataset_info.json" - - -def use_modelscope() -> bool: - return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0"))) - - -@dataclass -class DatasetAttr: - - load_from: Literal["hf_hub", "ms_hub", "script", "file"] - dataset_name: Optional[str] = None - dataset_sha1: Optional[str] = None - subset: Optional[str] = None - folder: Optional[str] = None - ranking: Optional[bool] = False - formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" - - prompt: Optional[str] = "instruction" - query: Optional[str] = "input" - response: Optional[str] = "output" - history: Optional[str] = None - messages: Optional[str] = "conversations" - role: Optional[str] = "from" - content: Optional[str] = "value" - system: Optional[str] = None - - def __repr__(self) -> str: - return self.dataset_name - - @dataclass class DataArguments: r""" @@ -126,64 +93,3 @@ class DataArguments: if self.streaming and self.max_samples is not None: raise ValueError("`max_samples` is incompatible with `streaming`.") - - def init_for_training(self, seed: int): # support mixing multiple datasets - self.seed = seed - dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else [] - try: - with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f: - dataset_info = json.load(f) - except Exception as err: - if self.dataset is not None: - raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err))) - dataset_info = None - - if self.interleave_probs is not None: - self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")] - - self.dataset_list: List[DatasetAttr] = [] - for name in dataset_names: - if name not in dataset_info: - raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) - - has_hf_url = "hf_hub_url" in dataset_info[name] - has_ms_url = "ms_hub_url" in dataset_info[name] - - if has_hf_url or has_ms_url: - if (use_modelscope() and has_ms_url) or (not has_hf_url): - dataset_attr = DatasetAttr( - "ms_hub", - dataset_name=dataset_info[name]["ms_hub_url"] - ) - else: - dataset_attr = DatasetAttr( - "hf_hub", - dataset_name=dataset_info[name]["hf_hub_url"] - ) - elif "script_url" in dataset_info[name]: - dataset_attr = DatasetAttr( - "script", - dataset_name=dataset_info[name]["script_url"] - ) - else: - dataset_attr = DatasetAttr( - "file", - dataset_name=dataset_info[name]["file_name"], - dataset_sha1=dataset_info[name].get("file_sha1", None) - ) - - if "columns" in dataset_info[name]: - dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) - dataset_attr.query = dataset_info[name]["columns"].get("query", None) - dataset_attr.response = dataset_info[name]["columns"].get("response", None) - dataset_attr.history = dataset_info[name]["columns"].get("history", None) - dataset_attr.messages = dataset_info[name]["columns"].get("messages", None) - dataset_attr.role = dataset_info[name]["columns"].get("role", None) - dataset_attr.content = dataset_info[name]["columns"].get("content", None) - dataset_attr.system = dataset_info[name]["columns"].get("system", None) - - dataset_attr.subset = dataset_info[name].get("subset", None) - dataset_attr.folder = dataset_info[name].get("folder", None) - dataset_attr.ranking = dataset_info[name].get("ranking", False) - dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca") - self.dataset_list.append(dataset_attr) diff --git a/src/llmtuner/hparams/evaluation_args.py b/src/llmtuner/hparams/evaluation_args.py index 5f507698..c70103ed 100644 --- a/src/llmtuner/hparams/evaluation_args.py +++ b/src/llmtuner/hparams/evaluation_args.py @@ -43,13 +43,5 @@ class EvaluationArguments: ) def __post_init__(self): - task_available = [] - for folder in os.listdir(self.task_dir): - if os.path.isdir(os.path.join(self.task_dir, folder)): - task_available.append(folder) - - if self.task not in task_available: - raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir)) - if self.save_dir is not None and os.path.exists(self.save_dir): raise ValueError("`save_dir` already exists, use another one.") diff --git a/src/llmtuner/model/parser.py b/src/llmtuner/hparams/parser.py similarity index 97% rename from src/llmtuner/model/parser.py rename to src/llmtuner/hparams/parser.py index f3626f69..cba9c690 100644 --- a/src/llmtuner/model/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -8,14 +8,12 @@ from typing import Any, Dict, Optional, Tuple from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers.trainer_utils import get_last_checkpoint -from llmtuner.extras.logging import get_logger -from llmtuner.hparams import ( - ModelArguments, - DataArguments, - EvaluationArguments, - FinetuningArguments, - GeneratingArguments -) +from ..extras.logging import get_logger +from .data_args import DataArguments +from .evaluation_args import EvaluationArguments +from .finetuning_args import FinetuningArguments +from .generating_args import GeneratingArguments +from .model_args import ModelArguments logger = get_logger(__name__) @@ -107,8 +105,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: _set_transformers_logging() # Check arguments - data_args.init_for_training(training_args.seed) - if finetuning_args.stage != "pt" and data_args.template is None: raise ValueError("Please specify which `template` to use.") diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index f12acb58..6d598361 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -1,5 +1,5 @@ -# Level: loader > adapter > parser, utils +from .loader import load_model_and_tokenizer +from .utils import dispatch_model, get_modelcard_args, load_valuehead_params -from llmtuner.model.loader import load_model_and_tokenizer -from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args -from llmtuner.model.utils import dispatch_model, get_modelcard_args, load_valuehead_params + +__all__ = ["load_model_and_tokenizer", "dispatch_model", "get_modelcard_args", "load_valuehead_params"] diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index 261650b7..f0d7ce21 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -3,12 +3,12 @@ from typing import TYPE_CHECKING from transformers.integrations import is_deepspeed_zero3_enabled from peft import PeftModel, TaskType, LoraConfig, get_peft_model -from llmtuner.extras.logging import get_logger -from llmtuner.model.utils import find_all_linear_modules +from ..extras.logging import get_logger +from .utils import find_all_linear_modules if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel - from llmtuner.hparams import ModelArguments, FinetuningArguments + from ..hparams import ModelArguments, FinetuningArguments logger = get_logger(__name__) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 8cdf85bf..adc45ea8 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -4,15 +4,15 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils.versions import require_version from trl import AutoModelForCausalLMWithValueHead -from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms -from llmtuner.model.adapter import init_adapter -from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model -from llmtuner.model.utils import load_valuehead_params, register_autoclass +from ..extras.logging import get_logger +from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms +from .adapter import init_adapter +from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model +from .utils import load_valuehead_params, register_autoclass if TYPE_CHECKING: from transformers import PreTrainedModel, PreTrainedTokenizer - from llmtuner.hparams import ModelArguments, FinetuningArguments + from ..hparams import ModelArguments, FinetuningArguments logger = get_logger(__name__) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 381436d2..d21d87dc 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -10,15 +10,15 @@ from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTra from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils.versions import require_version -from llmtuner.extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES -from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import get_current_device, infer_optim_dtype -from llmtuner.extras.packages import is_flash_attn2_available +from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES +from ..extras.logging import get_logger +from ..extras.misc import get_current_device, infer_optim_dtype +from ..extras.packages import is_flash_attn2_available if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer from trl import AutoModelForCausalLMWithValueHead - from llmtuner.hparams import ModelArguments + from ..hparams import ModelArguments logger = get_logger(__name__) diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 14bd4c59..ba4478fb 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -4,13 +4,13 @@ from typing import TYPE_CHECKING, Any, Dict, List from transformers import PreTrainedModel from transformers.utils import cached_file -from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME -from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import get_current_device +from ..extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME +from ..extras.logging import get_logger +from ..extras.misc import get_current_device if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + from ..hparams import ModelArguments, DataArguments, FinetuningArguments logger = get_logger(__name__) diff --git a/src/llmtuner/train/__init__.py b/src/llmtuner/train/__init__.py index e57c163b..6c22bc15 100644 --- a/src/llmtuner/train/__init__.py +++ b/src/llmtuner/train/__init__.py @@ -1 +1,4 @@ -from llmtuner.train.tuner import export_model, run_exp +from .tuner import export_model, run_exp + + +__all__ = ["export_model", "run_exp"] diff --git a/src/llmtuner/train/dpo/__init__.py b/src/llmtuner/train/dpo/__init__.py index 96c8ed09..43fe9420 100644 --- a/src/llmtuner/train/dpo/__init__.py +++ b/src/llmtuner/train/dpo/__init__.py @@ -1 +1,4 @@ -from llmtuner.train.dpo.workflow import run_dpo +from .workflow import run_dpo + + +__all__ = ["run_dpo"] diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index 97d80353..b5a44f5e 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -5,7 +5,7 @@ from transformers import BatchEncoding, Trainer from trl import DPOTrainer from trl.trainer.utils import disable_dropout_in_model -from llmtuner.extras.constants import IGNORE_INDEX +from ...extras.constants import IGNORE_INDEX if TYPE_CHECKING: from transformers import PreTrainedModel diff --git a/src/llmtuner/train/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py index 12a6b545..bd61a308 100644 --- a/src/llmtuner/train/dpo/workflow.py +++ b/src/llmtuner/train/dpo/workflow.py @@ -3,18 +3,18 @@ from typing import TYPE_CHECKING, Optional, List from transformers import Seq2SeqTrainingArguments -from llmtuner.data import get_dataset, preprocess_dataset, split_dataset -from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.extras.ploting import plot_loss -from llmtuner.hparams import ModelArguments -from llmtuner.model import load_model_and_tokenizer -from llmtuner.train.dpo.collator import DPODataCollatorWithPadding -from llmtuner.train.dpo.trainer import CustomDPOTrainer -from llmtuner.train.utils import create_modelcard_and_push, create_ref_model +from ...data import get_dataset, split_dataset +from ...extras.constants import IGNORE_INDEX +from ...extras.ploting import plot_loss +from ...hparams import ModelArguments +from ...model import load_model_and_tokenizer +from ...train.dpo.collator import DPODataCollatorWithPadding +from ...train.dpo.trainer import CustomDPOTrainer +from ...train.utils import create_modelcard_and_push, create_ref_model if TYPE_CHECKING: from transformers import TrainerCallback - from llmtuner.hparams import DataArguments, FinetuningArguments + from ...hparams import DataArguments, FinetuningArguments def run_dpo( @@ -24,9 +24,8 @@ def run_dpo( finetuning_args: "FinetuningArguments", callbacks: Optional[List["TrainerCallback"]] = None ): - dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) - dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") + dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm") data_collator = DPODataCollatorWithPadding( tokenizer=tokenizer, pad_to_multiple_of=8, diff --git a/src/llmtuner/train/ppo/__init__.py b/src/llmtuner/train/ppo/__init__.py index c32b23fa..d17336d5 100644 --- a/src/llmtuner/train/ppo/__init__.py +++ b/src/llmtuner/train/ppo/__init__.py @@ -1 +1,4 @@ -from llmtuner.train.ppo.workflow import run_ppo +from .workflow import run_ppo + + +__all__ = ["run_ppo"] diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index 31cab7c0..8b2116ea 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -13,15 +13,15 @@ from transformers.trainer_pt_utils import remove_dummy_checkpoint from trl import PPOTrainer from trl.core import PPODecorators, logprobs_from_logits -from llmtuner.extras.callbacks import LogCallback, FixValueHeadModelCallback -from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor -from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model +from ...extras.callbacks import LogCallback, FixValueHeadModelCallback +from ...extras.logging import get_logger +from ...extras.misc import AverageMeter, count_parameters, get_logits_processor +from .utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback from trl import AutoModelForCausalLMWithValueHead - from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments + from ...hparams import ModelArguments, FinetuningArguments, GeneratingArguments logger = get_logger(__name__) diff --git a/src/llmtuner/train/ppo/utils.py b/src/llmtuner/train/ppo/utils.py index 12e9bfcb..44e62067 100644 --- a/src/llmtuner/train/ppo/utils.py +++ b/src/llmtuner/train/ppo/utils.py @@ -2,7 +2,7 @@ import json import torch from typing import TYPE_CHECKING, Dict, List, Literal, Optional -from llmtuner.extras.packages import is_requests_available +from ...extras.packages import is_requests_available if TYPE_CHECKING: from transformers import PreTrainedModel diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py index 10c6a227..7b0dcc53 100644 --- a/src/llmtuner/train/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -7,17 +7,17 @@ from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorWithPadding from transformers.optimization import get_scheduler -from llmtuner.data import get_dataset, preprocess_dataset -from llmtuner.extras.callbacks import FixValueHeadModelCallback -from llmtuner.extras.misc import fix_valuehead_checkpoint -from llmtuner.extras.ploting import plot_loss -from llmtuner.model import load_model_and_tokenizer -from llmtuner.train.utils import create_ref_model, create_reward_model -from llmtuner.train.ppo.trainer import CustomPPOTrainer +from ...data import get_dataset +from ...extras.callbacks import FixValueHeadModelCallback +from ...extras.misc import fix_valuehead_checkpoint +from ...extras.ploting import plot_loss +from ...model import load_model_and_tokenizer +from ...train.utils import create_ref_model, create_reward_model +from ...train.ppo.trainer import CustomPPOTrainer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments + from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments def run_ppo( @@ -28,9 +28,8 @@ def run_ppo( generating_args: "GeneratingArguments", callbacks: Optional[List["TrainerCallback"]] = None ): - dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True) - dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") + dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="ppo") tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training data_collator = DataCollatorWithPadding(tokenizer=tokenizer) diff --git a/src/llmtuner/train/pt/__init__.py b/src/llmtuner/train/pt/__init__.py index eacbeadb..bdf397f6 100644 --- a/src/llmtuner/train/pt/__init__.py +++ b/src/llmtuner/train/pt/__init__.py @@ -1 +1,4 @@ -from llmtuner.train.pt.workflow import run_pt +from .workflow import run_pt + + +__all__ = ["run_pt"] diff --git a/src/llmtuner/train/pt/workflow.py b/src/llmtuner/train/pt/workflow.py index 27a6d2c4..3b7267eb 100644 --- a/src/llmtuner/train/pt/workflow.py +++ b/src/llmtuner/train/pt/workflow.py @@ -4,14 +4,14 @@ import math from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForLanguageModeling, Trainer -from llmtuner.data import get_dataset, preprocess_dataset, split_dataset -from llmtuner.extras.ploting import plot_loss -from llmtuner.model import load_model_and_tokenizer -from llmtuner.train.utils import create_modelcard_and_push +from ...data import get_dataset, split_dataset +from ...extras.ploting import plot_loss +from ...model import load_model_and_tokenizer +from ...train.utils import create_modelcard_and_push if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + from ...hparams import ModelArguments, DataArguments, FinetuningArguments def run_pt( @@ -21,9 +21,8 @@ def run_pt( finetuning_args: "FinetuningArguments", callbacks: Optional[List["TrainerCallback"]] = None ): - dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) - dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") + dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="pt") data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Initialize our Trainer diff --git a/src/llmtuner/train/rm/__init__.py b/src/llmtuner/train/rm/__init__.py index c80ccfb9..dedac35f 100644 --- a/src/llmtuner/train/rm/__init__.py +++ b/src/llmtuner/train/rm/__init__.py @@ -1 +1,4 @@ -from llmtuner.train.rm.workflow import run_rm +from .workflow import run_rm + + +__all__ = ["run_rm"] diff --git a/src/llmtuner/train/rm/trainer.py b/src/llmtuner/train/rm/trainer.py index b018a8c4..909d4373 100644 --- a/src/llmtuner/train/rm/trainer.py +++ b/src/llmtuner/train/rm/trainer.py @@ -4,7 +4,7 @@ import torch from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from transformers import Trainer -from llmtuner.extras.logging import get_logger +from ...extras.logging import get_logger if TYPE_CHECKING: from transformers.trainer import PredictionOutput diff --git a/src/llmtuner/train/rm/workflow.py b/src/llmtuner/train/rm/workflow.py index 52070027..e055e216 100644 --- a/src/llmtuner/train/rm/workflow.py +++ b/src/llmtuner/train/rm/workflow.py @@ -3,19 +3,19 @@ from typing import TYPE_CHECKING, Optional, List from transformers import Seq2SeqTrainingArguments -from llmtuner.data import get_dataset, preprocess_dataset, split_dataset -from llmtuner.extras.callbacks import FixValueHeadModelCallback -from llmtuner.extras.misc import fix_valuehead_checkpoint -from llmtuner.extras.ploting import plot_loss -from llmtuner.model import load_model_and_tokenizer -from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding -from llmtuner.train.rm.metric import compute_accuracy -from llmtuner.train.rm.trainer import PairwiseTrainer -from llmtuner.train.utils import create_modelcard_and_push +from ...data import get_dataset, split_dataset +from ...extras.callbacks import FixValueHeadModelCallback +from ...extras.misc import fix_valuehead_checkpoint +from ...extras.ploting import plot_loss +from ...model import load_model_and_tokenizer +from ...train.rm.collator import PairwiseDataCollatorWithPadding +from ...train.rm.metric import compute_accuracy +from ...train.rm.trainer import PairwiseTrainer +from ...train.utils import create_modelcard_and_push if TYPE_CHECKING: from transformers import TrainerCallback - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + from ...hparams import ModelArguments, DataArguments, FinetuningArguments def run_rm( @@ -25,9 +25,8 @@ def run_rm( finetuning_args: "FinetuningArguments", callbacks: Optional[List["TrainerCallback"]] = None ): - dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True) - dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") + dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm") data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) # Update arguments diff --git a/src/llmtuner/train/sft/__init__.py b/src/llmtuner/train/sft/__init__.py index cb5448f4..f2f84e78 100644 --- a/src/llmtuner/train/sft/__init__.py +++ b/src/llmtuner/train/sft/__init__.py @@ -1 +1,4 @@ -from llmtuner.train.sft.workflow import run_sft +from .workflow import run_sft + + +__all__ = ["run_sft"] diff --git a/src/llmtuner/train/sft/metric.py b/src/llmtuner/train/sft/metric.py index 18db0b88..2741c66b 100644 --- a/src/llmtuner/train/sft/metric.py +++ b/src/llmtuner/train/sft/metric.py @@ -2,8 +2,8 @@ import numpy as np from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union -from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.extras.packages import ( +from ...extras.constants import IGNORE_INDEX +from ...extras.packages import ( is_jieba_available, is_nltk_available, is_rouge_available ) diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index 291bbc7a..c8d9f039 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -6,8 +6,8 @@ import torch.nn as nn from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from transformers import Seq2SeqTrainer -from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.extras.logging import get_logger +from ...extras.constants import IGNORE_INDEX +from ...extras.logging import get_logger if TYPE_CHECKING: from transformers.trainer import PredictionOutput diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index 0e9bf7e4..6d3f34e8 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -3,18 +3,19 @@ from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments -from llmtuner.data import get_dataset, preprocess_dataset, split_dataset -from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.extras.misc import get_logits_processor -from llmtuner.extras.ploting import plot_loss -from llmtuner.model import load_model_and_tokenizer -from llmtuner.train.sft.metric import ComputeMetrics -from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer -from llmtuner.train.utils import create_modelcard_and_push +from ...data import get_dataset, split_dataset +from ...extras.constants import IGNORE_INDEX +from ...extras.misc import get_logits_processor +from ...extras.ploting import plot_loss +from ...model import load_model_and_tokenizer +from ...train.sft.metric import ComputeMetrics +from ...train.sft.trainer import CustomSeq2SeqTrainer +from ...train.utils import create_modelcard_and_push + if TYPE_CHECKING: from transformers import TrainerCallback - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments + from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments def run_sft( @@ -25,9 +26,8 @@ def run_sft( generating_args: "GeneratingArguments", callbacks: Optional[List["TrainerCallback"]] = None ): - dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) - dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") + dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="sft") if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index 8705c98e..32f1cda0 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -2,14 +2,15 @@ import torch from typing import TYPE_CHECKING, Any, Dict, List, Optional from transformers import PreTrainedModel -from llmtuner.extras.callbacks import LogCallback -from llmtuner.extras.logging import get_logger -from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer -from llmtuner.train.pt import run_pt -from llmtuner.train.sft import run_sft -from llmtuner.train.rm import run_rm -from llmtuner.train.ppo import run_ppo -from llmtuner.train.dpo import run_dpo +from ..extras.callbacks import LogCallback +from ..extras.logging import get_logger +from ..hparams import get_train_args, get_infer_args +from ..model import load_model_and_tokenizer +from .pt import run_pt +from .sft import run_sft +from .rm import run_rm +from .ppo import run_ppo +from .dpo import run_dpo if TYPE_CHECKING: from transformers import TrainerCallback diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 4cc775eb..789986e4 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -1,15 +1,15 @@ import torch from typing import TYPE_CHECKING, Optional, Union -from llmtuner.extras.logging import get_logger -from llmtuner.hparams import ModelArguments, FinetuningArguments -from llmtuner.model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params +from ..extras.logging import get_logger +from ..hparams import ModelArguments, FinetuningArguments +from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, Trainer from transformers.modeling_utils import PreTrainedModel from trl import AutoModelForCausalLMWithValueHead - from llmtuner.hparams import DataArguments + from ..hparams import DataArguments logger = get_logger(__name__) diff --git a/src/llmtuner/webui/__init__.py b/src/llmtuner/webui/__init__.py index a27c7f6e..3e82dd69 100644 --- a/src/llmtuner/webui/__init__.py +++ b/src/llmtuner/webui/__init__.py @@ -1 +1,4 @@ -from llmtuner.webui.interface import create_ui, create_web_demo +from .interface import create_ui, create_web_demo + + +__all__ = ["create_ui", "create_web_demo"] diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 08027e38..a6681665 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -2,14 +2,14 @@ import gradio as gr from gradio.components import Component # cannot use TYPE_CHECKING here from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple -from llmtuner.chat import ChatModel -from llmtuner.extras.misc import torch_gc -from llmtuner.hparams import GeneratingArguments -from llmtuner.webui.common import get_save_dir -from llmtuner.webui.locales import ALERTS +from ..chat import ChatModel +from ..extras.misc import torch_gc +from ..hparams import GeneratingArguments +from .common import get_save_dir +from .locales import ALERTS if TYPE_CHECKING: - from llmtuner.webui.manager import Manager + from .manager import Manager class WebChatModel(ChatModel): diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 28d8a805..3d431aeb 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -5,7 +5,8 @@ from collections import defaultdict from typing import Any, Dict, Optional from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME -from llmtuner.extras.constants import ( +from ..extras.constants import ( + DATA_CONFIG, DEFAULT_MODULE, DEFAULT_TEMPLATE, PEFT_METHODS, @@ -13,8 +14,7 @@ from llmtuner.extras.constants import ( TRAINING_STAGES, DownloadSource ) -from llmtuner.extras.misc import use_modelscope -from llmtuner.hparams.data_args import DATA_CONFIG +from ..extras.misc import use_modelscope ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME} diff --git a/src/llmtuner/webui/components/__init__.py b/src/llmtuner/webui/components/__init__.py index 32228b8e..2e9a87ec 100644 --- a/src/llmtuner/webui/components/__init__.py +++ b/src/llmtuner/webui/components/__init__.py @@ -1,6 +1,11 @@ -from llmtuner.webui.components.top import create_top -from llmtuner.webui.components.train import create_train_tab -from llmtuner.webui.components.eval import create_eval_tab -from llmtuner.webui.components.infer import create_infer_tab -from llmtuner.webui.components.export import create_export_tab -from llmtuner.webui.components.chatbot import create_chat_box +from .top import create_top +from .train import create_train_tab +from .eval import create_eval_tab +from .infer import create_infer_tab +from .export import create_export_tab +from .chatbot import create_chat_box + + +__all__ = [ + "create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box" +] diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 13e2dd4d..ee128aca 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -4,7 +4,8 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple if TYPE_CHECKING: from gradio.blocks import Block from gradio.components import Component - from llmtuner.webui.engine import Engine + + from ..engine import Engine def create_chat_box( diff --git a/src/llmtuner/webui/components/data.py b/src/llmtuner/webui/components/data.py index a74bd029..3a50065a 100644 --- a/src/llmtuner/webui/components/data.py +++ b/src/llmtuner/webui/components/data.py @@ -3,7 +3,7 @@ import json import gradio as gr from typing import TYPE_CHECKING, Any, Dict, Tuple -from llmtuner.webui.common import DATA_CONFIG +from ...extras.constants import DATA_CONFIG if TYPE_CHECKING: from gradio.components import Component diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index 0718c63e..d900ad29 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -1,12 +1,13 @@ import gradio as gr from typing import TYPE_CHECKING, Dict -from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR -from llmtuner.webui.components.data import create_preview_box +from ..common import list_dataset, DEFAULT_DATA_DIR +from .data import create_preview_box if TYPE_CHECKING: from gradio.components import Component - from llmtuner.webui.engine import Engine + + from ..engine import Engine def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py index a4b65591..187cda64 100644 --- a/src/llmtuner/webui/components/export.py +++ b/src/llmtuner/webui/components/export.py @@ -1,13 +1,14 @@ import gradio as gr from typing import TYPE_CHECKING, Dict, Generator, List -from llmtuner.train import export_model -from llmtuner.webui.common import get_save_dir -from llmtuner.webui.locales import ALERTS +from ...train import export_model +from ..common import get_save_dir +from ..locales import ALERTS if TYPE_CHECKING: from gradio.components import Component - from llmtuner.webui.engine import Engine + + from ..engine import Engine GPTQ_BITS = ["8", "4", "3", "2"] diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index d6dd7eed..ba578f10 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -1,11 +1,12 @@ import gradio as gr from typing import TYPE_CHECKING, Dict -from llmtuner.webui.components.chatbot import create_chat_box +from .chatbot import create_chat_box if TYPE_CHECKING: from gradio.components import Component - from llmtuner.webui.engine import Engine + + from ..engine import Engine def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 74441ab2..b8468186 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -1,10 +1,10 @@ import gradio as gr from typing import TYPE_CHECKING, Dict -from llmtuner.data.template import templates -from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS -from llmtuner.webui.common import get_model_path, get_template, list_adapters, save_config -from llmtuner.webui.utils import can_quantize +from ...data import templates +from ...extras.constants import METHODS, SUPPORTED_MODELS +from ..common import get_model_path, get_template, list_adapters, save_config +from ..utils import can_quantize if TYPE_CHECKING: from gradio.components import Component diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 5989c421..08e861f0 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -2,14 +2,15 @@ import gradio as gr from typing import TYPE_CHECKING, Dict from transformers.trainer_utils import SchedulerType -from llmtuner.extras.constants import TRAINING_STAGES -from llmtuner.webui.common import list_adapters, list_dataset, DEFAULT_DATA_DIR -from llmtuner.webui.components.data import create_preview_box -from llmtuner.webui.utils import gen_plot +from ...extras.constants import TRAINING_STAGES +from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR +from ..components.data import create_preview_box +from ..utils import gen_plot if TYPE_CHECKING: from gradio.components import Component - from llmtuner.webui.engine import Engine + + from ..engine import Engine def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py index 991b281c..db60b5df 100644 --- a/src/llmtuner/webui/engine.py +++ b/src/llmtuner/webui/engine.py @@ -2,12 +2,12 @@ import gradio as gr from gradio.components import Component # cannot use TYPE_CHECKING here from typing import Any, Dict, Generator, Optional -from llmtuner.webui.chatter import WebChatModel -from llmtuner.webui.common import get_model_path, list_dataset, load_config -from llmtuner.webui.locales import LOCALES -from llmtuner.webui.manager import Manager -from llmtuner.webui.runner import Runner -from llmtuner.webui.utils import get_time +from .chatter import WebChatModel +from .common import get_model_path, list_dataset, load_config +from .locales import LOCALES +from .manager import Manager +from .runner import Runner +from .utils import get_time class Engine: diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index 2525c3fd..39ddca04 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -2,7 +2,7 @@ import gradio as gr from typing import Optional from transformers.utils.versions import require_version -from llmtuner.webui.components import ( +from .components import ( create_top, create_train_tab, create_eval_tab, @@ -10,9 +10,9 @@ from llmtuner.webui.components import ( create_export_tab, create_chat_box ) -from llmtuner.webui.common import save_config -from llmtuner.webui.css import CSS -from llmtuner.webui.engine import Engine +from .common import save_config +from .css import CSS +from .engine import Engine require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"") diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 374d72a3..5d8efbfb 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -9,17 +9,17 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple import transformers from transformers.trainer import TRAINING_ARGS_NAME -from llmtuner.extras.callbacks import LogCallback -from llmtuner.extras.constants import TRAINING_STAGES -from llmtuner.extras.logging import LoggerHandler -from llmtuner.extras.misc import get_device_count, torch_gc -from llmtuner.train import run_exp -from llmtuner.webui.common import get_module, get_save_dir, load_config -from llmtuner.webui.locales import ALERTS -from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar +from ..extras.callbacks import LogCallback +from ..extras.constants import TRAINING_STAGES +from ..extras.logging import LoggerHandler +from ..extras.misc import get_device_count, torch_gc +from ..train import run_exp +from .common import get_module, get_save_dir, load_config +from .locales import ALERTS +from .utils import gen_cmd, get_eval_results, update_process_bar if TYPE_CHECKING: - from llmtuner.webui.manager import Manager + from .manager import Manager class Runner: diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 4579d296..c273b635 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -4,12 +4,12 @@ import gradio as gr from typing import TYPE_CHECKING, Any, Dict from datetime import datetime -from llmtuner.extras.packages import is_matplotlib_available -from llmtuner.extras.ploting import smooth -from llmtuner.webui.common import get_save_dir +from ..extras.packages import is_matplotlib_available +from ..extras.ploting import smooth +from .common import get_save_dir if TYPE_CHECKING: - from llmtuner.extras.callbacks import LogCallback + from ..extras.callbacks import LogCallback if is_matplotlib_available(): import matplotlib.figure diff --git a/tests/llamafy_internlm2.py b/tests/llamafy_internlm2.py index 3cd59f96..8fb1448c 100644 --- a/tests/llamafy_internlm2.py +++ b/tests/llamafy_internlm2.py @@ -1,6 +1,7 @@ # coding=utf-8 # Converts the InternLM2 model in the same format as LLaMA2. # Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB +# Warning: We have found that the converted model cannot infer correctly. It will be fixed later. import os import fire @@ -43,19 +44,18 @@ def save_weight( llama2_state_dict[key.replace("output", "lm_head")] = value elif "tok_embeddings" in key: llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value - elif "attention_norm" in key: - llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value elif "wqkv" in key: - proj_size = value.size(0) num_q_heads = internlm2_config_dict["num_attention_heads"] num_kv_heads = internlm2_config_dict["num_key_value_heads"] - q_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_q_heads - kv_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_kv_heads + q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads + kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...] llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...] llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...] elif "wo" in key: llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value + elif "attention_norm" in key: + llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value elif "ffn_norm" in key: llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value elif "w1" in key: