format style

This commit is contained in:
hiyouga 2024-01-20 20:15:56 +08:00
parent f6d6e00337
commit 638234ceee
73 changed files with 1492 additions and 2325 deletions

View File

@ -7,7 +7,7 @@ line-length = 119
target-version = ["py38"]
[tool.ruff]
ignore = ["C901", "E501", "E741", "W605"]
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
line-length = 119

View File

@ -1,4 +1,5 @@
import os
import uvicorn
from llmtuner import ChatModel, create_app

View File

@ -1,10 +1,12 @@
from llmtuner import ChatModel
from llmtuner.extras.misc import torch_gc
try:
import platform
if platform.system() != "Windows":
import readline
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")

View File

@ -8,12 +8,4 @@ from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.4.0"
__all__ = [
"create_app",
"ChatModel",
"Evaluator",
"export_model",
"run_exp",
"create_ui",
"create_web_demo"
]
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@ -1,30 +1,29 @@
import os
import json
import asyncio
from typing import List, Tuple
from pydantic import BaseModel
import json
import os
from contextlib import asynccontextmanager
from typing import List, Tuple
from pydantic import BaseModel
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
from .protocol import (
Role,
Finish,
ModelCard,
ModelList,
ChatMessage,
DeltaMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
Finish,
ModelCard,
ModelList,
Role,
ScoreEvaluationRequest,
ScoreEvaluationResponse
)
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available
ScoreEvaluationResponse,
)
@ -42,15 +41,15 @@ if is_uvicorn_available():
@asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory
async def lifespan(app: "FastAPI"): # collects GPU memory
yield
torch_gc()
def to_json(data: BaseModel) -> str:
try: # pydantic v2
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except: # pydantic v1
except Exception: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)
@ -90,8 +89,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content])
if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i + 1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else:
@ -107,65 +106,65 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat(
query, history, system,
query,
history,
system,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n
num_return_sequences=request.n,
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
choices.append(ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
))
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH,
)
)
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length+response_length
total_tokens=prompt_length + response_length,
)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
def stream_chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
def stream_chat_completion(
query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest
):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role=Role.ASSISTANT, content=""),
finish_reason=None
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
for new_text in chat_model.stream_chat(
query, history, system,
query,
history,
system,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens
max_new_tokens=request.max_tokens,
):
if len(new_text) == 0:
continue
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason=Finish.STOP
)
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
yield "[DONE]"

View File

@ -1,8 +1,9 @@
import time
from enum import Enum, unique
from pydantic import BaseModel, Field
from typing import List, Optional
from pydantic import BaseModel, Field
@unique
class Role(str, Enum):

View File

@ -1,18 +1,18 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer, Role
from ..data import Role, get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..model import dispatch_model, load_model_and_tokenizer
from ..hparams import get_infer_args
from ..model import dispatch_model, load_model_and_tokenizer
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
@ -20,10 +20,9 @@ class Response:
class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.can_generate = (finetuning_args.stage == "sft")
self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
@ -37,7 +36,7 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs
**input_kwargs,
) -> Tuple[Dict[str, Any], int]:
messages = []
if history is not None:
@ -63,16 +62,18 @@ class ChatModel:
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
generating_args = self.generating_args.to_dict()
generating_args.update(dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id
))
generating_args.update(
dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
generating_args["do_sample"] = True
@ -88,7 +89,7 @@ class ChatModel:
gen_kwargs = dict(
inputs=input_ids,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor()
logits_processor=get_logits_processor(),
)
return gen_kwargs, prompt_length
@ -100,7 +101,7 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs
**input_kwargs,
) -> List[Response]:
r"""
Args: query, history, system, **input_kwargs
@ -117,12 +118,14 @@ class ChatModel:
for i in range(len(response)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length"
))
results.append(
Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length",
)
)
return results
@ -133,7 +136,7 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs
**input_kwargs,
) -> Generator[str, None, None]:
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
@ -145,11 +148,7 @@ class ChatModel:
yield from streamer
@torch.inference_mode()
def get_scores(
self,
batch_input: List[str],
**input_kwargs
) -> List[float]:
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
@ -159,7 +158,7 @@ class ChatModel:
truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=True
add_special_tokens=True,
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]

View File

@ -1,6 +1,6 @@
from .loader import get_dataset
from .template import get_template_and_fix_tokenizer, templates
from .utils import split_dataset, Role
from .utils import Role, split_dataset
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"]

View File

@ -27,7 +27,9 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
if dataset_attr.response:
if isinstance(examples[dataset_attr.response][i], list):
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
response = [
{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]
]
else:
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
else:
@ -47,10 +49,10 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
dataset_attr.user_tag: Role.USER,
dataset_attr.assistant_tag: Role.ASSISTANT,
dataset_attr.observation_tag: Role.OBSERVATION,
dataset_attr.function_tag: Role.FUNCTION
dataset_attr.function_tag: Role.FUNCTION,
}
for i, messages in enumerate(examples[dataset_attr.messages]):
messages = messages[:len(messages) // 2 * 2] # should be multiples of 2
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
@ -65,7 +67,9 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
if message[dataset_attr.role_tag] not in accept_tags:
raise ValueError("Invalid role tag in {}.".format(messages))
prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]})
prompt.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
last_message = prompt.pop(-1)
response.append(last_message)
@ -98,12 +102,7 @@ def align_dataset(
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset"
desc="Converting format of dataset",
)
return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
**kwargs
)
return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs)

View File

@ -76,7 +76,11 @@ class ToolFormatter:
required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
name=name, type=param.get("type", ""), required=required, desc=param.get("description", ""), enum=enum
name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
@ -85,9 +89,7 @@ class ToolFormatter:
tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format(
tool_text=tool_text,
tool_names=", ".join(tool_names),
format_prompt=JSON_FORMAT_PROMPT
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
)
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:

View File

@ -1,16 +1,16 @@
import os
import inspect
import os
from typing import TYPE_CHECKING, List, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from .utils import checksum
from .parser import get_dataset_list
from .aligner import align_dataset
from .template import get_template_and_fix_tokenizer
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .utils import checksum
if TYPE_CHECKING:
@ -18,8 +18,8 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr
from ..hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
@ -44,14 +44,14 @@ def load_single_dataset(
elif dataset_attr.load_from == "file":
data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
@ -78,12 +78,12 @@ def load_single_dataset(
split=data_args.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
).to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
@ -97,13 +97,13 @@ def load_single_dataset(
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
**kwargs
**kwargs,
)
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.max_samples is not None: # truncate dataset
if data_args.max_samples is not None: # truncate dataset
num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples))
@ -113,7 +113,7 @@ def load_single_dataset(
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments"
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
if len(all_datasets) == 1:
return all_datasets[0]
@ -128,7 +128,7 @@ def merge_dataset(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError("Unknown mixing strategy.")
@ -160,7 +160,7 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args): # TODO: add split
for dataset_attr in get_dataset_list(data_args): # TODO: add split
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
@ -174,15 +174,10 @@ def get_dataset(
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset"
desc="Running tokenizer on dataset",
)
dataset = dataset.map(
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs
)
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:

View File

@ -1,18 +1,18 @@
import os
import json
from typing import TYPE_CHECKING, List, Literal, Optional
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Literal, Optional
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope
if TYPE_CHECKING:
from ..hparams import DataArguments
@dataclass
class DatasetAttr:
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
@ -49,7 +49,9 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_info = json.load(f)
except Exception as err:
if data_args.dataset is not None:
raise ValueError("Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)))
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None
if data_args.interleave_probs is not None:
@ -74,7 +76,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr = DatasetAttr(
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None)
dataset_sha1=dataset_info[name].get("file_sha1", None),
)
dataset_attr.subset = dataset_info[name].get("subset", None)

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
@ -17,9 +18,7 @@ logger = get_logger(__name__)
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments"
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
@ -35,7 +34,7 @@ def preprocess_pretrain_dataset(
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
@ -57,9 +56,11 @@ def preprocess_supervised_dataset(
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)):
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
@ -96,9 +97,9 @@ def preprocess_packed_supervised_dataset(
continue
messages = examples["prompt"][i] + examples["response"][i]
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i]
)):
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(tokenizer, messages, examples["system"][i], examples["tools"][i])
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
@ -119,9 +120,9 @@ def preprocess_packed_supervised_dataset(
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
model_inputs["input_ids"].append(input_ids[i : i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size])
model_inputs["labels"].append(labels[i : i + block_size])
return model_inputs
@ -191,9 +192,11 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
))
print(
"labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
)
)
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
@ -232,10 +235,14 @@ def get_preprocess_and_print_func(
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
preprocess_func = partial(
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
preprocess_func = partial(
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function

View File

@ -2,8 +2,8 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from ..extras.logging import get_logger
from .formatter import FunctionFormatter, StringFormatter, ToolFormatter
from .utils import Role
from .formatter import StringFormatter, FunctionFormatter, ToolFormatter
if TYPE_CHECKING:
@ -15,7 +15,6 @@ logger = get_logger(__name__)
@dataclass
class Template:
format_user: Callable
format_assistant: Callable
format_system: Callable
@ -34,7 +33,7 @@ class Template:
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000
cutoff_len: Optional[int] = 1_000_000,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
@ -53,7 +52,7 @@ class Template:
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: Optional[int] = 1_000_000
cutoff_len: Optional[int] = 1_000_000,
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
@ -67,7 +66,7 @@ class Template:
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int
cutoff_len: int,
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
@ -102,19 +101,17 @@ class Template:
if total_length >= cutoff_len:
break
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
total_length += len(encoded_messages[i])
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
total_length += len(encoded_messages[i+1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
total_length += len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
return encoded_pairs
def _convert_elements_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
elements: List[Union[str, Dict[str, str]]]
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
) -> List[int]:
r"""
Converts elements to token ids.
@ -139,14 +136,13 @@ class Template:
@dataclass
class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int
cutoff_len: int,
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
@ -182,12 +178,12 @@ class Llama2Template(Template):
if total_length >= cutoff_len:
break
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
total_length += len(encoded_messages[i])
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
total_length += len(encoded_messages[i+1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
total_length += len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
return encoded_pairs
@ -207,32 +203,26 @@ def register_template(
separator: Optional[List[Union[str, Dict[str, str]]]] = "",
stop_words: Optional[List[str]] = [],
efficient_eos: Optional[bool] = False,
replace_eos: Optional[bool] = False
replace_eos: Optional[bool] = False,
) -> None:
template_class = Llama2Template if name.startswith("llama2") else Template
templates[name] = template_class(
format_user=format_user or StringFormatter(container=["{{content}}"]),
format_assistant=format_assistant or StringFormatter(container=[
"{{content}}", {"eos_token"}
]),
format_assistant=format_assistant or StringFormatter(container=["{{content}}", {"eos_token"}]),
format_system=format_system or StringFormatter(container=["{{content}}"]),
format_tool=format_tool or ToolFormatter(type="default"),
format_observation=format_observation or format_user,
format_function=format_function or FunctionFormatter(container=[
"Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}
]),
format_function=format_function
or FunctionFormatter(container=["Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}]),
system=system,
separator=separator,
stop_words=stop_words,
efficient_eos=efficient_eos,
replace_eos=replace_eos
replace_eos=replace_eos,
)
def get_template_and_fix_tokenizer(
name: str,
tokenizer: "PreTrainedTokenizer"
) -> Template:
def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer") -> Template:
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token))
@ -241,7 +231,7 @@ def get_template_and_fix_tokenizer(
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if name is None: # for pre-training
if name is None: # for pre-training
return None
template = templates.get(name, None)
@ -258,8 +248,7 @@ def get_template_and_fix_tokenizer(
if stop_words:
tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words),
replace_additional_special_tokens=False
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
@ -268,263 +257,153 @@ def get_template_and_fix_tokenizer(
register_template(
name="alpaca",
format_user=StringFormatter(container=[
"### Instruction:\n{{content}}\n\n### Response:\n"
]),
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n\n### Response:\n"]),
system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
),
separator=[
"\n\n"
]
separator=["\n\n"],
)
register_template(
name="aquila",
format_user=StringFormatter(container=[
"Human: {{content}}###Assistant:"
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
format_user=StringFormatter(container=["Human: {{content}}###Assistant:"]),
format_assistant=StringFormatter(container=["{{content}}"]),
system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
separator=[
"###"
],
stop_words=[
"</s>"
],
efficient_eos=True
separator=["###"],
stop_words=["</s>"],
efficient_eos=True,
)
register_template(
name="baichuan",
format_user=StringFormatter(container=[
{"token": "<reserved_102>"},
"{{content}}",
{"token": "<reserved_103>"}
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
efficient_eos=True
format_user=StringFormatter(container=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
format_assistant=StringFormatter(container=["{{content}}"]),
efficient_eos=True,
)
register_template(
name="baichuan2",
format_user=StringFormatter(container=[
{"token": "<reserved_106>"},
"{{content}}",
{"token": "<reserved_107>"}
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
efficient_eos=True
format_user=StringFormatter(container=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
format_assistant=StringFormatter(container=["{{content}}"]),
efficient_eos=True,
)
register_template(
name="belle",
format_user=StringFormatter(container=[
"Human: {{content}}\n\nBelle: "
]),
separator=[
"\n\n"
]
name="belle", format_user=StringFormatter(container=["Human: {{content}}\n\nBelle: "]), separator=["\n\n"]
)
register_template(
name="bluelm",
format_user=StringFormatter(container=[
{"token": "[|Human|]:"},
"{{content}}",
{"token": "[|AI|]:"}
])
format_user=StringFormatter(container=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
register_template(
name="chatglm2",
format_user=StringFormatter(container=[
"[Round {{idx}}]\n\n问:{{content}}\n\n答:"
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{content}}"
]),
separator=[
"\n\n"
],
efficient_eos=True
format_user=StringFormatter(container=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_assistant=StringFormatter(container=["{{content}}"]),
format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
separator=["\n\n"],
efficient_eos=True,
)
register_template(
name="chatglm3",
format_user=StringFormatter(container=[
{"token": "<|user|>"},
"\n",
"{{content}}",
{"token": "<|assistant|>"}
]),
format_assistant=StringFormatter(container=[
"\n"
"{{content}}"
]),
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
{"token": "<|system|>"},
"\n",
"{{content}}"
]),
format_observation=StringFormatter(container=[
{"token": "<|observation|>"},
"\n",
"{{content}}"
]),
format_function=FunctionFormatter(container=[
"{{name}}\n{{arguments}}"
]),
format_user=StringFormatter(container=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(container=["\n" "{{content}}"]),
format_system=StringFormatter(
container=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
),
format_observation=StringFormatter(container=[{"token": "<|observation|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(container=["{{name}}\n{{arguments}}"]),
system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown."
),
stop_words=[
"<|user|>",
"<|observation|>"
],
efficient_eos=True
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
)
register_template(
name="codegeex2",
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{content}}"
])
name="codegeex2", format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"])
)
register_template(
name="deepseek",
format_user=StringFormatter(container=[
"User: {{content}}\n\nAssistant:"
])
)
register_template(name="deepseek", format_user=StringFormatter(container=["User: {{content}}\n\nAssistant:"]))
register_template(
name="deepseekcoder",
format_user=StringFormatter(container=[
"### Instruction:\n{{content}}\n### Response:\n"
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n### Response:\n"]),
format_assistant=StringFormatter(container=["{{content}}"]),
system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
),
separator=[
"\n",
{"token": "<|EOT|>"},
"\n"
],
stop_words=[
"<|EOT|>"
],
efficient_eos=True
separator=["\n", {"token": "<|EOT|>"}, "\n"],
stop_words=["<|EOT|>"],
efficient_eos=True,
)
register_template(
name="default",
format_user=StringFormatter(container=[
"Human: {{content}}\nAssistant: "
]),
format_user=StringFormatter(container=["Human: {{content}}\nAssistant: "]),
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
),
separator=[
"\n"
]
separator=["\n"],
)
register_template(
name="falcon",
format_user=StringFormatter(container=[
"User: {{content}}\nFalcon:"
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
separator=[
"\n"
],
efficient_eos=True
format_user=StringFormatter(container=["User: {{content}}\nFalcon:"]),
format_assistant=StringFormatter(container=["{{content}}"]),
separator=["\n"],
efficient_eos=True,
)
register_template(
name="intern",
format_user=StringFormatter(container=[
"<|User|>:{{content}}",
{"token": "<eoh>"},
"\n<|Bot|>:"
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
separator=[
{"token": "<eoa>"},
"\n"
],
stop_words=[
"<eoa>"
],
efficient_eos=True
format_user=StringFormatter(container=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
format_assistant=StringFormatter(container=["{{content}}"]),
separator=[{"token": "<eoa>"}, "\n"],
stop_words=["<eoa>"],
efficient_eos=True,
)
register_template(
name="intern2",
format_user=StringFormatter(container=[
{"token": "[UNUSED_TOKEN_146]"},
"user\n{{content}}",
{"token": "[UNUSED_TOKEN_145]"},
"\n",
{"token": "[UNUSED_TOKEN_146]"},
"assistant\n"
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
format_system=StringFormatter(container=[
{"token": "[UNUSED_TOKEN_146]"},
"system\n{{content}}",
{"token": "[UNUSED_TOKEN_145]"},
"\n"
]),
format_user=StringFormatter(
container=[
{"token": "[UNUSED_TOKEN_146]"},
"user\n{{content}}",
{"token": "[UNUSED_TOKEN_145]"},
"\n",
{"token": "[UNUSED_TOKEN_146]"},
"assistant\n",
]
),
format_assistant=StringFormatter(container=["{{content}}"]),
format_system=StringFormatter(
container=[{"token": "[UNUSED_TOKEN_146]"}, "system\n{{content}}", {"token": "[UNUSED_TOKEN_145]"}, "\n"]
),
system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed "
@ -532,14 +411,9 @@ register_template(
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
"by the user such as English and 中文."
),
separator=[
{"token": "[UNUSED_TOKEN_145]"},
"\n"
],
stop_words=[
"[UNUSED_TOKEN_145]"
],
efficient_eos=True
separator=[{"token": "[UNUSED_TOKEN_145]"}, "\n"],
stop_words=["[UNUSED_TOKEN_145]"],
efficient_eos=True,
)
@ -556,7 +430,7 @@ register_template(
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
)
),
)
@ -564,142 +438,83 @@ register_template(
name="llama2_zh",
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
system="You are a helpful assistant. 你是一个乐于助人的助手。"
system="You are a helpful assistant. 你是一个乐于助人的助手。",
)
register_template(
name="mistral",
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"])
)
register_template(name="mistral", format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]))
register_template(
name="openchat",
format_user=StringFormatter(container=[
"GPT4 Correct User: {{content}}",
{"token": "<|end_of_turn|>"},
"GPT4 Correct Assistant:"
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
separator=[
{"token": "<|end_of_turn|>"}
],
stop_words=[
"<|end_of_turn|>"
],
efficient_eos=True
format_user=StringFormatter(
container=["GPT4 Correct User: {{content}}", {"token": "<|end_of_turn|>"}, "GPT4 Correct Assistant:"]
),
format_assistant=StringFormatter(container=["{{content}}"]),
separator=[{"token": "<|end_of_turn|>"}],
stop_words=["<|end_of_turn|>"],
efficient_eos=True,
)
register_template(
name="qwen",
format_user=StringFormatter(container=[
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
]),
format_system=StringFormatter(container=[
"<|im_start|>system\n{{content}}<|im_end|>\n"
]),
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(container=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
system="You are a helpful assistant.",
separator=[
"\n"
],
stop_words=[
"<|im_end|>"
],
replace_eos=True
separator=["\n"],
stop_words=["<|im_end|>"],
replace_eos=True,
)
register_template(
name="solar",
format_user=StringFormatter(container=[
"### User:\n{{content}}\n\n### Assistant:\n"
])
)
register_template(name="solar", format_user=StringFormatter(container=["### User:\n{{content}}\n\n### Assistant:\n"]))
register_template(
name="starchat",
format_user=StringFormatter(container=[
{"token": "<|user|>"},
"\n{{content}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
format_system=StringFormatter(container=[
{"token": "<|system|>"},
"\n{{content}}",
{"token": "<|end|>"},
"\n"
]),
separator=[
{"token": "<|end|>"},
"\n"
],
stop_words=[
"<|end|>"
],
efficient_eos=True
format_user=StringFormatter(
container=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
),
format_assistant=StringFormatter(container=["{{content}}"]),
format_system=StringFormatter(container=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
separator=[{"token": "<|end|>"}, "\n"],
stop_words=["<|end|>"],
efficient_eos=True,
)
register_template(
name="vanilla"
)
register_template(name="vanilla")
register_template(
name="vicuna",
format_user=StringFormatter(container=[
"USER: {{content}} ASSISTANT:"
]),
format_user=StringFormatter(container=["USER: {{content}} ASSISTANT:"]),
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
),
)
register_template(
name="xuanyuan",
format_user=StringFormatter(container=[
"Human: {{content}} Assistant:"
]),
format_user=StringFormatter(container=["Human: {{content}} Assistant:"]),
system=(
"以下是用户和人工智能助手之间的对话。用户以Human开头人工智能助手以Assistant开头"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
)
),
)
register_template(
name="xverse",
format_user=StringFormatter(container=[
"Human: {{content}}\n\nAssistant: "
])
)
register_template(name="xverse", format_user=StringFormatter(container=["Human: {{content}}\n\nAssistant: "]))
register_template(
name="yayi",
format_user=StringFormatter(container=[
{"token": "<|Human|>"},
":\n{{content}}\n\n",
{"token": "<|YaYi|>"},
":"
]),
format_system=StringFormatter(container=[
{"token": "<|System|>"},
":\n{{content}}\n\n"
]),
format_user=StringFormatter(container=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_system=StringFormatter(container=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
@ -711,67 +526,43 @@ register_template(
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
separator=[
"\n\n"
],
stop_words=[
"<|End|>"
]
separator=["\n\n"],
stop_words=["<|End|>"],
)
register_template(
name="yi",
format_user=StringFormatter(container=[
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
]),
separator=[
"\n"
],
stop_words=[
"<|im_end|>"
],
replace_eos=True
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
separator=["\n"],
stop_words=["<|im_end|>"],
replace_eos=True,
)
register_template(
name="yuan",
format_user=StringFormatter(container=[
"{{content}}",
{"token": "<sep>"}
]),
separator=[
"\n"
],
stop_words=[
"<eod>"
],
replace_eos=True
format_user=StringFormatter(container=["{{content}}", {"token": "<sep>"}]),
separator=["\n"],
stop_words=["<eod>"],
replace_eos=True,
)
register_template(
name="zephyr",
format_user=StringFormatter(container=[
"<|user|>\n{{content}}</s><|assistant|>"
]),
format_system=StringFormatter(container=[
"<|system|>\n{{content}}</s>",
]),
system="You are a friendly chatbot who always responds in the style of a pirate"
format_user=StringFormatter(container=["<|user|>\n{{content}}</s><|assistant|>"]),
format_system=StringFormatter(
container=[
"<|system|>\n{{content}}</s>",
]
),
system="You are a friendly chatbot who always responds in the style of a pirate",
)
register_template(
name="ziya",
format_user=StringFormatter(container=[
{"token": "<human>"},
":{{content}}\n",
{"token": "<bot>"},
":"
]),
separator=[
"\n"
]
format_user=StringFormatter(container=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
separator=["\n"],
)

View File

@ -4,9 +4,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from ..extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import TrainingArguments
from llmtuner.hparams import DataArguments
@ -44,12 +46,10 @@ def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments")
def split_dataset(
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",
training_args: "TrainingArguments"
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
@ -63,5 +63,5 @@ def split_dataset(
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset}
else: # do_eval or do_predict
else: # do_eval or do_predict
return {"eval_dataset": dataset}

View File

@ -1,35 +1,34 @@
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os
import json
import torch
import numpy as np
import inspect
from tqdm import tqdm, trange
import json
import os
from typing import Any, Dict, List, Optional
import numpy as np
import torch
from datasets import load_dataset
from tqdm import tqdm, trange
from transformers.utils import cached_file
from ..data import get_template_and_fix_tokenizer
from .template import get_eval_template
from ..extras.constants import CHOICES, SUBJECTS
from ..hparams import get_eval_args
from ..model import dispatch_model, load_model_and_tokenizer
from .template import get_eval_template
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [self.tokenizer.encode(
self.eval_template.prefix + ch, add_special_tokens=False
)[-1] for ch in CHOICES]
self.choice_inputs = [
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
@ -41,10 +40,10 @@ class Evaluator:
def eval(self) -> None:
mapping = cached_file(
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json",
cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token
token=self.model_args.hf_hub_token,
)
with open(mapping, "r", encoding="utf-8") as f:
@ -54,7 +53,7 @@ class Evaluator:
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
@ -65,32 +64,34 @@ class Evaluator:
cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token,
**kwargs
**kwargs,
)
pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
support_set = (
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
)
messages = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i],
support_set=support_set,
subject_name=categorys[subject]["name"]
subject_name=categorys[subject]["name"],
)
input_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=messages
)
input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(messages[-1]["content"])
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
for i in trange(
0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
):
batch_input = self.tokenizer.pad(
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
).to(self.model.device)
preds = self.batch_inference(batch_input)
outputs += preds
corrects = (np.array(outputs) == np.array(labels))
corrects = np.array(outputs) == np.array(labels)
category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
@ -100,10 +101,13 @@ class Evaluator:
self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join([
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
for category_name, category_correct in category_corrects.items() if len(category_correct)
])
score_info = "\n".join(
[
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
for category_name, category_correct in category_corrects.items()
if len(category_correct)
]
)
print(score_info)
if self.eval_args.save_dir is not None:
os.makedirs(self.eval_args.save_dir, exist_ok=False)

View File

@ -1,8 +1,9 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from ..extras.constants import CHOICES
from ..data import Role
from ..extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@ -10,24 +11,17 @@ if TYPE_CHECKING:
@dataclass
class EvalTemplate:
system: str
choice: str
answer: str
prefix: str
def parse_example(
self,
example: Dict[str, str]
) -> Tuple[str, str]:
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self,
target_data: Dict[str, str],
support_set: "Dataset",
subject_name: str
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
) -> List[Dict[str, str]]:
messages = []
for k in range(len(support_set)):
@ -45,19 +39,8 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template(
name: str,
system: str,
choice: str,
answer: str,
prefix: str
) -> None:
eval_templates[name] = EvalTemplate(
system=system,
choice=choice,
answer=answer,
prefix=prefix
)
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
def get_eval_template(name: str) -> "EvalTemplate":
@ -71,7 +54,7 @@ register_eval_template(
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
answer="\nAnswer: ",
prefix=" "
prefix=" ",
)
@ -80,5 +63,5 @@ register_eval_template(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n"
prefix="\n",
)

View File

@ -1,10 +1,11 @@
import os
import json
import os
import time
from typing import TYPE_CHECKING
from datetime import timedelta
from typing import TYPE_CHECKING
from transformers import TrainerCallback
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import LOG_FILE_NAME
from .logging import get_logger
@ -12,14 +13,13 @@ from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
from transformers import TrainerControl, TrainerState, TrainingArguments
logger = get_logger(__name__)
class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
@ -28,12 +28,11 @@ class FixValueHeadModelCallback(TrainerCallback):
fix_valuehead_checkpoint(
model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors
safe_serialization=args.save_safetensors,
)
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.in_training = False
@ -99,7 +98,9 @@ class LogCallback(TrainerCallback):
self.cur_steps = 0
self.max_steps = 0
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
def on_predict(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
):
r"""
Event called after a successful prediction.
"""
@ -125,18 +126,22 @@ class LogCallback(TrainerCallback):
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time
remaining_time=self.remaining_time,
)
if self.runner is not None:
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
))
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
)
)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""

View File

@ -1,5 +1,5 @@
from collections import OrderedDict, defaultdict
from enum import Enum
from collections import defaultdict, OrderedDict
from typing import Dict, Optional
@ -11,14 +11,7 @@ DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
FILEEXT2TYPE = {
"arrow": "arrow",
"csv": "csv",
"json": "json",
"jsonl": "json",
"parquet": "parquet",
"txt": "text"
}
FILEEXT2TYPE = {"arrow": "arrow", "csv": "csv", "json": "json", "jsonl": "json", "parquet": "parquet", "txt": "text"}
IGNORE_INDEX = -100
@ -39,22 +32,21 @@ TRAINING_STAGES = {
"Reward Modeling": "rm",
"PPO": "ppo",
"DPO": "dpo",
"Pre-Training": "pt"
"Pre-Training": "pt",
}
V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
class DownloadSource(str, Enum):
DEFAULT = "hf"
MODELSCOPE = "ms"
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
module: Optional[str] = None,
template: Optional[str] = None
models: Dict[str, Dict[DownloadSource, str]], module: Optional[str] = None, template: Optional[str] = None
) -> None:
prefix = None
for name, path in models.items():
@ -73,19 +65,19 @@ register_model_group(
models={
"Baichuan-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B"
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
},
"Baichuan-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base"
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
},
"Baichuan-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat"
}
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
},
},
module="W_pack",
template="baichuan"
template="baichuan",
)
@ -93,23 +85,23 @@ register_model_group(
models={
"Baichuan2-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base"
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
},
"Baichuan2-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base"
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
},
"Baichuan2-7B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat"
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
},
"Baichuan2-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat"
}
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
},
},
module="W_pack",
template="baichuan2"
template="baichuan2",
)
@ -117,18 +109,18 @@ register_model_group(
models={
"BLOOM-560M": {
DownloadSource.DEFAULT: "bigscience/bloom-560m",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m"
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
},
"BLOOM-3B": {
DownloadSource.DEFAULT: "bigscience/bloom-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b"
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
},
"BLOOM-7B1": {
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
},
},
module="query_key_value"
module="query_key_value",
)
@ -136,18 +128,18 @@ register_model_group(
models={
"BLOOMZ-560M": {
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m"
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
},
"BLOOMZ-3B": {
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b"
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
},
"BLOOMZ-7B1-mt": {
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
},
},
module="query_key_value"
module="query_key_value",
)
@ -155,14 +147,14 @@ register_model_group(
models={
"BlueLM-7B-Base": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base"
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
},
"BlueLM-7B-Chat": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat"
}
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
},
},
template="bluelm"
template="bluelm",
)
@ -170,11 +162,11 @@ register_model_group(
models={
"ChatGLM2-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b"
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
}
},
module="query_key_value",
template="chatglm2"
template="chatglm2",
)
@ -182,15 +174,15 @@ register_model_group(
models={
"ChatGLM3-6B-Base": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base"
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
},
"ChatGLM3-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b"
}
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
},
},
module="query_key_value",
template="chatglm3"
template="chatglm3",
)
@ -198,30 +190,30 @@ register_model_group(
models={
"ChineseLLaMA2-1.3B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b"
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
},
"ChineseLLaMA2-7B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b"
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
},
"ChineseLLaMA2-13B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b"
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
},
"ChineseLLaMA2-1.3B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b"
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
},
"ChineseLLaMA2-7B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b"
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
},
"ChineseLLaMA2-13B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
},
},
template="llama2_zh"
template="llama2_zh",
)
@ -229,22 +221,22 @@ register_model_group(
models={
"DeepSeekLLM-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
},
"DeepSeekLLM-67B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
},
"DeepSeekLLM-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
},
"DeepSeekLLM-67B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat"
}
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
},
},
template="deepseek"
template="deepseek",
)
@ -252,22 +244,22 @@ register_model_group(
models={
"DeepSeekCoder-6.7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
},
"DeepSeekCoder-33B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
},
"DeepSeekCoder-6.7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
},
"DeepSeekCoder-33B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct"
}
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
},
},
template="deepseekcoder"
template="deepseekcoder",
)
@ -275,14 +267,14 @@ register_model_group(
models={
"DeepSeekMoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base"
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
},
"DeepSeekMoE-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat"
}
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
},
},
template="deepseek"
template="deepseek",
)
@ -290,31 +282,31 @@ register_model_group(
models={
"Falcon-7B": {
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b"
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
},
"Falcon-40B": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b"
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
},
"Falcon-180B": {
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
DownloadSource.MODELSCOPE: "modelscope/falcon-180B"
DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
},
"Falcon-7B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct"
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
},
"Falcon-40B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct"
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
},
"Falcon-180B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat"
}
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
},
},
module="query_key_value",
template="falcon"
template="falcon",
)
@ -322,22 +314,22 @@ register_model_group(
models={
"InternLM-7B": {
DownloadSource.DEFAULT: "internlm/internlm-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
},
"InternLM-20B": {
DownloadSource.DEFAULT: "internlm/internlm-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
},
"InternLM-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
},
"InternLM-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b"
}
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
},
},
template="intern"
template="intern",
)
@ -345,23 +337,23 @@ register_model_group(
models={
"InternLM2-7B": {
DownloadSource.DEFAULT: "internlm/internlm2-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
},
"InternLM2-20B": {
DownloadSource.DEFAULT: "internlm/internlm2-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
},
"InternLM2-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b"
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
},
"InternLM2-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b"
}
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
},
},
module="wqkv",
template="intern2"
template="intern2",
)
@ -369,31 +361,28 @@ register_model_group(
models={
"LingoWhale-8B": {
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B"
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
}
},
module="qkv_proj"
module="qkv_proj",
)
register_model_group(
models={
"LLaMA-7B": {
DownloadSource.DEFAULT: "huggyllama/llama-7b",
DownloadSource.MODELSCOPE: "skyline2006/llama-7b"
},
"LLaMA-7B": {DownloadSource.DEFAULT: "huggyllama/llama-7b", DownloadSource.MODELSCOPE: "skyline2006/llama-7b"},
"LLaMA-13B": {
DownloadSource.DEFAULT: "huggyllama/llama-13b",
DownloadSource.MODELSCOPE: "skyline2006/llama-13b"
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
},
"LLaMA-30B": {
DownloadSource.DEFAULT: "huggyllama/llama-30b",
DownloadSource.MODELSCOPE: "skyline2006/llama-30b"
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
},
"LLaMA-65B": {
DownloadSource.DEFAULT: "huggyllama/llama-65b",
DownloadSource.MODELSCOPE: "skyline2006/llama-65b"
}
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
},
}
)
@ -402,30 +391,30 @@ register_model_group(
models={
"LLaMA2-7B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms"
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
},
"LLaMA2-13B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms"
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
},
"LLaMA2-70B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms"
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
},
"LLaMA2-7B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms"
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
},
"LLaMA2-13B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms"
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
},
"LLaMA2-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms"
}
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
},
},
template="llama2"
template="llama2",
)
@ -433,18 +422,18 @@ register_model_group(
models={
"Mistral-7B": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1"
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
},
"Mistral-7B-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1"
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
},
"Mistral-7B-v0.2-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
},
},
template="mistral"
template="mistral",
)
@ -452,14 +441,14 @@ register_model_group(
models={
"Mixtral-8x7B": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1"
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
},
"Mixtral-8x7B-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
},
template="mistral"
template="mistral",
)
@ -467,110 +456,87 @@ register_model_group(
models={
"OpenChat3.5-7B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat_3.5",
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5"
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
}
},
template="openchat"
template="openchat",
)
register_model_group(
models={
"Phi-1.5-1.3B": {
DownloadSource.DEFAULT: "microsoft/phi-1_5",
DownloadSource.MODELSCOPE: "allspace/PHI_1-5"
},
"Phi-2-2.7B": {
DownloadSource.DEFAULT: "microsoft/phi-2",
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"
}
"Phi-1.5-1.3B": {DownloadSource.DEFAULT: "microsoft/phi-1_5", DownloadSource.MODELSCOPE: "allspace/PHI_1-5"},
"Phi-2-2.7B": {DownloadSource.DEFAULT: "microsoft/phi-2", DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"},
}
)
register_model_group(
models={
"Qwen-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"
},
"Qwen-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B"
},
"Qwen-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B"
},
"Qwen-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B"
},
"Qwen-1.8B": {DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"},
"Qwen-7B": {DownloadSource.DEFAULT: "Qwen/Qwen-7B", DownloadSource.MODELSCOPE: "qwen/Qwen-7B"},
"Qwen-14B": {DownloadSource.DEFAULT: "Qwen/Qwen-14B", DownloadSource.MODELSCOPE: "qwen/Qwen-14B"},
"Qwen-72B": {DownloadSource.DEFAULT: "Qwen/Qwen-72B", DownloadSource.MODELSCOPE: "qwen/Qwen-72B"},
"Qwen-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat"
},
"Qwen-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
},
"Qwen-7B-Chat": {DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"},
"Qwen-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat"
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
},
"Qwen-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat"
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
},
"Qwen-1.8B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8"
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
},
"Qwen-1.8B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4"
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
},
"Qwen-7B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8"
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
},
"Qwen-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4"
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
},
"Qwen-14B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8"
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
},
"Qwen-14B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4"
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
},
"Qwen-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8"
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
},
"Qwen-72B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4"
}
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
},
},
module="c_attn",
template="qwen"
template="qwen",
)
register_model_group(
models={
"SOLAR-10.7B": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0"
},
"SOLAR-10.7B": {DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0"},
"SOLAR-10.7B-Chat": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
},
},
template="solar"
template="solar",
)
@ -578,7 +544,7 @@ register_model_group(
models={
"Skywork-13B-Base": {
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base"
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
}
}
)
@ -588,68 +554,51 @@ register_model_group(
models={
"Vicuna1.5-7B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5"
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
},
"Vicuna1.5-13B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5"
}
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
},
},
template="vicuna"
template="vicuna",
)
register_model_group(
models={
"XuanYuan-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"
},
"XuanYuan-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"
},
"XuanYuan-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"
},
"XuanYuan-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"
}
"XuanYuan-70B": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"},
"XuanYuan-70B-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"},
"XuanYuan-70B-int8-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"},
"XuanYuan-70B-int4-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"},
},
template="xuanyuan"
template="xuanyuan",
)
register_model_group(
models={
"XVERSE-7B": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"
},
"XVERSE-13B": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"
},
"XVERSE-65B": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"
},
"XVERSE-7B": {DownloadSource.DEFAULT: "xverse/XVERSE-7B", DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"},
"XVERSE-13B": {DownloadSource.DEFAULT: "xverse/XVERSE-13B", DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"},
"XVERSE-65B": {DownloadSource.DEFAULT: "xverse/XVERSE-65B", DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"},
"XVERSE-65B-2": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2"
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
},
"XVERSE-7B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat"
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
},
"XVERSE-13B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat"
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
},
"XVERSE-65B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat"
}
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
},
},
template="xverse"
template="xverse",
)
@ -657,45 +606,33 @@ register_model_group(
models={
"Yayi-7B": {
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2"
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
},
"Yayi-13B": {
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2"
}
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
},
},
template="yayi"
template="yayi",
)
register_model_group(
models={
"Yi-6B": {
DownloadSource.DEFAULT: "01-ai/Yi-6B",
DownloadSource.MODELSCOPE: "01ai/Yi-6B"
},
"Yi-34B": {
DownloadSource.DEFAULT: "01-ai/Yi-34B",
DownloadSource.MODELSCOPE: "01ai/Yi-34B"
},
"Yi-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"
},
"Yi-34B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"
},
"Yi-6B": {DownloadSource.DEFAULT: "01-ai/Yi-6B", DownloadSource.MODELSCOPE: "01ai/Yi-6B"},
"Yi-34B": {DownloadSource.DEFAULT: "01-ai/Yi-34B", DownloadSource.MODELSCOPE: "01ai/Yi-34B"},
"Yi-6B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"},
"Yi-34B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"},
"Yi-6B-int8-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits"
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
},
"Yi-34B-int8-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits"
}
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
},
},
template="yi"
template="yi",
)
@ -703,18 +640,18 @@ register_model_group(
models={
"Yuan2-2B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf"
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
},
"Yuan2-51B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf"
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
},
"Yuan2-102B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf"
}
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
},
},
template="yuan"
template="yuan",
)
@ -722,12 +659,12 @@ register_model_group(
models={
"Zephyr-7B-Alpha-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha"
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
},
"Zephyr-7B-Beta-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta"
}
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
},
},
template="zephyr"
template="zephyr",
)

View File

@ -1,5 +1,5 @@
import sys
import logging
import sys
class LoggerHandler(logging.Handler):
@ -27,8 +27,7 @@ def get_logger(name: str) -> logging.Logger:
Gets a standard logger with a stream hander to stdout.
"""
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S"
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)

View File

@ -1,31 +1,33 @@
import gc
import os
import torch
from typing import TYPE_CHECKING, Dict, Tuple
import torch
from peft import PeftModel
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
from transformers.utils import (
WEIGHTS_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_npu_available,
is_torch_xpu_available
is_torch_xpu_available,
)
from peft import PeftModel
from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
_is_bf16_available = is_torch_bf16_gpu_available()
except:
except Exception:
_is_bf16_available = False
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments
@ -36,6 +38,7 @@ class AverageMeter:
r"""
Computes and stores the average and current value.
"""
def __init__(self):
self.reset()
@ -75,9 +78,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead",
output_dir: str,
safe_serialization: bool
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
r"""
The model is already unwrapped.
@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
if safe_serialization:
from safetensors import safe_open
from safetensors.torch import save_file
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
@ -112,9 +114,7 @@ def fix_valuehead_checkpoint(
os.remove(path_to_checkpoint)
model.pretrained_model.save_pretrained(
output_dir,
state_dict=decoder_state_dict or None,
safe_serialization=safe_serialization
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
)
if safe_serialization:
@ -182,11 +182,10 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
try:
from modelscope import snapshot_download
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
model_args.model_name_or_path = snapshot_download(
model_args.model_name_or_path,
revision=revision,
cache_dir=model_args.cache_dir
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
)
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")

View File

@ -9,7 +9,7 @@ def is_package_available(name: str) -> bool:
def get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except:
except Exception:
return "0.0.0"

View File

@ -1,11 +1,16 @@
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import (
Cache, LlamaAttention, LlamaFlashAttention2, apply_rotary_pos_emb, repeat_kv
Cache,
LlamaAttention,
LlamaFlashAttention2,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
@ -19,7 +24,7 @@ def llama_torch_attn_forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
**kwargs
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@ -45,15 +50,17 @@ def llama_torch_attn_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
@ -68,14 +75,17 @@ def llama_torch_attn_forward(
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
@ -94,7 +104,7 @@ def llama_flash_attn_forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
**kwargs
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
@ -124,9 +134,9 @@ def llama_flash_attn_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
dropout_rate = self.attention_dropout if self.training else 0.0
@ -144,14 +154,16 @@ def llama_flash_attn_forward(
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
@ -162,11 +174,14 @@ def llama_flash_attn_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)

View File

@ -1,12 +1,14 @@
import os
import math
import json
import math
import os
from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME
from .logging import get_logger
from .packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.pyplot as plt
@ -20,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
"""
last = scalars[0]
smoothed = list()
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val)
@ -29,7 +31,6 @@ def smooth(scalars: List[float]) -> List[float]:
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)

View File

@ -3,7 +3,7 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .parser import get_train_args, get_infer_args, get_eval_args
from .parser import get_eval_args, get_infer_args, get_train_args
__all__ = [
@ -14,5 +14,5 @@ __all__ = [
"ModelArguments",
"get_train_args",
"get_infer_args",
"get_eval_args"
"get_eval_args",
]

View File

@ -1,5 +1,5 @@
from typing import Literal, Optional
from dataclasses import dataclass, field
from typing import Literal, Optional
@dataclass
@ -8,80 +8,66 @@ class DataArguments:
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template: Optional[str] = field(
default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."}
default=None, metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
dataset: Optional[str] = field(
default=None,
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
)
dataset_dir: Optional[str] = field(
default="data",
metadata={"help": "Path to the folder containing the datasets."}
default="data", metadata={"help": "Path to the folder containing the datasets."}
)
split: Optional[str] = field(
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."}
default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
)
cutoff_len: Optional[int] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs after tokenization."}
default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."}
)
reserved_label_len: Optional[int] = field(
default=1,
metadata={"help": "The maximum length reserved for label after tokenization."}
default=1, metadata={"help": "The maximum length reserved for label after tokenization."}
)
train_on_prompt: Optional[bool] = field(
default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."}
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable dataset streaming."}
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
)
streaming: Optional[bool] = field(default=False, metadata={"help": "Enable dataset streaming."})
buffer_size: Optional[int] = field(
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
default=16384, metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
)
interleave_probs: Optional[str] = field(
default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
)
overwrite_cache: Optional[bool] = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}
default=False, metadata={"help": "Overwrite the cached training and evaluation sets."}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."}
default=None, metadata={"help": "The number of processes to use for the preprocessing."}
)
max_samples: Optional[int] = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
default=None, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
)
eval_num_beams: Optional[int] = field(
default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
)
ignore_pad_token_for_loss: Optional[bool] = field(
default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
val_size: Optional[float] = field(
default=0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
default=0, metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
)
sft_packing: Optional[bool] = field(
default=False,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
default=False, metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
)
cache_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the preprocessed datasets."}
default=None, metadata={"help": "Path to save or load the preprocessed datasets."}
)
def __post_init__(self):

View File

@ -1,6 +1,6 @@
import os
from typing import Literal, Optional
from dataclasses import dataclass, field
from typing import Literal, Optional
from datasets import DownloadMode
@ -10,36 +10,18 @@ class EvaluationArguments:
r"""
Arguments pertaining to specify the evaluation parameters.
"""
task: str = field(
metadata={"help": "Name of the evaluation task."}
)
task: str = field(metadata={"help": "Name of the evaluation task."})
task_dir: Optional[str] = field(
default="evaluation",
metadata={"help": "Path to the folder containing the evaluation datasets."}
)
batch_size: Optional[int] = field(
default=4,
metadata={"help": "The batch size per GPU for evaluation."}
)
seed: Optional[int] = field(
default=42,
metadata={"help": "Random seed to be used with data loaders."}
)
lang: Optional[Literal["en", "zh"]] = field(
default="en",
metadata={"help": "Language used at evaluation."}
)
n_shot: Optional[int] = field(
default=5,
metadata={"help": "Number of examplars for few-shot learning."}
)
save_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to save the evaluation results."}
default="evaluation", metadata={"help": "Path to the folder containing the evaluation datasets."}
)
batch_size: Optional[int] = field(default=4, metadata={"help": "The batch size per GPU for evaluation."})
seed: Optional[int] = field(default=42, metadata={"help": "Random seed to be used with data loaders."})
lang: Optional[Literal["en", "zh"]] = field(default="en", metadata={"help": "Language used at evaluation."})
n_shot: Optional[int] = field(default=5, metadata={"help": "Number of examplars for few-shot learning."})
save_dir: Optional[str] = field(default=None, metadata={"help": "Path to save the evaluation results."})
download_mode: Optional[DownloadMode] = field(
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
metadata={"help": "Download mode used for the evaluation datasets."}
metadata={"help": "Download mode used for the evaluation datasets."},
)
def __post_init__(self):

View File

@ -1,6 +1,6 @@
import json
from typing import Literal, Optional
from dataclasses import asdict, dataclass, field
from typing import Literal, Optional
@dataclass
@ -10,17 +10,18 @@ class FreezeArguments:
"""
name_module_trainable: Optional[str] = field(
default="mlp",
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
metadata={
"help": 'Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \
LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
Qwen choices: [\"mlp\", \"attn\"], \
Phi choices: [\"mlp\", \"mixer\"], \
Others choices: the same as LLaMA."}
LLaMA choices: ["mlp", "self_attn"], \
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
Qwen choices: ["mlp", "attn"], \
Phi choices: ["mlp", "mixer"], \
Others choices: the same as LLaMA.'
},
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
default=3, metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
@ -31,37 +32,32 @@ class LoraArguments:
"""
additional_target: Optional[str] = field(
default=None,
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
metadata={
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
},
)
lora_alpha: Optional[int] = field(
default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
)
lora_dropout: Optional[float] = field(
default=0.0,
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
)
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
default=None, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
)
lora_dropout: Optional[float] = field(default=0.0, metadata={"help": "Dropout rate for the LoRA fine-tuning."})
lora_rank: Optional[int] = field(default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."})
lora_target: Optional[str] = field(
default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
Phi choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
Others choices: the same as LLaMA."}
metadata={
"help": 'Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
Phi choices: ["Wqkv", "out_proj", "fc1", "fc2"], \
Others choices: the same as LLaMA.'
},
)
lora_bf16_mode: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
default=False, metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
)
create_new_adapter: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
)
@ -70,69 +66,53 @@ class RLHFArguments:
r"""
Arguments pertaining to the PPO and DPO training.
"""
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for the DPO loss."})
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
default="sigmoid",
metadata={"help": "The type of DPO loss to use."}
default="sigmoid", metadata={"help": "The type of DPO loss to use."}
)
dpo_ftx: Optional[float] = field(
default=0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
default=0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
)
ppo_buffer_size: Optional[int] = field(
default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
)
ppo_epochs: Optional[int] = field(
default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."}
default=4, metadata={"help": "The number of epochs to perform in a PPO optimization step."}
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": "Log with either \"wandb\" or \"tensorboard\" in PPO training."}
default=None, metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'}
)
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO training."}
default=False, metadata={"help": "Use score normalization in PPO training."}
)
ppo_target: Optional[float] = field(
default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
default=6.0, metadata={"help": "Target KL value for adaptive KL control in PPO training."}
)
ppo_whiten_rewards: Optional[bool] = field(
default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
default=False, metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
)
ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
default=None, metadata={"help": "Path to the reference model used for the PPO or DPO training."}
)
ref_model_adapters: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapters of the reference model."}
default=None, metadata={"help": "Path to the adapters of the reference model."}
)
ref_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reference model."}
default=None, metadata={"help": "The number of bits to quantize the reference model."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reward model used for the PPO training."}
default=None, metadata={"help": "Path to the reward model used for the PPO training."}
)
reward_model_adapters: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapters of the reward model."}
default=None, metadata={"help": "Path to the adapters of the reward model."}
)
reward_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reward model."}
default=None, metadata={"help": "The number of bits to quantize the reward model."}
)
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
)
@ -142,16 +122,13 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
default="sft", metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
default="lora", metadata={"help": "Which fine-tuning method to use."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."}
default=False, metadata={"help": "Whether or not to save the training loss curves."}
)
def __post_init__(self):

View File

@ -1,5 +1,5 @@
from typing import Any, Dict, Optional
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
@dataclass
@ -8,40 +8,37 @@ class GeneratingArguments:
Arguments pertaining to specify the decoding parameters.
"""
do_sample: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
default=True, metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
)
temperature: Optional[float] = field(
default=0.95,
metadata={"help": "The value used to modulate the next token probabilities."}
default=0.95, metadata={"help": "The value used to modulate the next token probabilities."}
)
top_p: Optional[float] = field(
default=0.7,
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
metadata={
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
},
)
top_k: Optional[int] = field(
default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
)
num_beams: Optional[int] = field(
default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
default=1, metadata={"help": "Number of beams for beam search. 1 means no beam search."}
)
max_length: Optional[int] = field(
default=512,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
)
max_new_tokens: Optional[int] = field(
default=512,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
)
repetition_penalty: Optional[float] = field(
default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
default=1.0, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
)
length_penalty: Optional[float] = field(
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
)
def to_dict(self) -> Dict[str, Any]:

View File

@ -1,5 +1,5 @@
from typing import Any, Dict, Literal, Optional
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
@dataclass
@ -11,108 +11,82 @@ class ModelArguments:
metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."}
)
adapter_name_or_path: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}
default=None, metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
)
use_fast_tokenizer: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
)
resize_vocab: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}
default=False, metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}
)
split_special_tokens: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
model_revision: Optional[str] = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model."}
default=None, metadata={"help": "The number of bits to quantize the model."}
)
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4",
metadata={"help": "Quantization data type to use in int4 training."}
default="nf4", metadata={"help": "Quantization data type to use in int4 training."}
)
double_quantization: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to use double quantization in int4 training."}
default=True, metadata={"help": "Whether or not to use double quantization in int4 training."}
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}
default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}
)
flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."}
default=False, metadata={"help": "Enable FlashAttention-2 for faster training."}
)
shift_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
)
use_unsloth: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}
default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}
)
disable_gradient_checkpointing: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."}
default=False, metadata={"help": "Whether or not to disable gradient checkpointing."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
default=False, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
)
upcast_lmhead_output: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
ms_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."}
default=False, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}
)
hf_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."})
ms_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with ModelScope Hub."})
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
default=None, metadata={"help": "Path to the directory to save the exported model."}
)
export_size: Optional[int] = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."}
default=1, metadata={"help": "The file shard size (in GB) of the exported model."}
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."}
default=None, metadata={"help": "The number of bits to quantize the exported model."}
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
default=None, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
)
export_quantization_nsamples: Optional[int] = field(
default=128,
metadata={"help": "The number of samples used for quantization."}
default=128, metadata={"help": "The number of samples used for quantization."}
)
export_quantization_maxlen: Optional[int] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."}
default=1024, metadata={"help": "The maximum length of the model inputs used for quantization."}
)
export_legacy_format: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}
default=False, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}
)
export_hub_model_id: Optional[str] = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}
default=None, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}
)
def __post_init__(self):
@ -122,7 +96,7 @@ class ModelArguments:
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."

View File

@ -1,10 +1,11 @@
import logging
import os
import sys
import torch
import logging
import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
import datasets
import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
@ -19,24 +20,12 @@ from .model_args import ModelArguments
logger = get_logger(__name__)
_TRAIN_ARGS = [
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_TRAIN_CLS = Tuple[
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_INFER_ARGS = [
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_INFER_CLS = Tuple[
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_EVAL_ARGS = [
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
_EVAL_CLS = Tuple[
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
@ -77,7 +66,7 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Multiple adapters are only available for LoRA tuning.")
if model_args.quantization_bit is not None:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
@ -181,18 +170,22 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
training_args.resume_from_checkpoint
))
logger.info(
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
training_args.resume_from_checkpoint
)
)
if (
finetuning_args.stage in ["rm", "ppo"]
and finetuning_args.finetuning_type == "lora"
and training_args.resume_from_checkpoint is not None
):
logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
))
logger.warning(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
)
)
# postprocess model_args
model_args.compute_dtype = (
@ -201,10 +194,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu,
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
))
logger.info(
"Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
str(model_args.compute_dtype),
)
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.

View File

@ -1,25 +1,25 @@
import torch
import inspect
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from ..extras.logging import get_logger
from .utils import find_all_linear_modules
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from ..hparams import ModelArguments, FinetuningArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
def init_adapter(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Initializes the adapters.
@ -47,10 +47,10 @@ def init_adapter(
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] # noqa: C416
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
@ -69,7 +69,7 @@ def init_adapter(
if model_args.adapter_name_or_path is not None:
is_mergeable = True
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False
@ -90,10 +90,10 @@ def init_adapter(
if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if adapter_to_resume is not None: # resume lora training
if adapter_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model)
else:
@ -103,11 +103,12 @@ def init_adapter(
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout
"lora_dropout": finetuning_args.lora_dropout,
}
if model_args.use_unsloth:
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters:
unsloth_peft_kwargs["loftq_config"] = {}
@ -124,7 +125,7 @@ def init_adapter(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
modules_to_save=finetuning_args.additional_target,
**peft_kwargs
**peft_kwargs,
)
model = get_peft_model(model, lora_config)

View File

@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
@ -7,12 +8,14 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter
from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from ..hparams import ModelArguments, FinetuningArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
@ -29,7 +32,7 @@ def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False
add_valuehead: Optional[bool] = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
@ -43,7 +46,7 @@ def load_model_and_tokenizer(
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token
"token": model_args.hf_hub_token,
}
tokenizer = AutoTokenizer.from_pretrained(
@ -51,7 +54,7 @@ def load_model_and_tokenizer(
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**config_kwargs
**config_kwargs,
)
patch_tokenizer(tokenizer)
@ -61,7 +64,8 @@ def load_model_and_tokenizer(
model = None
if is_trainable and model_args.use_unsloth:
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_kwargs = {
"model_name": model_args.model_name_or_path,
"max_seq_length": model_args.model_max_length,
@ -69,7 +73,7 @@ def load_model_and_tokenizer(
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": get_current_device(),
"rope_scaling": getattr(config, "rope_scaling", None)
"rope_scaling": getattr(config, "rope_scaling", None),
}
if getattr(config, "model_type", None) == "llama":
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
@ -89,7 +93,7 @@ def load_model_and_tokenizer(
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
**config_kwargs,
)
patch_model(model, tokenizer, model_args, is_trainable)
@ -119,9 +123,11 @@ def load_model_and_tokenizer(
model.train()
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
logger.info(
"trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
)
if not is_trainable:
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")

View File

@ -1,12 +1,12 @@
import os
import math
import torch
import os
import random
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from contextlib import nullcontext
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
@ -17,9 +17,11 @@ from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
@ -40,7 +42,8 @@ def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToke
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
@ -88,7 +91,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
@ -119,9 +122,9 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
model_args.rope_scaling, scaling_factor
))
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
@ -146,22 +149,22 @@ def _configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
config_kwargs: Dict[str, Any]
config_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # gptq
if getattr(config, "quantization_config", None): # gptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
quantization_config["use_exllama"] = False # disable exllama
quantization_config["use_exllama"] = False # disable exllama
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
elif model_args.export_quantization_bit is not None: # auto-gptq
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
@ -172,13 +175,13 @@ def _configure_quantization(
config_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args)
dataset=_get_quantization_dataset(tokenizer, model_args),
)
config_kwargs["device_map"] = "auto"
config_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
@ -192,7 +195,7 @@ def _configure_quantization(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
bnb_4bit_quant_type=model_args.quantization_type,
)
config_kwargs["device_map"] = {"": get_current_device()}
@ -200,9 +203,7 @@ def _configure_quantization(
def _prepare_model_for_training(
model: "PreTrainedModel",
model_args: "ModelArguments",
output_layer_name: Optional[str] = "lm_head"
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head"
) -> None:
r"""
Includes:
@ -222,10 +223,11 @@ def _prepare_model_for_training(
logger.warning("Current model does not support gradient checkpointing.")
else:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model.config.use_cache = False # turn off when gradient checkpointing is enabled
model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
@ -244,9 +246,9 @@ def patch_config(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
config_kwargs: Dict[str, Any],
is_trainable: bool
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if getattr(config, "model_type", None) == "qwen":
@ -266,10 +268,7 @@ def patch_config(
def patch_model(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
) -> None:
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)

View File

@ -1,16 +1,19 @@
import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List
import torch
from transformers import PreTrainedModel
from transformers.utils import cached_file
from ..extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ..extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ..hparams import ModelArguments, DataArguments, FinetuningArguments
from ..hparams import DataArguments, FinetuningArguments, ModelArguments
logger = get_logger(__name__)
@ -21,7 +24,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
"""
if getattr(model, "quantization_method", None): # already set on current device
if getattr(model, "quantization_method", None): # already set on current device
return model
if (
@ -31,7 +34,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
and model.config.model_type != "chatglm"
):
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
from accelerate.utils import get_balanced_memory, infer_auto_device_map
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
max_memory = get_balanced_memory(model, **kwargs)
@ -55,6 +58,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
linear_cls = torch.nn.Linear
elif quantization_method == "bitsandbytes":
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
else:
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
@ -65,10 +69,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, linear_cls)
and not any([output_layer in name for output_layer in output_layer_names])
):
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
@ -76,16 +77,14 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
def get_modelcard_args(
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments"
model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments"
) -> Dict[str, Any]:
return {
"tasks": "text-generation",
"license": "other",
"finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []),
}
@ -95,14 +94,11 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}

View File

@ -1,6 +1,7 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
import torch
from transformers import DataCollatorForSeq2Seq
@ -20,7 +21,7 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
@ -34,10 +35,12 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
for key in ("chosen_ids", "rejected_ids"):
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append({
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len)
})
concatenated_features.append(
{
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len),
}
)
label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad(

View File

@ -1,19 +1,20 @@
import torch
from contextlib import nullcontext
from collections import defaultdict
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers import PreTrainedModel
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
beta: float,
@ -22,15 +23,15 @@ class CustomDPOTrainer(DPOTrainer):
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
**kwargs
**kwargs,
):
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None:
disable_dropout_in_model(ref_model)
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.is_encoder_decoder = model.config.is_encoder_decoder
@ -53,42 +54,29 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
getattr(ref_model, "is_loaded_in_8bit", False)
or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def sft_loss(
self,
chosen_logits: torch.FloatTensor,
chosen_labels: torch.LongTensor
) -> torch.Tensor:
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self.get_batch_logps(
chosen_logits,
chosen_labels,
average_log_prob=True
)
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -all_logps
def concatenated_forward(
self,
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor]
self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
).logits.to(torch.float32)
all_logps = self.get_batch_logps(
@ -106,7 +94,7 @@ class CustomDPOTrainer(DPOTrainer):
self,
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor],
train_eval: Optional[Literal["train", "eval"]] = "train"
train_eval: Optional[Literal["train", "eval"]] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.

View File

@ -1,6 +1,7 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from typing import TYPE_CHECKING, Optional, List
from typing import TYPE_CHECKING, List, Optional
from transformers import Seq2SeqTrainingArguments
from ...data import get_dataset, split_dataset
@ -12,8 +13,10 @@ from ...train.dpo.collator import DPODataCollatorWithPadding
from ...train.dpo.trainer import CustomDPOTrainer
from ...train.utils import create_modelcard_and_push, create_ref_model
if TYPE_CHECKING:
from transformers import TrainerCallback
from ...hparams import DataArguments, FinetuningArguments
@ -22,25 +25,25 @@ def run_dpo(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
callbacks: Optional[List["TrainerCallback"]] = None,
):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
)
# Create reference model
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model = model
else:
ref_model = create_ref_model(model_args, finetuning_args)
# Update arguments
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
@ -54,7 +57,7 @@ def run_dpo(
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args)
**split_dataset(dataset, data_args, training_args),
)
# Training
@ -70,7 +73,7 @@ def run_dpo(
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model
if id(model) == id(ref_model): # unable to compute rewards without a reference model
remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys:
metrics.pop(key)

View File

@ -1,27 +1,28 @@
import math
import os
import sys
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import torch
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from ...extras.callbacks import LogCallback, FixValueHeadModelCallback
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
from .utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from ...hparams import ModelArguments, FinetuningArguments, GeneratingArguments
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
@ -40,7 +41,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs
**kwargs,
):
PPOTrainer.__init__(self, **kwargs)
@ -52,7 +53,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict()
**generating_args.to_dict(),
)
self.state = TrainerState()
@ -71,7 +72,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if not (
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
): # quantized models are already set on the correct device
self.reward_model = self._prepare_deepspeed(self.reward_model)
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
@ -111,9 +112,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.info(" Num examples = {}".format(num_examples))
logger.info(" Num Epochs = {}".format(num_train_epochs))
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
logger.info(" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
total_train_batch_size
))
logger.info(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
total_train_batch_size
)
)
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
logger.info(" Total training steps = {}".format(max_steps))
@ -138,10 +141,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.eval()
# Get inputs
self.tokenizer.padding_side = "right" # change padding side
self.tokenizer.padding_side = "right" # change padding side
queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
mini_batch_queries, mini_batch_responses = self.get_inputs(
batch[idx : idx + self.config.mini_batch_size]
)
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses)
@ -154,7 +159,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
# Run PPO step
stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side
self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
@ -163,18 +168,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards)
except:
except Exception:
logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
logs = dict(
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / steps_in_epoch, 2)
epoch=round(step / steps_in_epoch, 2),
)
tqdm.write(str(logs))
logs["step"] = step
@ -183,10 +188,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
loss_meter.reset()
reward_meter.reset()
if (step+1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join(
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
))
if (step + 1) % self.args.save_steps == 0: # save checkpoint
self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
)
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
@ -207,35 +212,33 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items():
batch[k] = v[:, start_index:]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config,
logits_processor=get_logits_processor(),
**batch
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)
query = batch["input_ids"].detach().cpu()
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
queries, responses = [], []
for i in range(len(query)):
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
response_length = 1 # allow empty response
else:
response_length = response_index[-1].item() + 1
queries.append(query[i, query_start_index:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
queries.append(query[i, query_start_index:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
return queries, responses
@ -244,7 +247,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self,
queries: List[torch.Tensor],
responses: List[torch.Tensor],
unwrapped_model: "AutoModelForCausalLMWithValueHead"
unwrapped_model: "AutoModelForCausalLMWithValueHead",
) -> List[torch.Tensor]:
r"""
Computes scores using given reward model.
@ -264,17 +267,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1)
rewards = []
for i in range(values.size(0)):
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
@ -289,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None
response_masks: Optional[torch.Tensor] = None,
):
r"""
Calculates model outputs in multiple batches.
@ -312,7 +315,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
@ -325,14 +328,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0].item()
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
masks[j, :start] = 0
masks[j, end:] = 0

View File

@ -1,9 +1,11 @@
import json
import torch
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
import torch
from ...extras.packages import is_requests_available
if TYPE_CHECKING:
from transformers import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
@ -21,16 +23,18 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily
if target == "reward": # save default head temporarily
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict({
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
})
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict(
{
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
}
)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:

View File

@ -1,23 +1,26 @@
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
from trl import PPOConfig
from typing import TYPE_CHECKING, List, Optional
from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from trl import PPOConfig
from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.utils import create_ref_model, create_reward_model
from ...train.ppo.trainer import CustomPPOTrainer
from ...train.utils import create_ref_model, create_reward_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
def run_ppo(
@ -26,12 +29,14 @@ def run_ppo(
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
callbacks: Optional[List["TrainerCallback"]] = None,
):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
model, tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, training_args.do_train, add_valuehead=True
)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model
@ -55,7 +60,7 @@ def run_ppo(
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False}
accelerator_kwargs={"step_scheduler_with_optimizer": False},
)
# Create optimizer and scheduler
@ -70,7 +75,7 @@ def run_ppo(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps
num_training_steps=num_training_steps,
)
# Initialize our Trainer
@ -88,7 +93,7 @@ def run_ppo(
dataset=dataset,
data_collator=data_collator,
optimizer=optimizer,
lr_scheduler=lr_scheduler
lr_scheduler=lr_scheduler,
)
# Training
@ -97,6 +102,6 @@ def run_ppo(
ppo_trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
ppo_trainer.save_state() # must be called after save_model to have a folder
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@ -1,7 +1,8 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
import math
from typing import TYPE_CHECKING, Optional, List
from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForLanguageModeling, Trainer
from ...data import get_dataset, split_dataset
@ -9,9 +10,11 @@ from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import ModelArguments, DataArguments, FinetuningArguments
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
def run_pt(
@ -19,7 +22,7 @@ def run_pt(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
callbacks: Optional[List["TrainerCallback"]] = None,
):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
@ -32,7 +35,7 @@ def run_pt(
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args)
**split_dataset(dataset, data_args, training_args),
)
# Training

View File

@ -1,6 +1,7 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, Sequence
import torch
from transformers import DataCollatorWithPadding
@ -20,8 +21,9 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
features = [
{
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])),
}
for key in ("chosen_ids", "rejected_ids") for feature in features
for key in ("chosen_ids", "rejected_ids")
for feature in features
]
return super().__call__(features)

View File

@ -1,6 +1,7 @@
import numpy as np
from typing import Dict, Sequence, Tuple, Union
import numpy as np
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
preds, _ = eval_preds

View File

@ -1,14 +1,16 @@
import os
import json
import torch
import os
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.trainer import PredictionOutput
logger = get_logger(__name__)
@ -21,13 +23,10 @@ class PairwiseTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.can_return_loss = True # override property to return eval_loss
self.can_return_loss = True # override property to return eval_loss
def compute_loss(
self,
model: "PreTrainedModel",
inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
@ -68,9 +67,9 @@ class PairwiseTrainer(Trainer):
assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
if return_outputs: # use the score on the last token except pad token for inference
chosen_scores.append(chosen_rewards[i, chosen_length-1])
rejected_scores.append(rejected_rewards[i, rejected_length-1])
if return_outputs: # use the score on the last token except pad token for inference
chosen_scores.append(chosen_rewards[i, chosen_length - 1])
rejected_scores.append(rejected_rewards[i, rejected_length - 1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
loss = loss / batch_size
@ -80,10 +79,7 @@ class PairwiseTrainer(Trainer):
return loss
def save_predictions(
self,
predict_results: "PredictionOutput"
) -> None:
def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""
Saves model predictions to `output_dir`.

View File

@ -1,6 +1,7 @@
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, Optional, List
from typing import TYPE_CHECKING, List, Optional
from transformers import Seq2SeqTrainingArguments
from ...data import get_dataset, split_dataset
@ -13,9 +14,11 @@ from ...train.rm.metric import compute_accuracy
from ...train.rm.trainer import PairwiseTrainer
from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import TrainerCallback
from ...hparams import ModelArguments, DataArguments, FinetuningArguments
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
def run_rm(
@ -23,15 +26,17 @@ def run_rm(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
callbacks: Optional[List["TrainerCallback"]] = None,
):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
model, tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, training_args.do_train, add_valuehead=True
)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
# Update arguments
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
@ -42,7 +47,7 @@ def run_rm(
data_collator=data_collator,
callbacks=callbacks + [FixValueHeadModelCallback()],
compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args)
**split_dataset(dataset, data_args, training_args),
)
# Training

View File

@ -1,11 +1,11 @@
import numpy as np
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import numpy as np
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import (
is_jieba_available, is_nltk_available, is_rouge_available
)
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
@ -14,7 +14,7 @@ if is_jieba_available():
import jieba
if is_nltk_available():
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
if is_rouge_available():
from rouge_chinese import Rouge

View File

@ -1,14 +1,16 @@
import os
import json
import torch
import numpy as np
import torch.nn as nn
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
@ -33,16 +35,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
Subclass and override to inject custom behavior.
"""
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
inputs["labels"] = inputs["labels"][:, :prompt_len]
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
if generated_tokens is not None and self.args.predict_with_generate:
@ -51,23 +53,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return loss, generated_tokens, labels
def _pad_tensors_to_target_len(
self,
src_tensor: torch.Tensor,
tgt_tensor: torch.Tensor
) -> torch.Tensor:
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
r"""
Pads the tensor to the same length as the target tensor.
"""
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory
def save_predictions(
self,
predict_results: "PredictionOutput"
) -> None:
def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""
Saves model predictions to `output_dir`.
@ -79,15 +74,23 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
)
preds = np.where(
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
)
for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
if len(pad_len):
preds[i] = np.concatenate((preds[i][pad_len[0]:], preds[i][:pad_len[0]]), axis=-1) # move pad token to last
preds[i] = np.concatenate(
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
) # move pad token to last
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False)
decoded_labels = self.tokenizer.batch_decode(
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer:

View File

@ -1,6 +1,7 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
from typing import TYPE_CHECKING, Optional, List
from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from ...data import get_dataset, split_dataset
@ -15,7 +16,8 @@ from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import TrainerCallback
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
def run_sft(
@ -24,29 +26,31 @@ def run_sft(
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
callbacks: Optional[List["TrainerCallback"]] = None,
):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation
tokenizer.padding_side = "left" # use left-padding in generation
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
)
# Override the decoding parameters of Seq2SeqTrainer
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
))
training_args_dict.update(
dict(
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams,
)
)
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
@ -57,7 +61,7 @@ def run_sft(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args, training_args)
**split_dataset(dataset, data_args, training_args),
)
# Keyword arguments for `model.generate`
@ -79,7 +83,7 @@ def run_sft(
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
@ -87,7 +91,7 @@ def run_sft(
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)

View File

@ -1,16 +1,18 @@
import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from transformers import PreTrainedModel
from ..extras.callbacks import LogCallback
from ..extras.logging import get_logger
from ..hparams import get_train_args, get_infer_args
from ..hparams import get_infer_args, get_train_args
from ..model import load_model_and_tokenizer
from .pt import run_pt
from .sft import run_sft
from .rm import run_rm
from .ppo import run_ppo
from .dpo import run_dpo
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -64,23 +66,23 @@ def export_model(args: Optional[Dict[str, Any]] = None):
model.save_pretrained(
save_directory=model_args.export_dir,
max_shard_size="{}GB".format(model_args.export_size),
safe_serialization=(not model_args.export_legacy_format)
safe_serialization=(not model_args.export_legacy_format),
)
if model_args.export_hub_model_id is not None:
model.push_to_hub(
model_args.export_hub_model_id,
token=model_args.hf_hub_token,
max_shard_size="{}GB".format(model_args.export_size),
safe_serialization=(not model_args.export_legacy_format)
safe_serialization=(not model_args.export_legacy_format),
)
try:
tokenizer.padding_side = "left" # restore padding side
tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left"
tokenizer.save_pretrained(model_args.export_dir)
if model_args.export_hub_model_id is not None:
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
except:
except Exception:
logger.warning("Cannot save tokenizer, please copy the files manually.")

View File

@ -1,14 +1,17 @@
import torch
from typing import TYPE_CHECKING, Optional, Union
import torch
from ..extras.logging import get_logger
from ..hparams import ModelArguments, FinetuningArguments
from ..hparams import FinetuningArguments, ModelArguments
from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments
@ -20,7 +23,7 @@ def create_modelcard_and_push(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments"
finetuning_args: "FinetuningArguments",
) -> None:
if training_args.do_train:
if training_args.push_to_hub:
@ -33,9 +36,7 @@ def create_modelcard_and_push(
def create_ref_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
add_valuehead: Optional[bool] = False
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: Optional[bool] = False
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
@ -44,11 +45,13 @@ def create_ref_model(
"""
if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.ref_model,
adapter_name_or_path=finetuning_args.ref_model_adapters,
quantization_bit=finetuning_args.ref_model_quantization_bit
))
ref_model_args_dict.update(
dict(
model_name_or_path=finetuning_args.ref_model,
adapter_name_or_path=finetuning_args.ref_model_adapters,
quantization_bit=finetuning_args.ref_model_quantization_bit,
)
)
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer(
@ -68,9 +71,7 @@ def create_ref_model(
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments"
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead":
r"""
Creates reward model for PPO training.
@ -81,24 +82,30 @@ def create_reward_model(
return finetuning_args.reward_model
elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
model.register_buffer(
"default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False
)
model.register_buffer(
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None
else:
reward_model_args_dict = model_args.to_dict()
reward_model_args_dict.update(dict(
model_name_or_path=finetuning_args.reward_model,
adapter_name_or_path=finetuning_args.reward_model_adapters,
quantization_bit=finetuning_args.reward_model_quantization_bit
))
reward_model_args_dict.update(
dict(
model_name_or_path=finetuning_args.reward_model,
adapter_name_or_path=finetuning_args.reward_model_adapters,
quantization_bit=finetuning_args.reward_model_quantization_bit,
)
)
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer(

View File

@ -1,24 +1,22 @@
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..hparams import GeneratingArguments
from .common import get_save_dir
from .locales import ALERTS
if TYPE_CHECKING:
from .manager import Manager
class WebChatModel(ChatModel):
def __init__(
self,
manager: "Manager",
demo_mode: Optional[bool] = False,
lazy_init: Optional[bool] = True
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
) -> None:
self.manager = manager
self.demo_mode = demo_mode
@ -26,11 +24,12 @@ class WebChatModel(ChatModel):
self.tokenizer = None
self.generating_args = GeneratingArguments()
if not lazy_init: # read arguments from command line
if not lazy_init: # read arguments from command line
super().__init__()
if demo_mode: # load demo_config.json if exists
if demo_mode: # load demo_config.json if exists
import json
try:
with open("demo_config.json", "r", encoding="utf-8") as f:
args = json.load(f)
@ -38,7 +37,7 @@ class WebChatModel(ChatModel):
super().__init__(args)
except AssertionError:
print("Please provided model name and template in `demo_config.json`.")
except:
except Exception:
print("Cannot find `demo_config.json` at current directory.")
@property
@ -64,9 +63,12 @@ class WebChatModel(ChatModel):
return
if get("top.adapter_path"):
adapter_name_or_path = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")])
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
@ -79,7 +81,7 @@ class WebChatModel(ChatModel):
template=get("top.template"),
flash_attn=(get("top.booster") == "flash_attn"),
use_unsloth=(get("top.booster") == "unsloth"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
)
super().__init__(args)
@ -108,7 +110,7 @@ class WebChatModel(ChatModel):
tools: str,
max_new_tokens: int,
top_p: float,
temperature: float
temperature: float,
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
chatbot.append([query, ""])
response = ""

View File

@ -1,9 +1,10 @@
import os
import json
import gradio as gr
import os
from collections import defaultdict
from typing import Any, Dict, Optional
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
import gradio as gr
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from ..extras.constants import (
DATA_CONFIG,
@ -12,7 +13,7 @@ from ..extras.constants import (
PEFT_METHODS,
SUPPORTED_MODELS,
TRAINING_STAGES,
DownloadSource
DownloadSource,
)
from ..extras.misc import use_modelscope
@ -36,7 +37,7 @@ def load_config() -> Dict[str, Any]:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@ -59,7 +60,7 @@ def get_model_path(model_name: str) -> str:
use_modelscope()
and path_dict.get(DownloadSource.MODELSCOPE)
and model_path == path_dict.get(DownloadSource.DEFAULT)
): # replace path
): # replace path
model_path = path_dict.get(DownloadSource.MODELSCOPE)
return model_path
@ -87,9 +88,8 @@ def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for adapter in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, adapter))
and any([os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES])
if os.path.isdir(os.path.join(save_dir, adapter)) and any(
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
):
adapters.append(adapter)
return gr.update(value=[], choices=adapters, interactive=True)

View File

@ -1,11 +1,16 @@
from .chatbot import create_chat_box
from .eval import create_eval_tab
from .export import create_export_tab
from .infer import create_infer_tab
from .top import create_top
from .train import create_train_tab
from .eval import create_eval_tab
from .infer import create_infer_tab
from .export import create_export_tab
from .chatbot import create_chat_box
__all__ = [
"create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box"
"create_top",
"create_train_tab",
"create_eval_tab",
"create_infer_tab",
"create_export_tab",
"create_chat_box",
]

View File

@ -1,6 +1,7 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr
from ..utils import check_json_schema
@ -12,8 +13,7 @@ if TYPE_CHECKING:
def create_chat_box(
engine: "Engine",
visible: Optional[bool] = False
engine: "Engine", visible: Optional[bool] = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
@ -38,20 +38,23 @@ def create_chat_box(
engine.chatter.predict,
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
lambda: gr.update(value=""), outputs=[query]
)
show_progress=True,
).then(lambda: gr.update(value=""), outputs=[query])
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict(
system=system,
tools=tools,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature
return (
chat_box,
chatbot,
history,
dict(
system=system,
tools=tools,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
),
)

View File

@ -1,10 +1,12 @@
import os
import json
import gradio as gr
import os
from typing import TYPE_CHECKING, Any, Dict, Tuple
import gradio as gr
from ...extras.constants import DATA_CONFIG
if TYPE_CHECKING:
from gradio.components import Component
@ -24,7 +26,7 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
except:
except Exception:
return gr.update(interactive=False)
if (
@ -48,7 +50,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
elif data_file.endswith(".jsonl"):
data = [json.loads(line) for line in f]
else:
data = [line for line in f]
data = [line for line in f] # noqa: C416
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
@ -67,32 +69,17 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
with gr.Row():
preview_samples = gr.JSON(interactive=False)
dataset.change(
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
).then(
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
lambda: 0, outputs=[page_index], queue=False
)
data_preview_btn.click(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
prev_btn.click(
prev_page, [page_index], [page_index], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
next_btn.click(
next_page, [page_index, preview_count], [page_index], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return dict(
@ -102,5 +89,5 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
prev_btn=prev_btn,
next_btn=next_btn,
close_btn=close_btn,
preview_samples=preview_samples
preview_samples=preview_samples,
)

View File

@ -1,9 +1,11 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from ..common import list_dataset, DEFAULT_DATA_DIR
import gradio as gr
from ..common import DEFAULT_DATA_DIR, list_dataset
from .data import create_preview_box
if TYPE_CHECKING:
from gradio.components import Component
@ -31,9 +33,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
predict = gr.Checkbox(value=True)
input_elems.update({cutoff_len, max_samples, batch_size, predict})
elem_dict.update(dict(
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
))
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
with gr.Row():
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
@ -42,9 +42,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir = gr.Textbox()
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
elem_dict.update(dict(
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
))
elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
with gr.Row():
cmd_preview_btn = gr.Button()
@ -59,10 +57,16 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
output_box = gr.Markdown()
output_elems = [output_box, process_bar]
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box
))
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,
)
)
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
start_btn.click(engine.runner.run_eval, input_elems, output_elems)

View File

@ -1,10 +1,12 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List
import gradio as gr
from ...train import export_model
from ..common import get_save_dir
from ..locales import ALERTS
if TYPE_CHECKING:
from gradio.components import Component
@ -24,7 +26,7 @@ def save_model(
max_shard_size: int,
export_quantization_bit: int,
export_quantization_dataset: str,
export_dir: str
export_dir: str,
) -> Generator[str, None, None]:
error = ""
if not model_name:
@ -44,7 +46,9 @@ def save_model(
return
if adapter_path:
adapter_name_or_path = ",".join([get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path])
adapter_name_or_path = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
)
else:
adapter_name_or_path = None
@ -56,7 +60,7 @@ def save_model(
export_dir=export_dir,
export_size=max_shard_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset
export_quantization_dataset=export_quantization_dataset,
)
yield ALERTS["info_exporting"][lang]
@ -86,9 +90,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
max_shard_size,
export_quantization_bit,
export_quantization_dataset,
export_dir
export_dir,
],
[info_box]
[info_box],
)
return dict(
@ -97,5 +101,5 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
export_quantization_dataset=export_quantization_dataset,
export_dir=export_dir,
export_btn=export_btn,
info_box=info_box
info_box=info_box,
)

View File

@ -1,8 +1,10 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
import gradio as gr
from .chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
@ -23,18 +25,12 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
load_btn.click(
engine.chatter.load_model, input_elems, [info_box]
).then(
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
)
unload_btn.click(
engine.chatter.unload_model, input_elems, [info_box]
).then(
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
lambda: ([], []), outputs=[chatbot, history]
).then(
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
)
).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box])
return elem_dict

View File

@ -1,11 +1,13 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
import gradio as gr
from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ..common import get_model_path, get_template, list_adapters, save_config
from ..utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component
@ -30,25 +32,19 @@ def create_top() -> Dict[str, "Component"]:
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none")
model_name.change(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
).then(
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
get_model_path, [model_name], [model_path], queue=False
).then(
get_template, [model_name], [template], queue=False
) # do not save config since the below line will save
) # do not save config since the below line will save
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
).then(
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
can_quantize, [finetuning_type], [quantization_bit], queue=False
)
refresh_btn.click(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
)
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
return dict(
lang=lang,
@ -61,5 +57,5 @@ def create_top() -> Dict[str, "Component"]:
quantization_bit=quantization_bit,
template=template,
rope_scaling=rope_scaling,
booster=booster
booster=booster,
)

View File

@ -1,12 +1,14 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
import gradio as gr
from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR
from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
from ..components.data import create_preview_box
from ..utils import gen_plot
if TYPE_CHECKING:
from gradio.components import Component
@ -29,9 +31,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
))
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
@ -41,25 +41,33 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
elem_dict.update(dict(
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
max_samples=max_samples, compute_type=compute_type
))
elem_dict.update(
dict(
cutoff_len=cutoff_len,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
compute_type=compute_type,
)
)
with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown(
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
)
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
elem_dict.update(dict(
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
))
elem_dict.update(
dict(
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
val_size=val_size,
)
)
with gr.Accordion(label="Extra config", open=False) as extra_tab:
with gr.Row():
@ -73,10 +81,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
upcast_layernorm = gr.Checkbox(value=False)
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, sft_packing, upcast_layernorm})
elem_dict.update(dict(
extra_tab=extra_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha, sft_packing=sft_packing, upcast_layernorm=upcast_layernorm
))
elem_dict.update(
dict(
extra_tab=extra_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha,
sft_packing=sft_packing,
upcast_layernorm=upcast_layernorm,
)
)
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
@ -87,10 +102,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter = gr.Checkbox(scale=1)
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, create_new_adapter})
elem_dict.update(dict(
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
additional_target=additional_target, create_new_adapter=create_new_adapter
))
elem_dict.update(
dict(
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
additional_target=additional_target,
create_new_adapter=create_new_adapter,
)
)
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
@ -103,13 +124,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
list_adapters,
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
[reward_model],
queue=False
queue=False,
)
input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(
rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn
))
elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
)
with gr.Row():
cmd_preview_btn = gr.Button()
@ -139,20 +160,28 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems)
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
))
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,
loss_viewer=loss_viewer,
)
)
output_box.change(
gen_plot,
[
engine.manager.get_elem_by_name("top.model_name"),
engine.manager.get_elem_by_name("top.finetuning_type"),
output_dir
output_dir,
],
loss_viewer,
queue=False
queue=False,
)
return elem_dict

View File

@ -1,7 +1,8 @@
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, Optional
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config
from .locales import LOCALES
@ -11,7 +12,6 @@ from .utils import get_time
class Engine:
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
self.demo_mode = demo_mode
self.pure_chat = pure_chat
@ -26,10 +26,7 @@ class Engine:
user_config = load_config() if not self.demo_mode else {}
lang = user_config.get("lang", None) or "en"
init_dict = {
"top.lang": {"value": lang},
"infer.chat_box": {"visible": self.chatter.loaded}
}
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
if not self.pure_chat:
init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
@ -49,13 +46,17 @@ class Engine:
else:
yield self._form_dict({"eval.resume_btn": {"value": True}})
else:
yield self._form_dict({
"train.output_dir": {"value": "train_" + get_time()},
"eval.output_dir": {"value": "eval_" + get_time()},
})
yield self._form_dict(
{
"train.output_dir": {"value": "train_" + get_time()},
"eval.output_dir": {"value": "eval_" + get_time()},
}
)
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
return {
component: gr.update(**LOCALES[name][lang])
for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
for elems in self.manager.all_elems.values()
for name, component in elems.items()
if name in LOCALES
}

View File

@ -1,21 +1,22 @@
import gradio as gr
from typing import Optional
import gradio as gr
from transformers.utils.versions import require_version
from .common import save_config
from .components import (
create_chat_box,
create_eval_tab,
create_export_tab,
create_infer_tab,
create_top,
create_train_tab,
create_eval_tab,
create_infer_tab,
create_export_tab,
create_chat_box
)
from .common import save_config
from .css import CSS
from .engine import Engine
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
require_version("gradio>=3.38.0,<4.0.0", 'To fix: pip install "gradio>=3.38.0,<4.0.0"')
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
@ -23,11 +24,9 @@ def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
if demo_mode:
gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
gr.HTML(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
)
gr.HTML(
"<h3><center>Visit <a href=\"https://github.com/hiyouga/LLaMA-Factory\" target=\"_blank\">"
'<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
"LLaMA Factory</a> for details.</center></h3>"
)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")

View File

@ -1,726 +1,220 @@
LOCALES = {
"lang": {
"en": {
"label": "Lang"
},
"zh": {
"label": "语言"
}
},
"model_name": {
"en": {
"label": "Model name"
},
"zh": {
"label": "模型名称"
}
},
"lang": {"en": {"label": "Lang"}, "zh": {"label": "语言"}},
"model_name": {"en": {"label": "Model name"}, "zh": {"label": "模型名称"}},
"model_path": {
"en": {
"label": "Model path",
"info": "Path to pretrained model or model identifier from Hugging Face."
},
"zh": {
"label": "模型路径",
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
}
},
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
}
},
"adapter_path": {
"en": {
"label": "Adapter path"
},
"zh": {
"label": "适配器路径"
}
},
"refresh_btn": {
"en": {
"value": "Refresh adapters"
},
"zh": {
"value": "刷新适配器"
}
},
"advanced_tab": {
"en": {
"label": "Advanced configurations"
},
"zh": {
"label": "高级设置"
}
"en": {"label": "Model path", "info": "Path to pretrained model or model identifier from Hugging Face."},
"zh": {"label": "模型路径", "info": "本地模型的文件路径或 Hugging Face 的模型标识符。"},
},
"finetuning_type": {"en": {"label": "Finetuning method"}, "zh": {"label": "微调方法"}},
"adapter_path": {"en": {"label": "Adapter path"}, "zh": {"label": "适配器路径"}},
"refresh_btn": {"en": {"value": "Refresh adapters"}, "zh": {"value": "刷新适配器"}},
"advanced_tab": {"en": {"label": "Advanced configurations"}, "zh": {"label": "高级设置"}},
"quantization_bit": {
"en": {
"label": "Quantization bit",
"info": "Enable 4/8-bit model quantization (QLoRA)."
},
"zh": {
"label": "量化等级",
"info": "启用 4/8 比特模型量化QLoRA"
}
"en": {"label": "Quantization bit", "info": "Enable 4/8-bit model quantization (QLoRA)."},
"zh": {"label": "量化等级", "info": "启用 4/8 比特模型量化QLoRA"},
},
"template": {
"en": {
"label": "Prompt template",
"info": "The template used in constructing prompts."
},
"zh": {
"label": "提示模板",
"info": "构建提示词时使用的模板"
}
},
"rope_scaling": {
"en": {
"label": "RoPE scaling"
},
"zh": {
"label": "RoPE 插值方法"
}
},
"booster": {
"en": {
"label": "Booster"
},
"zh": {
"label": "加速方式"
}
"en": {"label": "Prompt template", "info": "The template used in constructing prompts."},
"zh": {"label": "提示模板", "info": "构建提示词时使用的模板"},
},
"rope_scaling": {"en": {"label": "RoPE scaling"}, "zh": {"label": "RoPE 插值方法"}},
"booster": {"en": {"label": "Booster"}, "zh": {"label": "加速方式"}},
"training_stage": {
"en": {
"label": "Stage",
"info": "The stage to perform in training."
},
"zh": {
"label": "训练阶段",
"info": "目前采用的训练方式。"
}
"en": {"label": "Stage", "info": "The stage to perform in training."},
"zh": {"label": "训练阶段", "info": "目前采用的训练方式。"},
},
"dataset_dir": {
"en": {
"label": "Data dir",
"info": "Path to the data directory."
},
"zh": {
"label": "数据路径",
"info": "数据文件夹的路径。"
}
},
"dataset": {
"en": {
"label": "Dataset"
},
"zh": {
"label": "数据集"
}
},
"data_preview_btn": {
"en": {
"value": "Preview dataset"
},
"zh": {
"value": "预览数据集"
}
},
"preview_count": {
"en": {
"label": "Count"
},
"zh": {
"label": "数量"
}
},
"page_index": {
"en": {
"label": "Page"
},
"zh": {
"label": "页数"
}
},
"prev_btn": {
"en": {
"value": "Prev"
},
"zh": {
"value": "上一页"
}
},
"next_btn": {
"en": {
"value": "Next"
},
"zh": {
"value": "下一页"
}
},
"close_btn": {
"en": {
"value": "Close"
},
"zh": {
"value": "关闭"
}
},
"preview_samples": {
"en": {
"label": "Samples"
},
"zh": {
"label": "样例"
}
"en": {"label": "Data dir", "info": "Path to the data directory."},
"zh": {"label": "数据路径", "info": "数据文件夹的路径。"},
},
"dataset": {"en": {"label": "Dataset"}, "zh": {"label": "数据集"}},
"data_preview_btn": {"en": {"value": "Preview dataset"}, "zh": {"value": "预览数据集"}},
"preview_count": {"en": {"label": "Count"}, "zh": {"label": "数量"}},
"page_index": {"en": {"label": "Page"}, "zh": {"label": "页数"}},
"prev_btn": {"en": {"value": "Prev"}, "zh": {"value": "上一页"}},
"next_btn": {"en": {"value": "Next"}, "zh": {"value": "下一页"}},
"close_btn": {"en": {"value": "Close"}, "zh": {"value": "关闭"}},
"preview_samples": {"en": {"label": "Samples"}, "zh": {"label": "样例"}},
"cutoff_len": {
"en": {
"label": "Cutoff length",
"info": "Max tokens in input sequence."
},
"zh": {
"label": "截断长度",
"info": "输入序列分词后的最大长度。"
}
"en": {"label": "Cutoff length", "info": "Max tokens in input sequence."},
"zh": {"label": "截断长度", "info": "输入序列分词后的最大长度。"},
},
"learning_rate": {
"en": {
"label": "Learning rate",
"info": "Initial learning rate for AdamW."
},
"zh": {
"label": "学习率",
"info": "AdamW 优化器的初始学习率。"
}
"en": {"label": "Learning rate", "info": "Initial learning rate for AdamW."},
"zh": {"label": "学习率", "info": "AdamW 优化器的初始学习率。"},
},
"num_train_epochs": {
"en": {
"label": "Epochs",
"info": "Total number of training epochs to perform."
},
"zh": {
"label": "训练轮数",
"info": "需要执行的训练总轮数。"
}
"en": {"label": "Epochs", "info": "Total number of training epochs to perform."},
"zh": {"label": "训练轮数", "info": "需要执行的训练总轮数。"},
},
"max_samples": {
"en": {
"label": "Max samples",
"info": "Maximum samples per dataset."
},
"zh": {
"label": "最大样本数",
"info": "每个数据集最多使用的样本数。"
}
"en": {"label": "Max samples", "info": "Maximum samples per dataset."},
"zh": {"label": "最大样本数", "info": "每个数据集最多使用的样本数。"},
},
"compute_type": {
"en": {
"label": "Compute type",
"info": "Whether to use fp16 or bf16 mixed precision training."
},
"zh": {
"label": "计算类型",
"info": "是否启用 FP16 或 BF16 混合精度训练。"
}
"en": {"label": "Compute type", "info": "Whether to use fp16 or bf16 mixed precision training."},
"zh": {"label": "计算类型", "info": "是否启用 FP16 或 BF16 混合精度训练。"},
},
"batch_size": {
"en": {
"label": "Batch size",
"info": "Number of samples to process per GPU."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
"en": {"label": "Batch size", "info": "Number of samples to process per GPU."},
"zh": {"label": "批处理大小", "info": "每块 GPU 上处理的样本数量。"},
},
"gradient_accumulation_steps": {
"en": {
"label": "Gradient accumulation",
"info": "Number of gradient accumulation steps."
},
"zh": {
"label": "梯度累积",
"info": "梯度累积的步数。"
}
"en": {"label": "Gradient accumulation", "info": "Number of gradient accumulation steps."},
"zh": {"label": "梯度累积", "info": "梯度累积的步数。"},
},
"lr_scheduler_type": {
"en": {
"label": "LR Scheduler",
"info": "Name of learning rate scheduler.",
},
"zh": {
"label": "学习率调节器",
"info": "采用的学习率调节器名称。"
}
"zh": {"label": "学习率调节器", "info": "采用的学习率调节器名称。"},
},
"max_grad_norm": {
"en": {
"label": "Maximum gradient norm",
"info": "Norm for gradient clipping.."
},
"zh": {
"label": "最大梯度范数",
"info": "用于梯度裁剪的范数。"
}
"en": {"label": "Maximum gradient norm", "info": "Norm for gradient clipping.."},
"zh": {"label": "最大梯度范数", "info": "用于梯度裁剪的范数。"},
},
"val_size": {
"en": {
"label": "Val size",
"info": "Proportion of data in the dev set."
},
"zh": {
"label": "验证集比例",
"info": "验证集占全部样本的百分比。"
}
},
"extra_tab": {
"en": {
"label": "Extra configurations"
},
"zh": {
"label": "其它参数设置"
}
"en": {"label": "Val size", "info": "Proportion of data in the dev set."},
"zh": {"label": "验证集比例", "info": "验证集占全部样本的百分比。"},
},
"extra_tab": {"en": {"label": "Extra configurations"}, "zh": {"label": "其它参数设置"}},
"logging_steps": {
"en": {
"label": "Logging steps",
"info": "Number of steps between two logs."
},
"zh": {
"label": "日志间隔",
"info": "每两次日志输出间的更新步数。"
}
"en": {"label": "Logging steps", "info": "Number of steps between two logs."},
"zh": {"label": "日志间隔", "info": "每两次日志输出间的更新步数。"},
},
"save_steps": {
"en": {
"label": "Save steps",
"info": "Number of steps between two checkpoints."
},
"zh": {
"label": "保存间隔",
"info": "每两次断点保存间的更新步数。"
}
"en": {"label": "Save steps", "info": "Number of steps between two checkpoints."},
"zh": {"label": "保存间隔", "info": "每两次断点保存间的更新步数。"},
},
"warmup_steps": {
"en": {
"label": "Warmup steps",
"info": "Number of steps used for warmup."
},
"zh": {
"label": "预热步数",
"info": "学习率预热采用的步数。"
}
"en": {"label": "Warmup steps", "info": "Number of steps used for warmup."},
"zh": {"label": "预热步数", "info": "学习率预热采用的步数。"},
},
"neftune_alpha": {
"en": {
"label": "NEFTune Alpha",
"info": "Magnitude of noise adding to embedding vectors."
},
"zh": {
"label": "NEFTune 噪声参数",
"info": "嵌入向量所添加的噪声大小。"
}
"en": {"label": "NEFTune Alpha", "info": "Magnitude of noise adding to embedding vectors."},
"zh": {"label": "NEFTune 噪声参数", "info": "嵌入向量所添加的噪声大小。"},
},
"sft_packing": {
"en": {
"label": "Pack sequences",
"info": "Pack sequences into samples of fixed length in supervised fine-tuning."
"info": "Pack sequences into samples of fixed length in supervised fine-tuning.",
},
"zh": {
"label": "序列打包",
"info": "在有监督微调阶段将序列打包为相同长度的样本。"
}
"zh": {"label": "序列打包", "info": "在有监督微调阶段将序列打包为相同长度的样本。"},
},
"upcast_layernorm": {
"en": {
"label": "Upcast LayerNorm",
"info": "Upcast weights of layernorm in float32."
},
"zh": {
"label": "缩放归一化层",
"info": "将归一化层权重缩放至 32 位精度。"
}
},
"lora_tab": {
"en": {
"label": "LoRA configurations"
},
"zh": {
"label": "LoRA 参数设置"
}
"en": {"label": "Upcast LayerNorm", "info": "Upcast weights of layernorm in float32."},
"zh": {"label": "缩放归一化层", "info": "将归一化层权重缩放至 32 位精度。"},
},
"lora_tab": {"en": {"label": "LoRA configurations"}, "zh": {"label": "LoRA 参数设置"}},
"lora_rank": {
"en": {
"label": "LoRA rank",
"info": "The rank of LoRA matrices."
},
"zh": {
"label": "LoRA 秩",
"info": "LoRA 矩阵的秩。"
}
"en": {"label": "LoRA rank", "info": "The rank of LoRA matrices."},
"zh": {"label": "LoRA 秩", "info": "LoRA 矩阵的秩。"},
},
"lora_dropout": {
"en": {
"label": "LoRA Dropout",
"info": "Dropout ratio of LoRA weights."
},
"zh": {
"label": "LoRA 随机丢弃",
"info": "LoRA 权重随机丢弃的概率。"
}
"en": {"label": "LoRA Dropout", "info": "Dropout ratio of LoRA weights."},
"zh": {"label": "LoRA 随机丢弃", "info": "LoRA 权重随机丢弃的概率。"},
},
"lora_target": {
"en": {
"label": "LoRA modules (optional)",
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules.",
},
"zh": {
"label": "LoRA 作用模块(非必填)",
"info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"
}
"zh": {"label": "LoRA 作用模块(非必填)", "info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"},
},
"additional_target": {
"en": {
"label": "Additional modules (optional)",
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules."
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules.",
},
"zh": {
"label": "附加模块(非必填)",
"info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"
}
"zh": {"label": "附加模块(非必填)", "info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"},
},
"create_new_adapter": {
"en": {
"label": "Create new adapter",
"info": "Whether to create a new adapter with randomly initialized weight or not."
"info": "Whether to create a new adapter with randomly initialized weight or not.",
},
"zh": {
"label": "新建适配器",
"info": "是否创建一个经过随机初始化的新适配器。"
}
},
"rlhf_tab": {
"en": {
"label": "RLHF configurations"
},
"zh": {
"label": "RLHF 参数设置"
}
"zh": {"label": "新建适配器", "info": "是否创建一个经过随机初始化的新适配器。"},
},
"rlhf_tab": {"en": {"label": "RLHF configurations"}, "zh": {"label": "RLHF 参数设置"}},
"dpo_beta": {
"en": {
"label": "DPO beta",
"info": "Value of the beta parameter in the DPO loss."
},
"zh": {
"label": "DPO beta 参数",
"info": "DPO 损失函数中 beta 超参数大小。"
}
"en": {"label": "DPO beta", "info": "Value of the beta parameter in the DPO loss."},
"zh": {"label": "DPO beta 参数", "info": "DPO 损失函数中 beta 超参数大小。"},
},
"dpo_ftx": {
"en": {
"label": "DPO-ftx weight",
"info": "The weight of SFT loss in the DPO-ftx."
},
"zh": {
"label": "DPO-ftx 权重",
"info": "DPO-ftx 中 SFT 损失的权重大小。"
}
"en": {"label": "DPO-ftx weight", "info": "The weight of SFT loss in the DPO-ftx."},
"zh": {"label": "DPO-ftx 权重", "info": "DPO-ftx 中 SFT 损失的权重大小。"},
},
"reward_model": {
"en": {
"label": "Reward model",
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)"
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)",
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)"
}
},
"cmd_preview_btn": {
"en": {
"value": "Preview command"
},
"zh": {
"value": "预览命令"
}
},
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
"zh": {"label": "奖励模型", "info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)"},
},
"cmd_preview_btn": {"en": {"value": "Preview command"}, "zh": {"value": "预览命令"}},
"start_btn": {"en": {"value": "Start"}, "zh": {"value": "开始"}},
"stop_btn": {"en": {"value": "Abort"}, "zh": {"value": "中断"}},
"output_dir": {
"en": {
"label": "Output dir",
"info": "Directory for saving results."
},
"zh": {
"label": "输出目录",
"info": "保存结果的路径。"
}
},
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"loss_viewer": {
"en": {
"label": "Loss"
},
"zh": {
"label": "损失"
}
},
"predict": {
"en": {
"label": "Save predictions"
},
"zh": {
"label": "保存预测结果"
}
},
"load_btn": {
"en": {
"value": "Load model"
},
"zh": {
"value": "加载模型"
}
},
"unload_btn": {
"en": {
"value": "Unload model"
},
"zh": {
"value": "卸载模型"
}
},
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"system": {
"en": {
"placeholder": "System prompt (optional)"
},
"zh": {
"placeholder": "系统提示词(非必填)"
}
},
"tools": {
"en": {
"placeholder": "Tools (optional)"
},
"zh": {
"placeholder": "工具列表(非必填)"
}
},
"query": {
"en": {
"placeholder": "Input..."
},
"zh": {
"placeholder": "输入..."
}
},
"submit_btn": {
"en": {
"value": "Submit"
},
"zh": {
"value": "提交"
}
},
"clear_btn": {
"en": {
"value": "Clear history"
},
"zh": {
"value": "清空历史"
}
},
"max_length": {
"en": {
"label": "Maximum length"
},
"zh": {
"label": "最大长度"
}
},
"max_new_tokens": {
"en": {
"label": "Maximum new tokens"
},
"zh": {
"label": "最大生成长度"
}
},
"top_p": {
"en": {
"label": "Top-p"
},
"zh": {
"label": "Top-p 采样值"
}
},
"temperature": {
"en": {
"label": "Temperature"
},
"zh": {
"label": "温度系数"
}
"en": {"label": "Output dir", "info": "Directory for saving results."},
"zh": {"label": "输出目录", "info": "保存结果的路径。"},
},
"output_box": {"en": {"value": "Ready."}, "zh": {"value": "准备就绪。"}},
"loss_viewer": {"en": {"label": "Loss"}, "zh": {"label": "损失"}},
"predict": {"en": {"label": "Save predictions"}, "zh": {"label": "保存预测结果"}},
"load_btn": {"en": {"value": "Load model"}, "zh": {"value": "加载模型"}},
"unload_btn": {"en": {"value": "Unload model"}, "zh": {"value": "卸载模型"}},
"info_box": {"en": {"value": "Model unloaded, please load a model first."}, "zh": {"value": "模型未加载,请先加载模型。"}},
"system": {"en": {"placeholder": "System prompt (optional)"}, "zh": {"placeholder": "系统提示词(非必填)"}},
"tools": {"en": {"placeholder": "Tools (optional)"}, "zh": {"placeholder": "工具列表(非必填)"}},
"query": {"en": {"placeholder": "Input..."}, "zh": {"placeholder": "输入..."}},
"submit_btn": {"en": {"value": "Submit"}, "zh": {"value": "提交"}},
"clear_btn": {"en": {"value": "Clear history"}, "zh": {"value": "清空历史"}},
"max_length": {"en": {"label": "Maximum length"}, "zh": {"label": "最大长度"}},
"max_new_tokens": {"en": {"label": "Maximum new tokens"}, "zh": {"label": "最大生成长度"}},
"top_p": {"en": {"label": "Top-p"}, "zh": {"label": "Top-p 采样值"}},
"temperature": {"en": {"label": "Temperature"}, "zh": {"label": "温度系数"}},
"max_shard_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小GB",
"info": "单个模型文件的最大大小。"
}
"en": {"label": "Max shard size (GB)", "info": "The maximum size for a model file."},
"zh": {"label": "最大分块大小GB", "info": "单个模型文件的最大大小。"},
},
"export_quantization_bit": {
"en": {
"label": "Export quantization bit.",
"info": "Quantizing the exported model."
},
"zh": {
"label": "导出量化等级",
"info": "量化导出模型。"
}
"en": {"label": "Export quantization bit.", "info": "Quantizing the exported model."},
"zh": {"label": "导出量化等级", "info": "量化导出模型。"},
},
"export_quantization_dataset": {
"en": {
"label": "Export quantization dataset.",
"info": "The calibration dataset used for quantization."
},
"zh": {
"label": "导出量化数据集",
"info": "量化过程中使用的校准数据集。"
}
"en": {"label": "Export quantization dataset.", "info": "The calibration dataset used for quantization."},
"zh": {"label": "导出量化数据集", "info": "量化过程中使用的校准数据集。"},
},
"export_dir": {
"en": {
"label": "Export dir",
"info": "Directory to save exported model."
},
"zh": {
"label": "导出目录",
"info": "保存导出模型的文件夹路径。"
}
"en": {"label": "Export dir", "info": "Directory to save exported model."},
"zh": {"label": "导出目录", "info": "保存导出模型的文件夹路径。"},
},
"export_btn": {
"en": {
"value": "Export"
},
"zh": {
"value": "开始导出"
}
}
"export_btn": {"en": {"value": "Export"}, "zh": {"value": "开始导出"}},
}
ALERTS = {
"err_conflict": {
"en": "A process is in running, please abort it firstly.",
"zh": "任务已存在,请先中断训练。"
},
"err_exists": {
"en": "You have loaded a model, please unload it first.",
"zh": "模型已存在,请先卸载模型。"
},
"err_no_model": {
"en": "Please select a model.",
"zh": "请选择模型。"
},
"err_no_path": {
"en": "Model not found.",
"zh": "模型未找到。"
},
"err_no_dataset": {
"en": "Please choose a dataset.",
"zh": "请选择数据集。"
},
"err_no_adapter": {
"en": "Please select an adapter.",
"zh": "请选择一个适配器。"
},
"err_no_export_dir": {
"en": "Please provide export dir.",
"zh": "请填写导出目录"
},
"err_failed": {
"en": "Failed.",
"zh": "训练出错。"
},
"err_conflict": {"en": "A process is in running, please abort it firstly.", "zh": "任务已存在,请先中断训练。"},
"err_exists": {"en": "You have loaded a model, please unload it first.", "zh": "模型已存在,请先卸载模型。"},
"err_no_model": {"en": "Please select a model.", "zh": "请选择模型。"},
"err_no_path": {"en": "Model not found.", "zh": "模型未找到。"},
"err_no_dataset": {"en": "Please choose a dataset.", "zh": "请选择数据集。"},
"err_no_adapter": {"en": "Please select an adapter.", "zh": "请选择一个适配器。"},
"err_no_export_dir": {"en": "Please provide export dir.", "zh": "请填写导出目录"},
"err_failed": {"en": "Failed.", "zh": "训练出错。"},
"err_demo": {
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
"zh": "展示模式不支持训练,请先复制到私人空间。"
"zh": "展示模式不支持训练,请先复制到私人空间。",
},
"err_device_count": {
"en": "Multiple GPUs are not supported yet.",
"zh": "尚不支持多 GPU 训练。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"
},
"info_aborted": {
"en": "Ready.",
"zh": "准备就绪。"
},
"info_finished": {
"en": "Finished.",
"zh": "训练完毕。"
},
"info_loading": {
"en": "Loading model...",
"zh": "加载中……"
},
"info_unloading": {
"en": "Unloading model...",
"zh": "卸载中……"
},
"info_loaded": {
"en": "Model loaded, now you can chat with your model!",
"zh": "模型已加载,可以开始聊天了!"
},
"info_unloaded": {
"en": "Model unloaded.",
"zh": "模型已卸载。"
},
"info_exporting": {
"en": "Exporting model...",
"zh": "正在导出模型……"
},
"info_exported": {
"en": "Model exported.",
"zh": "模型导出完成。"
}
"err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
"info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
"info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
"info_finished": {"en": "Finished.", "zh": "训练完毕。"},
"info_loading": {"en": "Loading model...", "zh": "加载中……"},
"info_unloading": {"en": "Unloading model...", "zh": "卸载中……"},
"info_loaded": {"en": "Model loaded, now you can chat with your model!", "zh": "模型已加载,可以开始聊天了!"},
"info_unloaded": {"en": "Model unloaded.", "zh": "模型已卸载。"},
"info_exporting": {"en": "Exporting model...", "zh": "正在导出模型……"},
"info_exported": {"en": "Model exported.", "zh": "模型导出完成。"},
}

View File

@ -1,11 +1,11 @@
from typing import TYPE_CHECKING, Dict, List, Set
if TYPE_CHECKING:
from gradio.components import Component
class Manager:
def __init__(self) -> None:
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
@ -26,7 +26,7 @@ class Manager:
self.all_elems["top"]["quantization_bit"],
self.all_elems["top"]["template"],
self.all_elems["top"]["rope_scaling"],
self.all_elems["top"]["booster"]
self.all_elems["top"]["booster"],
}
def list_elems(self) -> List["Component"]:

View File

@ -1,12 +1,12 @@
import logging
import os
import time
import logging
import gradio as gr
from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import gradio as gr
import transformers
from gradio.components import Component # cannot use TYPE_CHECKING here
from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.callbacks import LogCallback
@ -18,12 +18,12 @@ from .common import get_module, get_save_dir, load_config
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar
if TYPE_CHECKING:
from .manager import Manager
class Runner:
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
self.manager = manager
self.demo_mode = demo_mode
@ -90,9 +90,12 @@ class Runner:
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")])
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
@ -131,12 +134,12 @@ class Runner:
create_new_adapter=get("train.create_new_adapter"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16")
bf16=(get("train.compute_type") == "bf16"),
)
args["disable_tqdm"] = True
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
args["create_new_adapter"] = (args["quantization_bit"] is None)
args["create_new_adapter"] = args["quantization_bit"] is None
if args["stage"] == "ppo":
args["reward_model"] = get_save_dir(
@ -161,9 +164,12 @@ class Runner:
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")])
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
@ -187,7 +193,7 @@ class Runner:
max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"),
temperature=get("eval.temperature"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
)
if get("eval.predict"):
@ -197,7 +203,9 @@ class Runner:
return args
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
def _preview(
self, data: Dict[Component, Any], do_train: bool
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train, from_preview=True)
if error:
gr.Warning(error)
@ -235,9 +243,11 @@ class Runner:
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
self.running = True
lang = get("top.lang")
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
"{}.output_dir".format("train" if self.do_train else "eval")
))
output_dir = get_save_dir(
get("top.model_name"),
get("top.finetuning_type"),
get("{}.output_dir".format("train" if self.do_train else "eval")),
)
while self.thread.is_alive():
time.sleep(2)

View File

@ -1,13 +1,15 @@
import os
import json
import gradio as gr
from typing import TYPE_CHECKING, Any, Dict
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict
import gradio as gr
from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth
from .common import get_save_dir
if TYPE_CHECKING:
from ..extras.callbacks import LogCallback
@ -22,16 +24,13 @@ def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
label = "Running {:d}/{:d}: {} < {}".format(
callback.cur_steps,
callback.max_steps,
callback.elapsed_time,
callback.remaining_time
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
)
return gr.update(label=label, value=percentage, visible=True)
def get_time() -> str:
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
def can_quantize(finetuning_type: str) -> Dict[str, Any]:

View File

@ -3,11 +3,12 @@
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
from typing import Optional
import fire
import torch
from typing import Optional
from deepspeed.accelerator import get_accelerator # type: ignore
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from deepspeed.accelerator import get_accelerator # type: ignore
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from llmtuner import ChatModel
@ -16,25 +17,13 @@ def calculate_flops(
model_name_or_path: str,
batch_size: Optional[int] = 1,
seq_length: Optional[int] = 256,
flash_attn: Optional[bool] = False
flash_attn: Optional[bool] = False,
):
with get_accelerator().device(0):
chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path,
template="vanilla",
flash_attn=flash_attn
))
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="vanilla", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
input_dict = {
"input_ids": fake_input,
"labels": fake_input.clone()
}
flops, macs, params = get_model_profile(
chat_model.model,
kwargs=input_dict,
print_profile=True,
detailed=True
)
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
print("FLOPs:", flops)
print("MACs:", macs)
print("Params:", params)

View File

@ -3,12 +3,13 @@
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
import fire
import math
import torch
from tqdm import tqdm
from typing import Optional
import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForSeq2Seq
from llmtuner.data import get_dataset
@ -17,8 +18,8 @@ from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
BASE_BS = 4_000_000 # from llama paper
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
BASE_BS = 4_000_000 # from llama paper
def calculate_lr(
@ -26,18 +27,20 @@ def calculate_lr(
dataset: str,
cutoff_len: int, # i.e. maximum input length during training
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
is_mistral: bool, # mistral model uses a smaller learning rate,
dataset_dir: Optional[str] = "data"
is_mistral: bool, # mistral model uses a smaller learning rate,
dataset_dir: Optional[str] = "data",
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
stage="sft",
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template="default",
cutoff_len=cutoff_len,
output_dir="dummy_dir"
))
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(
stage="sft",
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template="default",
cutoff_len=cutoff_len,
output_dir="dummy_dir",
)
)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
@ -49,14 +52,16 @@ def calculate_lr(
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"])
batch_max_len = cutoff_len * batch_size # max tokens in a batch
batch_max_len = cutoff_len * batch_size # max tokens in a batch
valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral else lr
print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len
))
print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len
)
)
if __name__ == "__main__":

View File

@ -4,32 +4,28 @@
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
import os
import fire
import json
import torch
from tqdm import tqdm
import os
from collections import OrderedDict
from safetensors.torch import save_file
from transformers.modeling_utils import (
shard_checkpoint,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME
)
from typing import Any, Dict, Optional
import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
CONFIG_NAME = "config.json"
def save_weight(
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: bool
):
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
@ -41,8 +37,8 @@ def save_weight(
if "W_pack" in key:
proj_size = value.size(0) // 3
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size:2*proj_size, :]
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*proj_size:, :]
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size : 2 * proj_size, :]
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2 * proj_size :, :]
elif "lm_head" in key:
llama2_state_dict[key] = torch.nn.functional.normalize(value)
else:
@ -56,7 +52,7 @@ def save_weight(
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
else:
@ -66,10 +62,7 @@ def save_weight(
print("Model weights saved in {}".format(output_dir))
def save_config(
input_dir: str,
output_dir: str
):
def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f)
@ -83,19 +76,14 @@ def save_config(
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_baichuan2(
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: Optional[bool] = False
):
def llamafy_baichuan2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
raise print("Output dir already exists", e)
save_weight(input_dir, output_dir, shard_size, save_safetensors)
save_config(input_dir, output_dir)
save_config(input_dir, output_dir)
if __name__ == "__main__":

View File

@ -3,32 +3,28 @@
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
import os
import fire
import json
import torch
from tqdm import tqdm
import os
from collections import OrderedDict
from safetensors.torch import save_file
from transformers.modeling_utils import (
shard_checkpoint,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME
)
from typing import Any, Dict, Optional
import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
CONFIG_NAME = "config.json"
def save_weight(
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: bool
):
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
internlm2_config_dict: Dict[str, Any] = json.load(f)
@ -50,8 +46,10 @@ def save_weight(
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[
q_size : q_size + kv_size, ...
]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...]
elif "wo" in key:
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
elif "attention_norm" in key:
@ -85,10 +83,7 @@ def save_weight(
print("Model weights saved in {}".format(output_dir))
def save_config(
input_dir: str,
output_dir: str
):
def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f)
@ -103,12 +98,7 @@ def save_config(
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_internlm2(
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: Optional[bool] = False
):
def llamafy_internlm2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:

View File

@ -3,39 +3,36 @@
# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
import os
import fire
import json
import torch
from tqdm import tqdm
import os
from collections import OrderedDict
from typing import Any, Dict, Optional
import fire
import torch
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
shard_checkpoint,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME
shard_checkpoint,
)
from transformers.utils import check_min_version
from typing import Any, Dict, Optional
try:
check_min_version("4.34.0")
except:
except Exception:
raise ValueError("Please upgrade `transformers` to 4.34.0")
CONFIG_NAME = "config.json"
def save_weight(
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: bool
) -> str:
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
@ -57,13 +54,15 @@ def save_weight(
if "attn.c_attn" in key:
proj_size = value.size(0) // 3
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[proj_size:2*proj_size, ...]
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2*proj_size:, ...]
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[
proj_size : 2 * proj_size, ...
]
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2 * proj_size :, ...]
elif "attn.c_proj" in key:
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = (
torch.zeros_like(value[:, 0]).squeeze()
)
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = torch.zeros_like(
value[:, 0]
).squeeze()
elif "ln_1" in key:
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
elif "ln_2" in key:
@ -99,11 +98,7 @@ def save_weight(
return str(torch_dtype).replace("torch.", "")
def save_config(
input_dir: str,
output_dir: str,
torch_dtype: str
):
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f)
@ -133,12 +128,7 @@ def save_config(
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_qwen(
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: Optional[bool] = False
):
def llamafy_qwen(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:

View File

@ -4,12 +4,13 @@
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
import os
from typing import TYPE_CHECKING, Optional
import fire
import torch
import torch.nn as nn
from typing import TYPE_CHECKING, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
if TYPE_CHECKING:
@ -17,7 +18,6 @@ if TYPE_CHECKING:
class Shell(nn.Module):
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=False)
@ -26,7 +26,7 @@ class Shell(nn.Module):
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]):
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]): # noqa: C403
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
parent_module = model.get_submodule(parent_name)
@ -35,7 +35,7 @@ def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
weight = getattr(base_layer, "weight", None)
bias = getattr(base_layer, "bias", None)
setattr(parent_module, child_name, Shell(weight, bias))
print("Model unwrapped.")
@ -60,7 +60,7 @@ def quantize_loftq(
lora_dropout=0.1,
target_modules=[name.strip() for name in lora_target.split(",")],
init_lora_weights="loftq",
loftq_config=loftq_config
loftq_config=loftq_config,
)
# Init LoftQ model