format style

This commit is contained in:
hiyouga 2024-01-20 20:15:56 +08:00
parent f6d6e00337
commit 638234ceee
73 changed files with 1492 additions and 2325 deletions

View File

@ -7,7 +7,7 @@ line-length = 119
target-version = ["py38"] target-version = ["py38"]
[tool.ruff] [tool.ruff]
ignore = ["C901", "E501", "E741", "W605"] ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
select = ["C", "E", "F", "I", "W"] select = ["C", "E", "F", "I", "W"]
line-length = 119 line-length = 119

View File

@ -1,4 +1,5 @@
import os import os
import uvicorn import uvicorn
from llmtuner import ChatModel, create_app from llmtuner import ChatModel, create_app

View File

@ -1,10 +1,12 @@
from llmtuner import ChatModel from llmtuner import ChatModel
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
try: try:
import platform import platform
if platform.system() != "Windows": if platform.system() != "Windows":
import readline import readline # noqa: F401
except ImportError: except ImportError:
print("Install `readline` for a better experience.") print("Install `readline` for a better experience.")

View File

@ -8,12 +8,4 @@ from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.4.0" __version__ = "0.4.0"
__all__ = [ __all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
"create_app",
"ChatModel",
"Evaluator",
"export_model",
"run_exp",
"create_ui",
"create_web_demo"
]

View File

@ -1,30 +1,29 @@
import os
import json
import asyncio import asyncio
from typing import List, Tuple import json
from pydantic import BaseModel import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import List, Tuple
from pydantic import BaseModel
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
from .protocol import ( from .protocol import (
Role,
Finish,
ModelCard,
ModelList,
ChatMessage,
DeltaMessage,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionStreamResponse,
ChatCompletionResponseChoice, ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice, ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage, ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
Finish,
ModelCard,
ModelList,
Role,
ScoreEvaluationRequest, ScoreEvaluationRequest,
ScoreEvaluationResponse ScoreEvaluationResponse,
)
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available
) )
@ -42,15 +41,15 @@ if is_uvicorn_available():
@asynccontextmanager @asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory async def lifespan(app: "FastAPI"): # collects GPU memory
yield yield
torch_gc() torch_gc()
def to_json(data: BaseModel) -> str: def to_json(data: BaseModel) -> str:
try: # pydantic v2 try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except: # pydantic v1 except Exception: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False) return data.json(exclude_unset=True, ensure_ascii=False)
@ -90,8 +89,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
history = [] history = []
if len(prev_messages) % 2 == 0: if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2): for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT: if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content]) history.append([prev_messages[i].content, prev_messages[i + 1].content])
else: else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else: else:
@ -107,65 +106,65 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat( responses = chat_model.chat(
query, history, system, query,
history,
system,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
num_return_sequences=request.n num_return_sequences=request.n,
) )
prompt_length, response_length = 0, 0 prompt_length, response_length = 0, 0
choices = [] choices = []
for i, response in enumerate(responses): for i, response in enumerate(responses):
choices.append(ChatCompletionResponseChoice( choices.append(
index=i, ChatCompletionResponseChoice(
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text), index=i,
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
)) finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH,
)
)
prompt_length = response.prompt_length prompt_length = response.prompt_length
response_length += response.response_length response_length += response.response_length
usage = ChatCompletionResponseUsage( usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length, prompt_tokens=prompt_length,
completion_tokens=response_length, completion_tokens=response_length,
total_tokens=prompt_length+response_length total_tokens=prompt_length + response_length,
) )
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
def stream_chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): def stream_chat_completion(
query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest
):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
delta=DeltaMessage(role=Role.ASSISTANT, content=""),
finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk) yield to_json(chunk)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
query, history, system, query,
history,
system,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens max_new_tokens=request.max_tokens,
): ):
if len(new_text) == 0: if len(new_text) == 0:
continue continue
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0, delta=DeltaMessage(content=new_text), finish_reason=None
delta=DeltaMessage(content=new_text),
finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk) yield to_json(chunk)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
index=0,
delta=DeltaMessage(),
finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk) yield to_json(chunk)
yield "[DONE]" yield "[DONE]"

View File

@ -1,8 +1,9 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from pydantic import BaseModel, Field
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field
@unique @unique
class Role(str, Enum): class Role(str, Enum):

View File

@ -1,18 +1,18 @@
import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread from threading import Thread
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer, Role from ..data import Role, get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import dispatch_model, load_model_and_tokenizer
from ..hparams import get_infer_args from ..hparams import get_infer_args
from ..model import dispatch_model, load_model_and_tokenizer
@dataclass @dataclass
class Response: class Response:
response_text: str response_text: str
response_length: int response_length: int
prompt_length: int prompt_length: int
@ -20,10 +20,9 @@ class Response:
class ChatModel: class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.can_generate = (finetuning_args.stage == "sft") self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer( self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) )
@ -37,7 +36,7 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs **input_kwargs,
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
messages = [] messages = []
if history is not None: if history is not None:
@ -63,16 +62,18 @@ class ChatModel:
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None)
generating_args = self.generating_args.to_dict() generating_args = self.generating_args.to_dict()
generating_args.update(dict( generating_args.update(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"], dict(
temperature=temperature or generating_args["temperature"], do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
top_p=top_p or generating_args["top_p"], temperature=temperature or generating_args["temperature"],
top_k=top_k or generating_args["top_k"], top_p=top_p or generating_args["top_p"],
num_return_sequences=num_return_sequences or 1, top_k=top_k or generating_args["top_k"],
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], num_return_sequences=num_return_sequences or 1,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
pad_token_id=self.tokenizer.pad_token_id eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
)) pad_token_id=self.tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1: if isinstance(num_return_sequences, int) and num_return_sequences > 1:
generating_args["do_sample"] = True generating_args["do_sample"] = True
@ -88,7 +89,7 @@ class ChatModel:
gen_kwargs = dict( gen_kwargs = dict(
inputs=input_ids, inputs=input_ids,
generation_config=GenerationConfig(**generating_args), generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor() logits_processor=get_logits_processor(),
) )
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
@ -100,7 +101,7 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs **input_kwargs,
) -> List[Response]: ) -> List[Response]:
r""" r"""
Args: query, history, system, **input_kwargs Args: query, history, system, **input_kwargs
@ -117,12 +118,14 @@ class ChatModel:
for i in range(len(response)): for i in range(len(response)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero() eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i]) response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(Response( results.append(
response_text=response[i], Response(
response_length=response_length, response_text=response[i],
prompt_length=prompt_length, response_length=response_length,
finish_reason="stop" if len(eos_index) else "length" prompt_length=prompt_length,
)) finish_reason="stop" if len(eos_index) else "length",
)
)
return results return results
@ -133,7 +136,7 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs) gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
@ -145,11 +148,7 @@ class ChatModel:
yield from streamer yield from streamer
@torch.inference_mode() @torch.inference_mode()
def get_scores( def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
self,
batch_input: List[str],
**input_kwargs
) -> List[float]:
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda") device = getattr(self.model.pretrained_model, "device", "cuda")
@ -159,7 +158,7 @@ class ChatModel:
truncation=True, truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024), max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
return_tensors="pt", return_tensors="pt",
add_special_tokens=True add_special_tokens=True,
).to(device) ).to(device)
input_ids: torch.Tensor = inputs["input_ids"] input_ids: torch.Tensor = inputs["input_ids"]

View File

@ -1,6 +1,6 @@
from .loader import get_dataset from .loader import get_dataset
from .template import get_template_and_fix_tokenizer, templates from .template import get_template_and_fix_tokenizer, templates
from .utils import split_dataset, Role from .utils import Role, split_dataset
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"] __all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"]

View File

@ -27,7 +27,9 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
if dataset_attr.response: if dataset_attr.response:
if isinstance(examples[dataset_attr.response][i], list): if isinstance(examples[dataset_attr.response][i], list):
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]] response = [
{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]
]
else: else:
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}] response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
else: else:
@ -47,10 +49,10 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
dataset_attr.user_tag: Role.USER, dataset_attr.user_tag: Role.USER,
dataset_attr.assistant_tag: Role.ASSISTANT, dataset_attr.assistant_tag: Role.ASSISTANT,
dataset_attr.observation_tag: Role.OBSERVATION, dataset_attr.observation_tag: Role.OBSERVATION,
dataset_attr.function_tag: Role.FUNCTION dataset_attr.function_tag: Role.FUNCTION,
} }
for i, messages in enumerate(examples[dataset_attr.messages]): for i, messages in enumerate(examples[dataset_attr.messages]):
messages = messages[:len(messages) // 2 * 2] # should be multiples of 2 messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0: if len(messages) == 0:
continue continue
@ -65,7 +67,9 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
if message[dataset_attr.role_tag] not in accept_tags: if message[dataset_attr.role_tag] not in accept_tags:
raise ValueError("Invalid role tag in {}.".format(messages)) raise ValueError("Invalid role tag in {}.".format(messages))
prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}) prompt.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
last_message = prompt.pop(-1) last_message = prompt.pop(-1)
response.append(last_message) response.append(last_message)
@ -98,12 +102,7 @@ def align_dataset(
kwargs = dict( kwargs = dict(
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache), load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset" desc="Converting format of dataset",
) )
return dataset.map( return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs)
convert_func,
batched=True,
remove_columns=column_names,
**kwargs
)

View File

@ -76,7 +76,11 @@ class ToolFormatter:
required = ", required" if name in tool["parameters"].get("required", []) else "" required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else "" enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format( param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
name=name, type=param.get("type", ""), required=required, desc=param.get("description", ""), enum=enum name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
) )
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
@ -85,9 +89,7 @@ class ToolFormatter:
tool_names.append(tool["name"]) tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format( return TOOL_SYSTEM_PROMPT.format(
tool_text=tool_text, tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
tool_names=", ".join(tool_names),
format_prompt=JSON_FORMAT_PROMPT
) )
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]: def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:

View File

@ -1,16 +1,16 @@
import os
import inspect import inspect
import os
from typing import TYPE_CHECKING, List, Literal, Union from typing import TYPE_CHECKING, List, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .utils import checksum
from .parser import get_dataset_list
from .aligner import align_dataset from .aligner import align_dataset
from .template import get_template_and_fix_tokenizer from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .utils import checksum
if TYPE_CHECKING: if TYPE_CHECKING:
@ -18,8 +18,8 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr from .parser import DatasetAttr
from ..hparams import ModelArguments, DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -44,14 +44,14 @@ def load_single_dataset(
elif dataset_attr.load_from == "file": elif dataset_attr.load_from == "file":
data_files = [] data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path): for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name)) data_files.append(os.path.join(local_path, file_name))
if data_path is None: if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None) data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None): elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.") raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file elif os.path.isfile(local_path): # is file
data_files.append(local_path) data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else: else:
@ -78,12 +78,12 @@ def load_single_dataset(
split=data_args.split, split=data_args.split,
cache_dir=cache_dir, cache_dir=cache_dir,
token=model_args.ms_hub_token, token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")) use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
).to_hf_dataset() ).to_hf_dataset()
except ImportError: except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`") raise ImportError("Please install modelscope via `pip install modelscope -U`")
else: else:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0 if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True} kwargs = {"trust_remote_code": True}
else: else:
kwargs = {} kwargs = {}
@ -97,13 +97,13 @@ def load_single_dataset(
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")), streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
**kwargs **kwargs,
) )
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.max_samples is not None: # truncate dataset if data_args.max_samples is not None: # truncate dataset
num_samples = min(data_args.max_samples, len(dataset)) num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples)) dataset = dataset.select(range(num_samples))
@ -113,7 +113,7 @@ def load_single_dataset(
def merge_dataset( def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments" training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
if len(all_datasets) == 1: if len(all_datasets) == 1:
return all_datasets[0] return all_datasets[0]
@ -128,7 +128,7 @@ def merge_dataset(
datasets=all_datasets, datasets=all_datasets,
probabilities=data_args.interleave_probs, probabilities=data_args.interleave_probs,
seed=training_args.seed, seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
) )
else: else:
raise ValueError("Unknown mixing strategy.") raise ValueError("Unknown mixing strategy.")
@ -160,7 +160,7 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset"):
all_datasets = [] all_datasets = []
for dataset_attr in get_dataset_list(data_args): # TODO: add split for dataset_attr in get_dataset_list(data_args): # TODO: add split
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args) dataset = merge_dataset(all_datasets, data_args, training_args)
@ -174,15 +174,10 @@ def get_dataset(
kwargs = dict( kwargs = dict(
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache), load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset" desc="Running tokenizer on dataset",
) )
dataset = dataset.map( dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs
)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path): if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save: if training_args.should_save:

View File

@ -1,18 +1,18 @@
import os
import json import json
from typing import TYPE_CHECKING, List, Literal, Optional import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Literal, Optional
from ..extras.constants import DATA_CONFIG from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope
if TYPE_CHECKING: if TYPE_CHECKING:
from ..hparams import DataArguments from ..hparams import DataArguments
@dataclass @dataclass
class DatasetAttr: class DatasetAttr:
load_from: Literal["hf_hub", "ms_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None dataset_sha1: Optional[str] = None
@ -49,7 +49,9 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception as err: except Exception as err:
if data_args.dataset is not None: if data_args.dataset is not None:
raise ValueError("Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))) raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None dataset_info = None
if data_args.interleave_probs is not None: if data_args.interleave_probs is not None:
@ -74,7 +76,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr(
"file", "file",
dataset_name=dataset_info[name]["file_name"], dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None) dataset_sha1=dataset_info[name].get("file_sha1", None),
) )
dataset_attr.subset = dataset_info[name].get("subset", None) dataset_attr.subset = dataset_info[name].get("subset", None)

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger from ..extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
@ -17,9 +18,7 @@ logger = get_logger(__name__)
def preprocess_pretrain_dataset( def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments"
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` # build grouped texts with format `X1 X2 X3 ...`
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))] text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
@ -35,7 +34,7 @@ def preprocess_pretrain_dataset(
total_length = (total_length // block_size) * block_size total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len # split by chunks of cutoff_len
result = { result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)] k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items() for k, t in concatenated_examples.items()
} }
return result return result
@ -57,9 +56,11 @@ def preprocess_supervised_dataset(
messages = examples["prompt"][i] + examples["response"][i] messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], [] input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( for turn_idx, (source_ids, target_ids) in enumerate(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len template.encode_multiturn(
)): tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)
):
if data_args.train_on_prompt: if data_args.train_on_prompt:
source_mask = source_ids source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos: elif turn_idx != 0 and template.efficient_eos:
@ -96,9 +97,9 @@ def preprocess_packed_supervised_dataset(
continue continue
messages = examples["prompt"][i] + examples["response"][i] messages = examples["prompt"][i] + examples["response"][i]
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( for turn_idx, (source_ids, target_ids) in enumerate(
tokenizer, messages, examples["system"][i], examples["tools"][i] template.encode_multiturn(tokenizer, messages, examples["system"][i], examples["tools"][i])
)): ):
if data_args.train_on_prompt: if data_args.train_on_prompt:
source_mask = source_ids source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos: elif turn_idx != 0 and template.efficient_eos:
@ -119,9 +120,9 @@ def preprocess_packed_supervised_dataset(
total_length = (total_length // block_size) * block_size total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len # split by chunks of cutoff_len
for i in range(0, total_length, block_size): for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size]) model_inputs["input_ids"].append(input_ids[i : i + block_size])
model_inputs["attention_mask"].append([1] * block_size) model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size]) model_inputs["labels"].append(labels[i : i + block_size])
return model_inputs return model_inputs
@ -191,9 +192,11 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format( print(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) "labels:\n{}".format(
)) tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
)
)
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
@ -232,10 +235,14 @@ def get_preprocess_and_print_func(
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm": elif stage == "rm":
preprocess_func = partial(preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args) preprocess_func = partial(
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else: else:
preprocess_func = partial(preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args) preprocess_func = partial(
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function return preprocess_func, print_function

View File

@ -2,8 +2,8 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .formatter import FunctionFormatter, StringFormatter, ToolFormatter
from .utils import Role from .utils import Role
from .formatter import StringFormatter, FunctionFormatter, ToolFormatter
if TYPE_CHECKING: if TYPE_CHECKING:
@ -15,7 +15,6 @@ logger = get_logger(__name__)
@dataclass @dataclass
class Template: class Template:
format_user: Callable format_user: Callable
format_assistant: Callable format_assistant: Callable
format_system: Callable format_system: Callable
@ -34,7 +33,7 @@ class Template:
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000 cutoff_len: Optional[int] = 1_000_000,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
r""" r"""
Returns a single pair of token ids representing prompt and response respectively. Returns a single pair of token ids representing prompt and response respectively.
@ -53,7 +52,7 @@ class Template:
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: str, system: str,
tools: str, tools: str,
cutoff_len: Optional[int] = 1_000_000 cutoff_len: Optional[int] = 1_000_000,
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
Returns multiple pairs of token ids representing prompts and responses respectively. Returns multiple pairs of token ids representing prompts and responses respectively.
@ -67,7 +66,7 @@ class Template:
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: str, system: str,
tools: str, tools: str,
cutoff_len: int cutoff_len: int,
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
Encodes formatted inputs to pairs of token ids. Encodes formatted inputs to pairs of token ids.
@ -102,19 +101,17 @@ class Template:
if total_length >= cutoff_len: if total_length >= cutoff_len:
break break
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length] encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
total_length += len(encoded_messages[i]) total_length += len(encoded_messages[i])
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)] encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
total_length += len(encoded_messages[i+1]) total_length += len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1])) encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
return encoded_pairs return encoded_pairs
def _convert_elements_to_ids( def _convert_elements_to_ids(
self, self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
tokenizer: "PreTrainedTokenizer",
elements: List[Union[str, Dict[str, str]]]
) -> List[int]: ) -> List[int]:
r""" r"""
Converts elements to token ids. Converts elements to token ids.
@ -139,14 +136,13 @@ class Template:
@dataclass @dataclass
class Llama2Template(Template): class Llama2Template(Template):
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: str, system: str,
tools: str, tools: str,
cutoff_len: int cutoff_len: int,
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
Encodes formatted inputs to pairs of token ids. Encodes formatted inputs to pairs of token ids.
@ -182,12 +178,12 @@ class Llama2Template(Template):
if total_length >= cutoff_len: if total_length >= cutoff_len:
break break
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length] encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
total_length += len(encoded_messages[i]) total_length += len(encoded_messages[i])
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)] encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
total_length += len(encoded_messages[i+1]) total_length += len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1])) encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
return encoded_pairs return encoded_pairs
@ -207,32 +203,26 @@ def register_template(
separator: Optional[List[Union[str, Dict[str, str]]]] = "", separator: Optional[List[Union[str, Dict[str, str]]]] = "",
stop_words: Optional[List[str]] = [], stop_words: Optional[List[str]] = [],
efficient_eos: Optional[bool] = False, efficient_eos: Optional[bool] = False,
replace_eos: Optional[bool] = False replace_eos: Optional[bool] = False,
) -> None: ) -> None:
template_class = Llama2Template if name.startswith("llama2") else Template template_class = Llama2Template if name.startswith("llama2") else Template
templates[name] = template_class( templates[name] = template_class(
format_user=format_user or StringFormatter(container=["{{content}}"]), format_user=format_user or StringFormatter(container=["{{content}}"]),
format_assistant=format_assistant or StringFormatter(container=[ format_assistant=format_assistant or StringFormatter(container=["{{content}}", {"eos_token"}]),
"{{content}}", {"eos_token"}
]),
format_system=format_system or StringFormatter(container=["{{content}}"]), format_system=format_system or StringFormatter(container=["{{content}}"]),
format_tool=format_tool or ToolFormatter(type="default"), format_tool=format_tool or ToolFormatter(type="default"),
format_observation=format_observation or format_user, format_observation=format_observation or format_user,
format_function=format_function or FunctionFormatter(container=[ format_function=format_function
"Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"} or FunctionFormatter(container=["Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}]),
]),
system=system, system=system,
separator=separator, separator=separator,
stop_words=stop_words, stop_words=stop_words,
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos replace_eos=replace_eos,
) )
def get_template_and_fix_tokenizer( def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer") -> Template:
name: str,
tokenizer: "PreTrainedTokenizer"
) -> Template:
if tokenizer.eos_token_id is None: if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>" tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token)) logger.info("Add eos token: {}".format(tokenizer.eos_token))
@ -241,7 +231,7 @@ def get_template_and_fix_tokenizer(
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token)) logger.info("Add pad token: {}".format(tokenizer.pad_token))
if name is None: # for pre-training if name is None: # for pre-training
return None return None
template = templates.get(name, None) template = templates.get(name, None)
@ -258,8 +248,7 @@ def get_template_and_fix_tokenizer(
if stop_words: if stop_words:
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
replace_additional_special_tokens=False
) )
logger.info("Add {} to stop words.".format(",".join(stop_words))) logger.info("Add {} to stop words.".format(",".join(stop_words)))
@ -268,263 +257,153 @@ def get_template_and_fix_tokenizer(
register_template( register_template(
name="alpaca", name="alpaca",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["### Instruction:\n{{content}}\n\n### Response:\n"]),
"### Instruction:\n{{content}}\n\n### Response:\n"
]),
system=( system=(
"Below is an instruction that describes a task. " "Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
"Write a response that appropriately completes the request."
), ),
separator=[ separator=["\n\n"],
"\n\n"
]
) )
register_template( register_template(
name="aquila", name="aquila",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["Human: {{content}}###Assistant:"]),
"Human: {{content}}###Assistant:" format_assistant=StringFormatter(container=["{{content}}"]),
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
system=( system=(
"A chat between a curious human and an artificial intelligence assistant. " "A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions." "The assistant gives helpful, detailed, and polite answers to the human's questions."
), ),
separator=[ separator=["###"],
"###" stop_words=["</s>"],
], efficient_eos=True,
stop_words=[
"</s>"
],
efficient_eos=True
) )
register_template( register_template(
name="baichuan", name="baichuan",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
{"token": "<reserved_102>"}, format_assistant=StringFormatter(container=["{{content}}"]),
"{{content}}", efficient_eos=True,
{"token": "<reserved_103>"}
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
efficient_eos=True
) )
register_template( register_template(
name="baichuan2", name="baichuan2",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
{"token": "<reserved_106>"}, format_assistant=StringFormatter(container=["{{content}}"]),
"{{content}}", efficient_eos=True,
{"token": "<reserved_107>"}
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
efficient_eos=True
) )
register_template( register_template(
name="belle", name="belle", format_user=StringFormatter(container=["Human: {{content}}\n\nBelle: "]), separator=["\n\n"]
format_user=StringFormatter(container=[
"Human: {{content}}\n\nBelle: "
]),
separator=[
"\n\n"
]
) )
register_template( register_template(
name="bluelm", name="bluelm",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
{"token": "[|Human|]:"},
"{{content}}",
{"token": "[|AI|]:"}
])
) )
register_template( register_template(
name="chatglm2", name="chatglm2",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
"[Round {{idx}}]\n\n问:{{content}}\n\n答:" format_assistant=StringFormatter(container=["{{content}}"]),
]), format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_assistant=StringFormatter(container=[ separator=["\n\n"],
"{{content}}" efficient_eos=True,
]),
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{content}}"
]),
separator=[
"\n\n"
],
efficient_eos=True
) )
register_template( register_template(
name="chatglm3", name="chatglm3",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
{"token": "<|user|>"}, format_assistant=StringFormatter(container=["\n" "{{content}}"]),
"\n", format_system=StringFormatter(
"{{content}}", container=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
{"token": "<|assistant|>"} ),
]), format_observation=StringFormatter(container=[{"token": "<|observation|>"}, "\n", "{{content}}"]),
format_assistant=StringFormatter(container=[ format_function=FunctionFormatter(container=["{{name}}\n{{arguments}}"]),
"\n"
"{{content}}"
]),
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
{"token": "<|system|>"},
"\n",
"{{content}}"
]),
format_observation=StringFormatter(container=[
{"token": "<|observation|>"},
"\n",
"{{content}}"
]),
format_function=FunctionFormatter(container=[
"{{name}}\n{{arguments}}"
]),
system=( system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. " "You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown." "Follow the user's instructions carefully. Respond using markdown."
), ),
stop_words=[ stop_words=["<|user|>", "<|observation|>"],
"<|user|>", efficient_eos=True,
"<|observation|>"
],
efficient_eos=True
) )
register_template( register_template(
name="codegeex2", name="codegeex2", format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"])
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{content}}"
])
) )
register_template( register_template(name="deepseek", format_user=StringFormatter(container=["User: {{content}}\n\nAssistant:"]))
name="deepseek",
format_user=StringFormatter(container=[
"User: {{content}}\n\nAssistant:"
])
)
register_template( register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["### Instruction:\n{{content}}\n### Response:\n"]),
"### Instruction:\n{{content}}\n### Response:\n" format_assistant=StringFormatter(container=["{{content}}"]),
]),
format_assistant=StringFormatter(container=[
"{{content}}"
]),
system=( system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, " "You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. " "developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, " "For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n" "and other non-computer science questions, you will refuse to answer\n"
), ),
separator=[ separator=["\n", {"token": "<|EOT|>"}, "\n"],
"\n", stop_words=["<|EOT|>"],
{"token": "<|EOT|>"}, efficient_eos=True,
"\n"
],
stop_words=[
"<|EOT|>"
],
efficient_eos=True
) )
register_template( register_template(
name="default", name="default",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["Human: {{content}}\nAssistant: "]),
"Human: {{content}}\nAssistant: "
]),
system=( system=(
"A chat between a curious user and an artificial intelligence assistant. " "A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n" "The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
), ),
separator=[ separator=["\n"],
"\n"
]
) )
register_template( register_template(
name="falcon", name="falcon",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["User: {{content}}\nFalcon:"]),
"User: {{content}}\nFalcon:" format_assistant=StringFormatter(container=["{{content}}"]),
]), separator=["\n"],
format_assistant=StringFormatter(container=[ efficient_eos=True,
"{{content}}"
]),
separator=[
"\n"
],
efficient_eos=True
) )
register_template( register_template(
name="intern", name="intern",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
"<|User|>:{{content}}", format_assistant=StringFormatter(container=["{{content}}"]),
{"token": "<eoh>"}, separator=[{"token": "<eoa>"}, "\n"],
"\n<|Bot|>:" stop_words=["<eoa>"],
]), efficient_eos=True,
format_assistant=StringFormatter(container=[
"{{content}}"
]),
separator=[
{"token": "<eoa>"},
"\n"
],
stop_words=[
"<eoa>"
],
efficient_eos=True
) )
register_template( register_template(
name="intern2", name="intern2",
format_user=StringFormatter(container=[ format_user=StringFormatter(
{"token": "[UNUSED_TOKEN_146]"}, container=[
"user\n{{content}}", {"token": "[UNUSED_TOKEN_146]"},
{"token": "[UNUSED_TOKEN_145]"}, "user\n{{content}}",
"\n", {"token": "[UNUSED_TOKEN_145]"},
{"token": "[UNUSED_TOKEN_146]"}, "\n",
"assistant\n" {"token": "[UNUSED_TOKEN_146]"},
]), "assistant\n",
format_assistant=StringFormatter(container=[ ]
"{{content}}" ),
]), format_assistant=StringFormatter(container=["{{content}}"]),
format_system=StringFormatter(container=[ format_system=StringFormatter(
{"token": "[UNUSED_TOKEN_146]"}, container=[{"token": "[UNUSED_TOKEN_146]"}, "system\n{{content}}", {"token": "[UNUSED_TOKEN_145]"}, "\n"]
"system\n{{content}}", ),
{"token": "[UNUSED_TOKEN_145]"},
"\n"
]),
system=( system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n" "You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed " "- InternLM (书生·浦语) is a conversational language model that is developed "
@ -532,14 +411,9 @@ register_template(
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen " "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
"by the user such as English and 中文." "by the user such as English and 中文."
), ),
separator=[ separator=[{"token": "[UNUSED_TOKEN_145]"}, "\n"],
{"token": "[UNUSED_TOKEN_145]"}, stop_words=["[UNUSED_TOKEN_145]"],
"\n" efficient_eos=True,
],
stop_words=[
"[UNUSED_TOKEN_145]"
],
efficient_eos=True
) )
@ -556,7 +430,7 @@ register_template(
"If a question does not make any sense, or is not factually coherent, " "If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. " "explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information." "If you don't know the answer to a question, please don't share false information."
) ),
) )
@ -564,142 +438,83 @@ register_template(
name="llama2_zh", name="llama2_zh",
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]), format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
system="You are a helpful assistant. 你是一个乐于助人的助手。" system="You are a helpful assistant. 你是一个乐于助人的助手。",
) )
register_template( register_template(name="mistral", format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]))
name="mistral",
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"])
)
register_template( register_template(
name="openchat", name="openchat",
format_user=StringFormatter(container=[ format_user=StringFormatter(
"GPT4 Correct User: {{content}}", container=["GPT4 Correct User: {{content}}", {"token": "<|end_of_turn|>"}, "GPT4 Correct Assistant:"]
{"token": "<|end_of_turn|>"}, ),
"GPT4 Correct Assistant:" format_assistant=StringFormatter(container=["{{content}}"]),
]), separator=[{"token": "<|end_of_turn|>"}],
format_assistant=StringFormatter(container=[ stop_words=["<|end_of_turn|>"],
"{{content}}" efficient_eos=True,
]),
separator=[
{"token": "<|end_of_turn|>"}
],
stop_words=[
"<|end_of_turn|>"
],
efficient_eos=True
) )
register_template( register_template(
name="qwen", name="qwen",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n" format_system=StringFormatter(container=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
]),
format_system=StringFormatter(container=[
"<|im_start|>system\n{{content}}<|im_end|>\n"
]),
system="You are a helpful assistant.", system="You are a helpful assistant.",
separator=[ separator=["\n"],
"\n" stop_words=["<|im_end|>"],
], replace_eos=True,
stop_words=[
"<|im_end|>"
],
replace_eos=True
) )
register_template( register_template(name="solar", format_user=StringFormatter(container=["### User:\n{{content}}\n\n### Assistant:\n"]))
name="solar",
format_user=StringFormatter(container=[
"### User:\n{{content}}\n\n### Assistant:\n"
])
)
register_template( register_template(
name="starchat", name="starchat",
format_user=StringFormatter(container=[ format_user=StringFormatter(
{"token": "<|user|>"}, container=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
"\n{{content}}", ),
{"token": "<|end|>"}, format_assistant=StringFormatter(container=["{{content}}"]),
"\n", format_system=StringFormatter(container=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
{"token": "<|assistant|>"} separator=[{"token": "<|end|>"}, "\n"],
]), stop_words=["<|end|>"],
format_assistant=StringFormatter(container=[ efficient_eos=True,
"{{content}}"
]),
format_system=StringFormatter(container=[
{"token": "<|system|>"},
"\n{{content}}",
{"token": "<|end|>"},
"\n"
]),
separator=[
{"token": "<|end|>"},
"\n"
],
stop_words=[
"<|end|>"
],
efficient_eos=True
) )
register_template( register_template(name="vanilla")
name="vanilla"
)
register_template( register_template(
name="vicuna", name="vicuna",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["USER: {{content}} ASSISTANT:"]),
"USER: {{content}} ASSISTANT:"
]),
system=( system=(
"A chat between a curious user and an artificial intelligence assistant. " "A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions." "The assistant gives helpful, detailed, and polite answers to the user's questions."
) ),
) )
register_template( register_template(
name="xuanyuan", name="xuanyuan",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["Human: {{content}} Assistant:"]),
"Human: {{content}} Assistant:"
]),
system=( system=(
"以下是用户和人工智能助手之间的对话。用户以Human开头人工智能助手以Assistant开头" "以下是用户和人工智能助手之间的对话。用户以Human开头人工智能助手以Assistant开头"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、" "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n" "不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
) ),
) )
register_template( register_template(name="xverse", format_user=StringFormatter(container=["Human: {{content}}\n\nAssistant: "]))
name="xverse",
format_user=StringFormatter(container=[
"Human: {{content}}\n\nAssistant: "
])
)
register_template( register_template(
name="yayi", name="yayi",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
{"token": "<|Human|>"}, format_system=StringFormatter(container=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
":\n{{content}}\n\n",
{"token": "<|YaYi|>"},
":"
]),
format_system=StringFormatter(container=[
{"token": "<|System|>"},
":\n{{content}}\n\n"
]),
system=( system=(
"You are a helpful, respectful and honest assistant named YaYi " "You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. " "developed by Beijing Wenge Technology Co.,Ltd. "
@ -711,67 +526,43 @@ register_template(
"explain why instead of answering something not correct. " "explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information." "If you don't know the answer to a question, please don't share false information."
), ),
separator=[ separator=["\n\n"],
"\n\n" stop_words=["<|End|>"],
],
stop_words=[
"<|End|>"
]
) )
register_template( register_template(
name="yi", name="yi",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n" separator=["\n"],
]), stop_words=["<|im_end|>"],
separator=[ replace_eos=True,
"\n"
],
stop_words=[
"<|im_end|>"
],
replace_eos=True
) )
register_template( register_template(
name="yuan", name="yuan",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["{{content}}", {"token": "<sep>"}]),
"{{content}}", separator=["\n"],
{"token": "<sep>"} stop_words=["<eod>"],
]), replace_eos=True,
separator=[
"\n"
],
stop_words=[
"<eod>"
],
replace_eos=True
) )
register_template( register_template(
name="zephyr", name="zephyr",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=["<|user|>\n{{content}}</s><|assistant|>"]),
"<|user|>\n{{content}}</s><|assistant|>" format_system=StringFormatter(
]), container=[
format_system=StringFormatter(container=[ "<|system|>\n{{content}}</s>",
"<|system|>\n{{content}}</s>", ]
]), ),
system="You are a friendly chatbot who always responds in the style of a pirate" system="You are a friendly chatbot who always responds in the style of a pirate",
) )
register_template( register_template(
name="ziya", name="ziya",
format_user=StringFormatter(container=[ format_user=StringFormatter(container=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
{"token": "<human>"}, separator=["\n"],
":{{content}}\n",
{"token": "<bot>"},
":"
]),
separator=[
"\n"
]
) )

View File

@ -4,9 +4,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from ..extras.logging import get_logger from ..extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import TrainingArguments from transformers import TrainingArguments
from llmtuner.hparams import DataArguments from llmtuner.hparams import DataArguments
@ -44,12 +46,10 @@ def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments")
def split_dataset( def split_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
data_args: "DataArguments",
training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]: ) -> Dict[str, "Dataset"]:
if training_args.do_train: if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming: if data_args.streaming:
val_set = dataset.take(int(data_args.val_size)) val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size)) train_set = dataset.skip(int(data_args.val_size))
@ -63,5 +63,5 @@ def split_dataset(
if data_args.streaming: if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset} return {"train_dataset": dataset}
else: # do_eval or do_predict else: # do_eval or do_predict
return {"eval_dataset": dataset} return {"eval_dataset": dataset}

View File

@ -1,35 +1,34 @@
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py # Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os
import json
import torch
import numpy as np
import inspect import inspect
from tqdm import tqdm, trange import json
import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy as np
import torch
from datasets import load_dataset from datasets import load_dataset
from tqdm import tqdm, trange
from transformers.utils import cached_file from transformers.utils import cached_file
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from .template import get_eval_template
from ..extras.constants import CHOICES, SUBJECTS from ..extras.constants import CHOICES, SUBJECTS
from ..hparams import get_eval_args from ..hparams import get_eval_args
from ..model import dispatch_model, load_model_and_tokenizer from ..model import dispatch_model, load_model_and_tokenizer
from .template import get_eval_template
class Evaluator: class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
self.eval_template = get_eval_template(self.eval_args.lang) self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [self.tokenizer.encode( self.choice_inputs = [
self.eval_template.prefix + ch, add_special_tokens=False self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
)[-1] for ch in CHOICES] ]
@torch.inference_mode() @torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
@ -41,10 +40,10 @@ class Evaluator:
def eval(self) -> None: def eval(self) -> None:
mapping = cached_file( mapping = cached_file(
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task), path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json", filename="mapping.json",
cache_dir=self.model_args.cache_dir, cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token token=self.model_args.hf_hub_token,
) )
with open(mapping, "r", encoding="utf-8") as f: with open(mapping, "r", encoding="utf-8") as f:
@ -54,7 +53,7 @@ class Evaluator:
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {} results = {}
for subject in pbar: for subject in pbar:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0 if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True} kwargs = {"trust_remote_code": True}
else: else:
kwargs = {} kwargs = {}
@ -65,32 +64,34 @@ class Evaluator:
cache_dir=self.model_args.cache_dir, cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode, download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token, token=self.model_args.hf_hub_token,
**kwargs **kwargs,
) )
pbar.set_postfix_str(categorys[subject]["name"]) pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], [] inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False): for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"])))) support_set = (
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
)
messages = self.eval_template.format_example( messages = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i], target_data=dataset[self.data_args.split][i],
support_set=support_set, support_set=support_set,
subject_name=categorys[subject]["name"] subject_name=categorys[subject]["name"],
) )
input_ids, _ = self.template.encode_oneturn( input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
tokenizer=self.tokenizer, messages=messages
)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}) inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(messages[-1]["content"]) labels.append(messages[-1]["content"])
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False): for i in trange(
0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
):
batch_input = self.tokenizer.pad( batch_input = self.tokenizer.pad(
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt" inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
).to(self.model.device) ).to(self.model.device)
preds = self.batch_inference(batch_input) preds = self.batch_inference(batch_input)
outputs += preds outputs += preds
corrects = (np.array(outputs) == np.array(labels)) corrects = np.array(outputs) == np.array(labels)
category_name = categorys[subject]["category"] category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0) category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0) category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
@ -100,10 +101,13 @@ class Evaluator:
self._save_results(category_corrects, results) self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join([ score_info = "\n".join(
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) [
for category_name, category_correct in category_corrects.items() if len(category_correct) "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
]) for category_name, category_correct in category_corrects.items()
if len(category_correct)
]
)
print(score_info) print(score_info)
if self.eval_args.save_dir is not None: if self.eval_args.save_dir is not None:
os.makedirs(self.eval_args.save_dir, exist_ok=False) os.makedirs(self.eval_args.save_dir, exist_ok=False)

View File

@ -1,8 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple from typing import TYPE_CHECKING, Dict, List, Tuple
from ..extras.constants import CHOICES
from ..data import Role from ..data import Role
from ..extras.constants import CHOICES
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset
@ -10,24 +11,17 @@ if TYPE_CHECKING:
@dataclass @dataclass
class EvalTemplate: class EvalTemplate:
system: str system: str
choice: str choice: str
answer: str answer: str
prefix: str prefix: str
def parse_example( def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
self,
example: Dict[str, str]
) -> Tuple[str, str]:
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"] return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example( def format_example(
self, self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
target_data: Dict[str, str],
support_set: "Dataset",
subject_name: str
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
messages = [] messages = []
for k in range(len(support_set)): for k in range(len(support_set)):
@ -45,19 +39,8 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {} eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template( def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
name: str, eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
system: str,
choice: str,
answer: str,
prefix: str
) -> None:
eval_templates[name] = EvalTemplate(
system=system,
choice=choice,
answer=answer,
prefix=prefix
)
def get_eval_template(name: str) -> "EvalTemplate": def get_eval_template(name: str) -> "EvalTemplate":
@ -71,7 +54,7 @@ register_eval_template(
system="The following are multiple choice questions (with answers) about {subject}.\n\n", system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\nAnswer: ", answer="\nAnswer: ",
prefix=" " prefix=" ",
) )
@ -80,5 +63,5 @@ register_eval_template(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\n答案:", answer="\n答案:",
prefix="\n" prefix="\n",
) )

View File

@ -1,10 +1,11 @@
import os
import json import json
import os
import time import time
from typing import TYPE_CHECKING
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING
from transformers import TrainerCallback from transformers import TrainerCallback
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import LOG_FILE_NAME from .constants import LOG_FILE_NAME
from .logging import get_logger from .logging import get_logger
@ -12,14 +13,13 @@ from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl from transformers import TrainerControl, TrainerState, TrainingArguments
logger = get_logger(__name__) logger = get_logger(__name__)
class FixValueHeadModelCallback(TrainerCallback): class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called after a checkpoint save. Event called after a checkpoint save.
@ -28,12 +28,11 @@ class FixValueHeadModelCallback(TrainerCallback):
fix_valuehead_checkpoint( fix_valuehead_checkpoint(
model=kwargs.pop("model"), model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors safe_serialization=args.save_safetensors,
) )
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
def __init__(self, runner=None): def __init__(self, runner=None):
self.runner = runner self.runner = runner
self.in_training = False self.in_training = False
@ -99,7 +98,9 @@ class LogCallback(TrainerCallback):
self.cur_steps = 0 self.cur_steps = 0
self.max_steps = 0 self.max_steps = 0
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs): def on_predict(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
):
r""" r"""
Event called after a successful prediction. Event called after a successful prediction.
""" """
@ -125,18 +126,22 @@ class LogCallback(TrainerCallback):
epoch=state.log_history[-1].get("epoch", None), epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time, elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time remaining_time=self.remaining_time,
) )
if self.runner is not None: if self.runner is not None:
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( logger.info(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
)) logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
)
)
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n") f.write(json.dumps(logs) + "\n")
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r""" r"""
Event called after a prediction step. Event called after a prediction step.
""" """

View File

@ -1,5 +1,5 @@
from collections import OrderedDict, defaultdict
from enum import Enum from enum import Enum
from collections import defaultdict, OrderedDict
from typing import Dict, Optional from typing import Dict, Optional
@ -11,14 +11,7 @@ DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str) DEFAULT_TEMPLATE = defaultdict(str)
FILEEXT2TYPE = { FILEEXT2TYPE = {"arrow": "arrow", "csv": "csv", "json": "json", "jsonl": "json", "parquet": "parquet", "txt": "text"}
"arrow": "arrow",
"csv": "csv",
"json": "json",
"jsonl": "json",
"parquet": "parquet",
"txt": "text"
}
IGNORE_INDEX = -100 IGNORE_INDEX = -100
@ -39,22 +32,21 @@ TRAINING_STAGES = {
"Reward Modeling": "rm", "Reward Modeling": "rm",
"PPO": "ppo", "PPO": "ppo",
"DPO": "dpo", "DPO": "dpo",
"Pre-Training": "pt" "Pre-Training": "pt",
} }
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
class DownloadSource(str, Enum): class DownloadSource(str, Enum):
DEFAULT = "hf" DEFAULT = "hf"
MODELSCOPE = "ms" MODELSCOPE = "ms"
def register_model_group( def register_model_group(
models: Dict[str, Dict[DownloadSource, str]], models: Dict[str, Dict[DownloadSource, str]], module: Optional[str] = None, template: Optional[str] = None
module: Optional[str] = None,
template: Optional[str] = None
) -> None: ) -> None:
prefix = None prefix = None
for name, path in models.items(): for name, path in models.items():
@ -73,19 +65,19 @@ register_model_group(
models={ models={
"Baichuan-7B-Base": { "Baichuan-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B", DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B" DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
}, },
"Baichuan-13B-Base": { "Baichuan-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base", DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base" DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
}, },
"Baichuan-13B-Chat": { "Baichuan-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat", DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat" DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
} },
}, },
module="W_pack", module="W_pack",
template="baichuan" template="baichuan",
) )
@ -93,23 +85,23 @@ register_model_group(
models={ models={
"Baichuan2-7B-Base": { "Baichuan2-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base" DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
}, },
"Baichuan2-13B-Base": { "Baichuan2-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base" DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
}, },
"Baichuan2-7B-Chat": { "Baichuan2-7B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat" DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
}, },
"Baichuan2-13B-Chat": { "Baichuan2-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat" DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
} },
}, },
module="W_pack", module="W_pack",
template="baichuan2" template="baichuan2",
) )
@ -117,18 +109,18 @@ register_model_group(
models={ models={
"BLOOM-560M": { "BLOOM-560M": {
DownloadSource.DEFAULT: "bigscience/bloom-560m", DownloadSource.DEFAULT: "bigscience/bloom-560m",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m" DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
}, },
"BLOOM-3B": { "BLOOM-3B": {
DownloadSource.DEFAULT: "bigscience/bloom-3b", DownloadSource.DEFAULT: "bigscience/bloom-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b" DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
}, },
"BLOOM-7B1": { "BLOOM-7B1": {
DownloadSource.DEFAULT: "bigscience/bloom-7b1", DownloadSource.DEFAULT: "bigscience/bloom-7b1",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1" DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
} },
}, },
module="query_key_value" module="query_key_value",
) )
@ -136,18 +128,18 @@ register_model_group(
models={ models={
"BLOOMZ-560M": { "BLOOMZ-560M": {
DownloadSource.DEFAULT: "bigscience/bloomz-560m", DownloadSource.DEFAULT: "bigscience/bloomz-560m",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m" DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
}, },
"BLOOMZ-3B": { "BLOOMZ-3B": {
DownloadSource.DEFAULT: "bigscience/bloomz-3b", DownloadSource.DEFAULT: "bigscience/bloomz-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b" DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
}, },
"BLOOMZ-7B1-mt": { "BLOOMZ-7B1-mt": {
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt", DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt" DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
} },
}, },
module="query_key_value" module="query_key_value",
) )
@ -155,14 +147,14 @@ register_model_group(
models={ models={
"BlueLM-7B-Base": { "BlueLM-7B-Base": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base", DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base" DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
}, },
"BlueLM-7B-Chat": { "BlueLM-7B-Chat": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat", DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat" DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
} },
}, },
template="bluelm" template="bluelm",
) )
@ -170,11 +162,11 @@ register_model_group(
models={ models={
"ChatGLM2-6B-Chat": { "ChatGLM2-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm2-6b", DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b" DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
} }
}, },
module="query_key_value", module="query_key_value",
template="chatglm2" template="chatglm2",
) )
@ -182,15 +174,15 @@ register_model_group(
models={ models={
"ChatGLM3-6B-Base": { "ChatGLM3-6B-Base": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base", DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base" DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
}, },
"ChatGLM3-6B-Chat": { "ChatGLM3-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b", DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b" DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
} },
}, },
module="query_key_value", module="query_key_value",
template="chatglm3" template="chatglm3",
) )
@ -198,30 +190,30 @@ register_model_group(
models={ models={
"ChineseLLaMA2-1.3B": { "ChineseLLaMA2-1.3B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b" DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
}, },
"ChineseLLaMA2-7B": { "ChineseLLaMA2-7B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b" DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
}, },
"ChineseLLaMA2-13B": { "ChineseLLaMA2-13B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b" DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
}, },
"ChineseLLaMA2-1.3B-Chat": { "ChineseLLaMA2-1.3B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b" DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
}, },
"ChineseLLaMA2-7B-Chat": { "ChineseLLaMA2-7B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b" DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
}, },
"ChineseLLaMA2-13B-Chat": { "ChineseLLaMA2-13B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b" DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
} },
}, },
template="llama2_zh" template="llama2_zh",
) )
@ -229,22 +221,22 @@ register_model_group(
models={ models={
"DeepSeekLLM-7B-Base": { "DeepSeekLLM-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base" DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
}, },
"DeepSeekLLM-67B-Base": { "DeepSeekLLM-67B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base" DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
}, },
"DeepSeekLLM-7B-Chat": { "DeepSeekLLM-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat", DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat" DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
}, },
"DeepSeekLLM-67B-Chat": { "DeepSeekLLM-67B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat", DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat" DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
} },
}, },
template="deepseek" template="deepseek",
) )
@ -252,22 +244,22 @@ register_model_group(
models={ models={
"DeepSeekCoder-6.7B-Base": { "DeepSeekCoder-6.7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base" DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
}, },
"DeepSeekCoder-33B-Base": { "DeepSeekCoder-33B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base" DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
}, },
"DeepSeekCoder-6.7B-Chat": { "DeepSeekCoder-6.7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct" DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
}, },
"DeepSeekCoder-33B-Chat": { "DeepSeekCoder-33B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct" DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
} },
}, },
template="deepseekcoder" template="deepseekcoder",
) )
@ -275,14 +267,14 @@ register_model_group(
models={ models={
"DeepSeekMoE-16B-Base": { "DeepSeekMoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base" DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
}, },
"DeepSeekMoE-16B-Chat": { "DeepSeekMoE-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat", DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat" DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
} },
}, },
template="deepseek" template="deepseek",
) )
@ -290,31 +282,31 @@ register_model_group(
models={ models={
"Falcon-7B": { "Falcon-7B": {
DownloadSource.DEFAULT: "tiiuae/falcon-7b", DownloadSource.DEFAULT: "tiiuae/falcon-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b" DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
}, },
"Falcon-40B": { "Falcon-40B": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b", DownloadSource.DEFAULT: "tiiuae/falcon-40b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b" DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
}, },
"Falcon-180B": { "Falcon-180B": {
DownloadSource.DEFAULT: "tiiuae/falcon-180b", DownloadSource.DEFAULT: "tiiuae/falcon-180b",
DownloadSource.MODELSCOPE: "modelscope/falcon-180B" DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
}, },
"Falcon-7B-Chat": { "Falcon-7B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct", DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct" DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
}, },
"Falcon-40B-Chat": { "Falcon-40B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct", DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct" DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
}, },
"Falcon-180B-Chat": { "Falcon-180B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat", DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat" DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
} },
}, },
module="query_key_value", module="query_key_value",
template="falcon" template="falcon",
) )
@ -322,22 +314,22 @@ register_model_group(
models={ models={
"InternLM-7B": { "InternLM-7B": {
DownloadSource.DEFAULT: "internlm/internlm-7b", DownloadSource.DEFAULT: "internlm/internlm-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b" DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
}, },
"InternLM-20B": { "InternLM-20B": {
DownloadSource.DEFAULT: "internlm/internlm-20b", DownloadSource.DEFAULT: "internlm/internlm-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b" DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
}, },
"InternLM-7B-Chat": { "InternLM-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-7b", DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b" DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
}, },
"InternLM-20B-Chat": { "InternLM-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-20b", DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b" DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
} },
}, },
template="intern" template="intern",
) )
@ -345,23 +337,23 @@ register_model_group(
models={ models={
"InternLM2-7B": { "InternLM2-7B": {
DownloadSource.DEFAULT: "internlm/internlm2-7b", DownloadSource.DEFAULT: "internlm/internlm2-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b" DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
}, },
"InternLM2-20B": { "InternLM2-20B": {
DownloadSource.DEFAULT: "internlm/internlm2-20b", DownloadSource.DEFAULT: "internlm/internlm2-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b" DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
}, },
"InternLM2-7B-Chat": { "InternLM2-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b", DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b" DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
}, },
"InternLM2-20B-Chat": { "InternLM2-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b", DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b" DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
} },
}, },
module="wqkv", module="wqkv",
template="intern2" template="intern2",
) )
@ -369,31 +361,28 @@ register_model_group(
models={ models={
"LingoWhale-8B": { "LingoWhale-8B": {
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B", DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B" DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
} }
}, },
module="qkv_proj" module="qkv_proj",
) )
register_model_group( register_model_group(
models={ models={
"LLaMA-7B": { "LLaMA-7B": {DownloadSource.DEFAULT: "huggyllama/llama-7b", DownloadSource.MODELSCOPE: "skyline2006/llama-7b"},
DownloadSource.DEFAULT: "huggyllama/llama-7b",
DownloadSource.MODELSCOPE: "skyline2006/llama-7b"
},
"LLaMA-13B": { "LLaMA-13B": {
DownloadSource.DEFAULT: "huggyllama/llama-13b", DownloadSource.DEFAULT: "huggyllama/llama-13b",
DownloadSource.MODELSCOPE: "skyline2006/llama-13b" DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
}, },
"LLaMA-30B": { "LLaMA-30B": {
DownloadSource.DEFAULT: "huggyllama/llama-30b", DownloadSource.DEFAULT: "huggyllama/llama-30b",
DownloadSource.MODELSCOPE: "skyline2006/llama-30b" DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
}, },
"LLaMA-65B": { "LLaMA-65B": {
DownloadSource.DEFAULT: "huggyllama/llama-65b", DownloadSource.DEFAULT: "huggyllama/llama-65b",
DownloadSource.MODELSCOPE: "skyline2006/llama-65b" DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
} },
} }
) )
@ -402,30 +391,30 @@ register_model_group(
models={ models={
"LLaMA2-7B": { "LLaMA2-7B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms" DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
}, },
"LLaMA2-13B": { "LLaMA2-13B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms" DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
}, },
"LLaMA2-70B": { "LLaMA2-70B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms" DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
}, },
"LLaMA2-7B-Chat": { "LLaMA2-7B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms" DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
}, },
"LLaMA2-13B-Chat": { "LLaMA2-13B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms" DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
}, },
"LLaMA2-70B-Chat": { "LLaMA2-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms" DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
} },
}, },
template="llama2" template="llama2",
) )
@ -433,18 +422,18 @@ register_model_group(
models={ models={
"Mistral-7B": { "Mistral-7B": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1", DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1" DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
}, },
"Mistral-7B-Chat": { "Mistral-7B-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1", DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1" DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
}, },
"Mistral-7B-v0.2-Chat": { "Mistral-7B-v0.2-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2", DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2" DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
} },
}, },
template="mistral" template="mistral",
) )
@ -452,14 +441,14 @@ register_model_group(
models={ models={
"Mixtral-8x7B": { "Mixtral-8x7B": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1" DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
}, },
"Mixtral-8x7B-Chat": { "Mixtral-8x7B-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1" DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
} },
}, },
template="mistral" template="mistral",
) )
@ -467,110 +456,87 @@ register_model_group(
models={ models={
"OpenChat3.5-7B-Chat": { "OpenChat3.5-7B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat_3.5", DownloadSource.DEFAULT: "openchat/openchat_3.5",
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5" DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
} }
}, },
template="openchat" template="openchat",
) )
register_model_group( register_model_group(
models={ models={
"Phi-1.5-1.3B": { "Phi-1.5-1.3B": {DownloadSource.DEFAULT: "microsoft/phi-1_5", DownloadSource.MODELSCOPE: "allspace/PHI_1-5"},
DownloadSource.DEFAULT: "microsoft/phi-1_5", "Phi-2-2.7B": {DownloadSource.DEFAULT: "microsoft/phi-2", DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"},
DownloadSource.MODELSCOPE: "allspace/PHI_1-5"
},
"Phi-2-2.7B": {
DownloadSource.DEFAULT: "microsoft/phi-2",
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"
}
} }
) )
register_model_group( register_model_group(
models={ models={
"Qwen-1.8B": { "Qwen-1.8B": {DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"},
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", "Qwen-7B": {DownloadSource.DEFAULT: "Qwen/Qwen-7B", DownloadSource.MODELSCOPE: "qwen/Qwen-7B"},
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B" "Qwen-14B": {DownloadSource.DEFAULT: "Qwen/Qwen-14B", DownloadSource.MODELSCOPE: "qwen/Qwen-14B"},
}, "Qwen-72B": {DownloadSource.DEFAULT: "Qwen/Qwen-72B", DownloadSource.MODELSCOPE: "qwen/Qwen-72B"},
"Qwen-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B"
},
"Qwen-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B"
},
"Qwen-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B"
},
"Qwen-1.8B-Chat": { "Qwen-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat" DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
},
"Qwen-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"
}, },
"Qwen-7B-Chat": {DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"},
"Qwen-14B-Chat": { "Qwen-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat" DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
}, },
"Qwen-72B-Chat": { "Qwen-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat" DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
}, },
"Qwen-1.8B-int8-Chat": { "Qwen-1.8B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8" DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
}, },
"Qwen-1.8B-int4-Chat": { "Qwen-1.8B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4" DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
}, },
"Qwen-7B-int8-Chat": { "Qwen-7B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8" DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
}, },
"Qwen-7B-int4-Chat": { "Qwen-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4" DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
}, },
"Qwen-14B-int8-Chat": { "Qwen-14B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8" DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
}, },
"Qwen-14B-int4-Chat": { "Qwen-14B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4" DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
}, },
"Qwen-72B-int8-Chat": { "Qwen-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8" DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
}, },
"Qwen-72B-int4-Chat": { "Qwen-72B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4" DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
} },
}, },
module="c_attn", module="c_attn",
template="qwen" template="qwen",
) )
register_model_group( register_model_group(
models={ models={
"SOLAR-10.7B": { "SOLAR-10.7B": {DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0"},
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0"
},
"SOLAR-10.7B-Chat": { "SOLAR-10.7B-Chat": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0", DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0" DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
} },
}, },
template="solar" template="solar",
) )
@ -578,7 +544,7 @@ register_model_group(
models={ models={
"Skywork-13B-Base": { "Skywork-13B-Base": {
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base", DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base" DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
} }
} }
) )
@ -588,68 +554,51 @@ register_model_group(
models={ models={
"Vicuna1.5-7B-Chat": { "Vicuna1.5-7B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5", DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5" DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
}, },
"Vicuna1.5-13B-Chat": { "Vicuna1.5-13B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5", DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5" DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
} },
}, },
template="vicuna" template="vicuna",
) )
register_model_group( register_model_group(
models={ models={
"XuanYuan-70B": { "XuanYuan-70B": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"},
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B" "XuanYuan-70B-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"},
}, "XuanYuan-70B-int8-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"},
"XuanYuan-70B-Chat": { "XuanYuan-70B-int4-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"},
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"
},
"XuanYuan-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"
},
"XuanYuan-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"
}
}, },
template="xuanyuan" template="xuanyuan",
) )
register_model_group( register_model_group(
models={ models={
"XVERSE-7B": { "XVERSE-7B": {DownloadSource.DEFAULT: "xverse/XVERSE-7B", DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"},
DownloadSource.DEFAULT: "xverse/XVERSE-7B", "XVERSE-13B": {DownloadSource.DEFAULT: "xverse/XVERSE-13B", DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"},
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B" "XVERSE-65B": {DownloadSource.DEFAULT: "xverse/XVERSE-65B", DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"},
},
"XVERSE-13B": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"
},
"XVERSE-65B": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"
},
"XVERSE-65B-2": { "XVERSE-65B-2": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2", DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2" DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
}, },
"XVERSE-7B-Chat": { "XVERSE-7B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat", DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat" DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
}, },
"XVERSE-13B-Chat": { "XVERSE-13B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat", DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat" DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
}, },
"XVERSE-65B-Chat": { "XVERSE-65B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat", DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat" DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
} },
}, },
template="xverse" template="xverse",
) )
@ -657,45 +606,33 @@ register_model_group(
models={ models={
"Yayi-7B": { "Yayi-7B": {
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2", DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2" DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
}, },
"Yayi-13B": { "Yayi-13B": {
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2", DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2" DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
} },
}, },
template="yayi" template="yayi",
) )
register_model_group( register_model_group(
models={ models={
"Yi-6B": { "Yi-6B": {DownloadSource.DEFAULT: "01-ai/Yi-6B", DownloadSource.MODELSCOPE: "01ai/Yi-6B"},
DownloadSource.DEFAULT: "01-ai/Yi-6B", "Yi-34B": {DownloadSource.DEFAULT: "01-ai/Yi-34B", DownloadSource.MODELSCOPE: "01ai/Yi-34B"},
DownloadSource.MODELSCOPE: "01ai/Yi-6B" "Yi-6B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"},
}, "Yi-34B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"},
"Yi-34B": {
DownloadSource.DEFAULT: "01-ai/Yi-34B",
DownloadSource.MODELSCOPE: "01ai/Yi-34B"
},
"Yi-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"
},
"Yi-34B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"
},
"Yi-6B-int8-Chat": { "Yi-6B-int8-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits", DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits" DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
}, },
"Yi-34B-int8-Chat": { "Yi-34B-int8-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits", DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits" DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
} },
}, },
template="yi" template="yi",
) )
@ -703,18 +640,18 @@ register_model_group(
models={ models={
"Yuan2-2B-Chat": { "Yuan2-2B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf", DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf" DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
}, },
"Yuan2-51B-Chat": { "Yuan2-51B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf", DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf" DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
}, },
"Yuan2-102B-Chat": { "Yuan2-102B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf", DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf" DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
} },
}, },
template="yuan" template="yuan",
) )
@ -722,12 +659,12 @@ register_model_group(
models={ models={
"Zephyr-7B-Alpha-Chat": { "Zephyr-7B-Alpha-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha", DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha" DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
}, },
"Zephyr-7B-Beta-Chat": { "Zephyr-7B-Beta-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta", DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta" DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
} },
}, },
template="zephyr" template="zephyr",
) )

View File

@ -1,5 +1,5 @@
import sys
import logging import logging
import sys
class LoggerHandler(logging.Handler): class LoggerHandler(logging.Handler):
@ -27,8 +27,7 @@ def get_logger(name: str) -> logging.Logger:
Gets a standard logger with a stream hander to stdout. Gets a standard logger with a stream hander to stdout.
""" """
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
datefmt="%m/%d/%Y %H:%M:%S"
) )
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter) handler.setFormatter(formatter)

View File

@ -1,31 +1,33 @@
import gc import gc
import os import os
import torch
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING, Dict, Tuple
import torch
from peft import PeftModel
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
from transformers.utils import ( from transformers.utils import (
WEIGHTS_NAME,
SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_torch_bf16_gpu_available, is_torch_bf16_gpu_available,
is_torch_cuda_available, is_torch_cuda_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_xpu_available is_torch_xpu_available,
) )
from peft import PeftModel
from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger from .logging import get_logger
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try: try:
_is_bf16_available = is_torch_bf16_gpu_available() _is_bf16_available = is_torch_bf16_gpu_available()
except: except Exception:
_is_bf16_available = False _is_bf16_available = False
if TYPE_CHECKING: if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments from llmtuner.hparams import ModelArguments
@ -36,6 +38,7 @@ class AverageMeter:
r""" r"""
Computes and stores the average and current value. Computes and stores the average and current value.
""" """
def __init__(self): def __init__(self):
self.reset() self.reset()
@ -75,9 +78,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def fix_valuehead_checkpoint( def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
output_dir: str,
safe_serialization: bool
) -> None: ) -> None:
r""" r"""
The model is already unwrapped. The model is already unwrapped.
@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
if safe_serialization: if safe_serialization:
from safetensors import safe_open from safetensors import safe_open
from safetensors.torch import save_file from safetensors.torch import save_file
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
@ -112,9 +114,7 @@ def fix_valuehead_checkpoint(
os.remove(path_to_checkpoint) os.remove(path_to_checkpoint)
model.pretrained_model.save_pretrained( model.pretrained_model.save_pretrained(
output_dir, output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
state_dict=decoder_state_dict or None,
safe_serialization=safe_serialization
) )
if safe_serialization: if safe_serialization:
@ -182,11 +182,10 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
try: try:
from modelscope import snapshot_download from modelscope import snapshot_download
revision = "master" if model_args.model_revision == "main" else model_args.model_revision revision = "master" if model_args.model_revision == "main" else model_args.model_revision
model_args.model_name_or_path = snapshot_download( model_args.model_name_or_path = snapshot_download(
model_args.model_name_or_path, model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
revision=revision,
cache_dir=model_args.cache_dir
) )
except ImportError: except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`") raise ImportError("Please install modelscope via `pip install modelscope -U`")

View File

@ -9,7 +9,7 @@ def is_package_available(name: str) -> bool:
def get_package_version(name: str) -> str: def get_package_version(name: str) -> str:
try: try:
return importlib.metadata.version(name) return importlib.metadata.version(name)
except: except Exception:
return "0.0.0" return "0.0.0"

View File

@ -1,11 +1,16 @@
import math import math
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
Cache, LlamaAttention, LlamaFlashAttention2, apply_rotary_pos_emb, repeat_kv Cache,
LlamaAttention,
LlamaFlashAttention2,
apply_rotary_pos_emb,
repeat_kv,
) )
from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -19,7 +24,7 @@ def llama_torch_attn_forward(
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
**kwargs **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@ -45,15 +50,17 @@ def llama_torch_attn_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor: def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(( state = torch.cat(
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
), dim=2) dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
@ -68,14 +75,17 @@ def llama_torch_attn_forward(
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :) attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(( attn_output = torch.cat(
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) (
)) attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
@ -94,7 +104,7 @@ def llama_flash_attn_forward(
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
**kwargs **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions # LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False output_attentions = False
@ -124,9 +134,9 @@ def llama_flash_attn_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
dropout_rate = self.attention_dropout if self.training else 0.0 dropout_rate = self.attention_dropout if self.training else 0.0
@ -144,14 +154,16 @@ def llama_flash_attn_forward(
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor: def shift(state: torch.Tensor) -> torch.Tensor:
state = torch.cat(( state = torch.cat(
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
), dim=2) dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim) return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
@ -162,11 +174,14 @@ def llama_flash_attn_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
) )
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(( attn_output = torch.cat(
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) (
)) attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)

View File

@ -1,12 +1,14 @@
import os
import math
import json import json
import math
import os
from typing import List, Optional from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
from .logging import get_logger from .logging import get_logger
from .packages import is_matplotlib_available from .packages import is_matplotlib_available
if is_matplotlib_available(): if is_matplotlib_available():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -20,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
""" """
last = scalars[0] last = scalars[0]
smoothed = list() smoothed = list()
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars: for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val) smoothed.append(smoothed_val)
@ -29,7 +31,6 @@ def smooth(scalars: List[float]) -> List[float]:
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)

View File

@ -3,7 +3,7 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .parser import get_train_args, get_infer_args, get_eval_args from .parser import get_eval_args, get_infer_args, get_train_args
__all__ = [ __all__ = [
@ -14,5 +14,5 @@ __all__ = [
"ModelArguments", "ModelArguments",
"get_train_args", "get_train_args",
"get_infer_args", "get_infer_args",
"get_eval_args" "get_eval_args",
] ]

View File

@ -1,5 +1,5 @@
from typing import Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional
@dataclass @dataclass
@ -8,80 +8,66 @@ class DataArguments:
Arguments pertaining to what data we are going to input our model for training and evaluation. Arguments pertaining to what data we are going to input our model for training and evaluation.
""" """
template: Optional[str] = field( template: Optional[str] = field(
default=None, default=None, metadata={"help": "Which template to use for constructing prompts in training and inference."}
metadata={"help": "Which template to use for constructing prompts in training and inference."}
) )
dataset: Optional[str] = field( dataset: Optional[str] = field(
default=None, default=None,
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
) )
dataset_dir: Optional[str] = field( dataset_dir: Optional[str] = field(
default="data", default="data", metadata={"help": "Path to the folder containing the datasets."}
metadata={"help": "Path to the folder containing the datasets."}
) )
split: Optional[str] = field( split: Optional[str] = field(
default="train", default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
metadata={"help": "Which dataset split to use for training and evaluation."}
) )
cutoff_len: Optional[int] = field( cutoff_len: Optional[int] = field(
default=1024, default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."}
metadata={"help": "The maximum length of the model inputs after tokenization."}
) )
reserved_label_len: Optional[int] = field( reserved_label_len: Optional[int] = field(
default=1, default=1, metadata={"help": "The maximum length reserved for label after tokenization."}
metadata={"help": "The maximum length reserved for label after tokenization."}
) )
train_on_prompt: Optional[bool] = field( train_on_prompt: Optional[bool] = field(
default=False, default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
metadata={"help": "Whether to disable the mask on the prompt or not."}
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable dataset streaming."}
) )
streaming: Optional[bool] = field(default=False, metadata={"help": "Enable dataset streaming."})
buffer_size: Optional[int] = field( buffer_size: Optional[int] = field(
default=16384, default=16384, metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
) )
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."} metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
) )
interleave_probs: Optional[str] = field( interleave_probs: Optional[str] = field(
default=None, default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."} metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
) )
overwrite_cache: Optional[bool] = field( overwrite_cache: Optional[bool] = field(
default=False, default=False, metadata={"help": "Overwrite the cached training and evaluation sets."}
metadata={"help": "Overwrite the cached training and evaluation sets."}
) )
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: Optional[int] = field(
default=None, default=None, metadata={"help": "The number of processes to use for the preprocessing."}
metadata={"help": "The number of processes to use for the preprocessing."}
) )
max_samples: Optional[int] = field( max_samples: Optional[int] = field(
default=None, default=None, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
) )
eval_num_beams: Optional[int] = field( eval_num_beams: Optional[int] = field(
default=None, default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
) )
ignore_pad_token_for_loss: Optional[bool] = field( ignore_pad_token_for_loss: Optional[bool] = field(
default=True, default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
) )
val_size: Optional[float] = field( val_size: Optional[float] = field(
default=0, default=0, metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
) )
sft_packing: Optional[bool] = field( sft_packing: Optional[bool] = field(
default=False, default=False, metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
) )
cache_path: Optional[str] = field( cache_path: Optional[str] = field(
default=None, default=None, metadata={"help": "Path to save or load the preprocessed datasets."}
metadata={"help": "Path to save or load the preprocessed datasets."}
) )
def __post_init__(self): def __post_init__(self):

View File

@ -1,6 +1,6 @@
import os import os
from typing import Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional
from datasets import DownloadMode from datasets import DownloadMode
@ -10,36 +10,18 @@ class EvaluationArguments:
r""" r"""
Arguments pertaining to specify the evaluation parameters. Arguments pertaining to specify the evaluation parameters.
""" """
task: str = field( task: str = field(metadata={"help": "Name of the evaluation task."})
metadata={"help": "Name of the evaluation task."}
)
task_dir: Optional[str] = field( task_dir: Optional[str] = field(
default="evaluation", default="evaluation", metadata={"help": "Path to the folder containing the evaluation datasets."}
metadata={"help": "Path to the folder containing the evaluation datasets."}
)
batch_size: Optional[int] = field(
default=4,
metadata={"help": "The batch size per GPU for evaluation."}
)
seed: Optional[int] = field(
default=42,
metadata={"help": "Random seed to be used with data loaders."}
)
lang: Optional[Literal["en", "zh"]] = field(
default="en",
metadata={"help": "Language used at evaluation."}
)
n_shot: Optional[int] = field(
default=5,
metadata={"help": "Number of examplars for few-shot learning."}
)
save_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to save the evaluation results."}
) )
batch_size: Optional[int] = field(default=4, metadata={"help": "The batch size per GPU for evaluation."})
seed: Optional[int] = field(default=42, metadata={"help": "Random seed to be used with data loaders."})
lang: Optional[Literal["en", "zh"]] = field(default="en", metadata={"help": "Language used at evaluation."})
n_shot: Optional[int] = field(default=5, metadata={"help": "Number of examplars for few-shot learning."})
save_dir: Optional[str] = field(default=None, metadata={"help": "Path to save the evaluation results."})
download_mode: Optional[DownloadMode] = field( download_mode: Optional[DownloadMode] = field(
default=DownloadMode.REUSE_DATASET_IF_EXISTS, default=DownloadMode.REUSE_DATASET_IF_EXISTS,
metadata={"help": "Download mode used for the evaluation datasets."} metadata={"help": "Download mode used for the evaluation datasets."},
) )
def __post_init__(self): def __post_init__(self):

View File

@ -1,6 +1,6 @@
import json import json
from typing import Literal, Optional
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Literal, Optional
@dataclass @dataclass
@ -10,17 +10,18 @@ class FreezeArguments:
""" """
name_module_trainable: Optional[str] = field( name_module_trainable: Optional[str] = field(
default="mlp", default="mlp",
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ metadata={
"help": 'Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \ Use commas to separate multiple modules. \
LLaMA choices: [\"mlp\", \"self_attn\"], \ LLaMA choices: ["mlp", "self_attn"], \
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \ BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
Qwen choices: [\"mlp\", \"attn\"], \ Qwen choices: ["mlp", "attn"], \
Phi choices: [\"mlp\", \"mixer\"], \ Phi choices: ["mlp", "mixer"], \
Others choices: the same as LLaMA."} Others choices: the same as LLaMA.'
},
) )
num_layer_trainable: Optional[int] = field( num_layer_trainable: Optional[int] = field(
default=3, default=3, metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
) )
@ -31,37 +32,32 @@ class LoraArguments:
""" """
additional_target: Optional[str] = field( additional_target: Optional[str] = field(
default=None, default=None,
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."} metadata={
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
},
) )
lora_alpha: Optional[int] = field( lora_alpha: Optional[int] = field(
default=None, default=None, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
)
lora_dropout: Optional[float] = field(
default=0.0,
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
)
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
) )
lora_dropout: Optional[float] = field(default=0.0, metadata={"help": "Dropout rate for the LoRA fine-tuning."})
lora_rank: Optional[int] = field(default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."})
lora_target: Optional[str] = field( lora_target: Optional[str] = field(
default=None, default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ metadata={
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ "help": 'Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \ LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
Phi choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
Others choices: the same as LLaMA."} Phi choices: ["Wqkv", "out_proj", "fc1", "fc2"], \
Others choices: the same as LLaMA.'
},
) )
lora_bf16_mode: Optional[bool] = field( lora_bf16_mode: Optional[bool] = field(
default=False, default=False, metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
) )
create_new_adapter: Optional[bool] = field( create_new_adapter: Optional[bool] = field(
default=False, default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
) )
@ -70,69 +66,53 @@ class RLHFArguments:
r""" r"""
Arguments pertaining to the PPO and DPO training. Arguments pertaining to the PPO and DPO training.
""" """
dpo_beta: Optional[float] = field( dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for the DPO loss."})
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field( dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
default="sigmoid", default="sigmoid", metadata={"help": "The type of DPO loss to use."}
metadata={"help": "The type of DPO loss to use."}
) )
dpo_ftx: Optional[float] = field( dpo_ftx: Optional[float] = field(
default=0, default=0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
) )
ppo_buffer_size: Optional[int] = field( ppo_buffer_size: Optional[int] = field(
default=1, default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."} metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
) )
ppo_epochs: Optional[int] = field( ppo_epochs: Optional[int] = field(
default=4, default=4, metadata={"help": "The number of epochs to perform in a PPO optimization step."}
metadata={"help": "The number of epochs to perform in a PPO optimization step."}
) )
ppo_logger: Optional[str] = field( ppo_logger: Optional[str] = field(
default=None, default=None, metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'}
metadata={"help": "Log with either \"wandb\" or \"tensorboard\" in PPO training."}
) )
ppo_score_norm: Optional[bool] = field( ppo_score_norm: Optional[bool] = field(
default=False, default=False, metadata={"help": "Use score normalization in PPO training."}
metadata={"help": "Use score normalization in PPO training."}
) )
ppo_target: Optional[float] = field( ppo_target: Optional[float] = field(
default=6.0, default=6.0, metadata={"help": "Target KL value for adaptive KL control in PPO training."}
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
) )
ppo_whiten_rewards: Optional[bool] = field( ppo_whiten_rewards: Optional[bool] = field(
default=False, default=False, metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
) )
ref_model: Optional[str] = field( ref_model: Optional[str] = field(
default=None, default=None, metadata={"help": "Path to the reference model used for the PPO or DPO training."}
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
) )
ref_model_adapters: Optional[str] = field( ref_model_adapters: Optional[str] = field(
default=None, default=None, metadata={"help": "Path to the adapters of the reference model."}
metadata={"help": "Path to the adapters of the reference model."}
) )
ref_model_quantization_bit: Optional[int] = field( ref_model_quantization_bit: Optional[int] = field(
default=None, default=None, metadata={"help": "The number of bits to quantize the reference model."}
metadata={"help": "The number of bits to quantize the reference model."}
) )
reward_model: Optional[str] = field( reward_model: Optional[str] = field(
default=None, default=None, metadata={"help": "Path to the reward model used for the PPO training."}
metadata={"help": "Path to the reward model used for the PPO training."}
) )
reward_model_adapters: Optional[str] = field( reward_model_adapters: Optional[str] = field(
default=None, default=None, metadata={"help": "Path to the adapters of the reward model."}
metadata={"help": "Path to the adapters of the reward model."}
) )
reward_model_quantization_bit: Optional[int] = field( reward_model_quantization_bit: Optional[int] = field(
default=None, default=None, metadata={"help": "The number of bits to quantize the reward model."}
metadata={"help": "The number of bits to quantize the reward model."}
) )
reward_model_type: Optional[Literal["lora", "full", "api"]] = field( reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
default="lora", default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."} metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
) )
@ -142,16 +122,13 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft", default="sft", metadata={"help": "Which stage will be performed in training."}
metadata={"help": "Which stage will be performed in training."}
) )
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field( finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora", default="lora", metadata={"help": "Which fine-tuning method to use."}
metadata={"help": "Which fine-tuning method to use."}
) )
plot_loss: Optional[bool] = field( plot_loss: Optional[bool] = field(
default=False, default=False, metadata={"help": "Whether or not to save the training loss curves."}
metadata={"help": "Whether or not to save the training loss curves."}
) )
def __post_init__(self): def __post_init__(self):

View File

@ -1,5 +1,5 @@
from typing import Any, Dict, Optional
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
@dataclass @dataclass
@ -8,40 +8,37 @@ class GeneratingArguments:
Arguments pertaining to specify the decoding parameters. Arguments pertaining to specify the decoding parameters.
""" """
do_sample: Optional[bool] = field( do_sample: Optional[bool] = field(
default=True, default=True, metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
) )
temperature: Optional[float] = field( temperature: Optional[float] = field(
default=0.95, default=0.95, metadata={"help": "The value used to modulate the next token probabilities."}
metadata={"help": "The value used to modulate the next token probabilities."}
) )
top_p: Optional[float] = field( top_p: Optional[float] = field(
default=0.7, default=0.7,
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} metadata={
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
},
) )
top_k: Optional[int] = field( top_k: Optional[int] = field(
default=50, default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
) )
num_beams: Optional[int] = field( num_beams: Optional[int] = field(
default=1, default=1, metadata={"help": "Number of beams for beam search. 1 means no beam search."}
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
) )
max_length: Optional[int] = field( max_length: Optional[int] = field(
default=512, default=512,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
) )
max_new_tokens: Optional[int] = field( max_new_tokens: Optional[int] = field(
default=512, default=512,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
) )
repetition_penalty: Optional[float] = field( repetition_penalty: Optional[float] = field(
default=1.0, default=1.0, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
) )
length_penalty: Optional[float] = field( length_penalty: Optional[float] = field(
default=1.0, default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
) )
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:

View File

@ -1,5 +1,5 @@
from typing import Any, Dict, Literal, Optional
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
@dataclass @dataclass
@ -11,108 +11,82 @@ class ModelArguments:
metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."} metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."}
) )
adapter_name_or_path: Optional[str] = field( adapter_name_or_path: Optional[str] = field(
default=None, default=None, metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}
) )
cache_dir: Optional[str] = field( cache_dir: Optional[str] = field(
default=None, default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."} metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
) )
use_fast_tokenizer: Optional[bool] = field( use_fast_tokenizer: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."} metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
) )
resize_vocab: Optional[bool] = field( resize_vocab: Optional[bool] = field(
default=False, default=False, metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}
) )
split_special_tokens: Optional[bool] = field( split_special_tokens: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."} metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
) )
model_revision: Optional[str] = field( model_revision: Optional[str] = field(
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
) )
quantization_bit: Optional[int] = field( quantization_bit: Optional[int] = field(
default=None, default=None, metadata={"help": "The number of bits to quantize the model."}
metadata={"help": "The number of bits to quantize the model."}
) )
quantization_type: Optional[Literal["fp4", "nf4"]] = field( quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4", default="nf4", metadata={"help": "Quantization data type to use in int4 training."}
metadata={"help": "Quantization data type to use in int4 training."}
) )
double_quantization: Optional[bool] = field( double_quantization: Optional[bool] = field(
default=True, default=True, metadata={"help": "Whether or not to use double quantization in int4 training."}
metadata={"help": "Whether or not to use double quantization in int4 training."}
) )
rope_scaling: Optional[Literal["linear", "dynamic"]] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}
) )
flash_attn: Optional[bool] = field( flash_attn: Optional[bool] = field(
default=False, default=False, metadata={"help": "Enable FlashAttention-2 for faster training."}
metadata={"help": "Enable FlashAttention-2 for faster training."}
) )
shift_attn: Optional[bool] = field( shift_attn: Optional[bool] = field(
default=False, default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
) )
use_unsloth: Optional[bool] = field( use_unsloth: Optional[bool] = field(
default=False, default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}
) )
disable_gradient_checkpointing: Optional[bool] = field( disable_gradient_checkpointing: Optional[bool] = field(
default=False, default=False, metadata={"help": "Whether or not to disable gradient checkpointing."}
metadata={"help": "Whether or not to disable gradient checkpointing."}
) )
upcast_layernorm: Optional[bool] = field( upcast_layernorm: Optional[bool] = field(
default=False, default=False, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
) )
upcast_lmhead_output: Optional[bool] = field( upcast_lmhead_output: Optional[bool] = field(
default=False, default=False, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
ms_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."}
) )
hf_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."})
ms_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with ModelScope Hub."})
export_dir: Optional[str] = field( export_dir: Optional[str] = field(
default=None, default=None, metadata={"help": "Path to the directory to save the exported model."}
metadata={"help": "Path to the directory to save the exported model."}
) )
export_size: Optional[int] = field( export_size: Optional[int] = field(
default=1, default=1, metadata={"help": "The file shard size (in GB) of the exported model."}
metadata={"help": "The file shard size (in GB) of the exported model."}
) )
export_quantization_bit: Optional[int] = field( export_quantization_bit: Optional[int] = field(
default=None, default=None, metadata={"help": "The number of bits to quantize the exported model."}
metadata={"help": "The number of bits to quantize the exported model."}
) )
export_quantization_dataset: Optional[str] = field( export_quantization_dataset: Optional[str] = field(
default=None, default=None, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
) )
export_quantization_nsamples: Optional[int] = field( export_quantization_nsamples: Optional[int] = field(
default=128, default=128, metadata={"help": "The number of samples used for quantization."}
metadata={"help": "The number of samples used for quantization."}
) )
export_quantization_maxlen: Optional[int] = field( export_quantization_maxlen: Optional[int] = field(
default=1024, default=1024, metadata={"help": "The maximum length of the model inputs used for quantization."}
metadata={"help": "The maximum length of the model inputs used for quantization."}
) )
export_legacy_format: Optional[bool] = field( export_legacy_format: Optional[bool] = field(
default=False, default=False, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}
) )
export_hub_model_id: Optional[str] = field( export_hub_model_id: Optional[str] = field(
default=None, default=None, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}
) )
def __post_init__(self): def __post_init__(self):
@ -122,7 +96,7 @@ class ModelArguments:
if self.split_special_tokens and self.use_fast_tokenizer: if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."

View File

@ -1,10 +1,11 @@
import logging
import os import os
import sys import sys
import torch
import logging
import datasets
import transformers
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import datasets
import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
@ -19,24 +20,12 @@ from .model_args import ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
_TRAIN_ARGS = [ _TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments _TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_INFER_ARGS = [
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_INFER_CLS = Tuple[
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_EVAL_ARGS = [
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
_EVAL_CLS = Tuple[
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
@ -77,7 +66,7 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
raise ValueError("Multiple adapters are only available for LoRA tuning.") raise ValueError("Multiple adapters are only available for LoRA tuning.")
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
@ -181,18 +170,22 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint)) training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format( logger.info(
training_args.resume_from_checkpoint "Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
)) training_args.resume_from_checkpoint
)
)
if ( if (
finetuning_args.stage in ["rm", "ppo"] finetuning_args.stage in ["rm", "ppo"]
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
and training_args.resume_from_checkpoint is not None and training_args.resume_from_checkpoint is not None
): ):
logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( logger.warning(
training_args.resume_from_checkpoint "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
)) training_args.resume_from_checkpoint
)
)
# postprocess model_args # postprocess model_args
model_args.compute_dtype = ( model_args.compute_dtype = (
@ -201,10 +194,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args.model_max_length = data_args.cutoff_len model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary: # Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format( logger.info(
training_args.local_rank, training_args.device, training_args.n_gpu, "Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
bool(training_args.local_rank != -1), str(model_args.compute_dtype) training_args.local_rank,
)) training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
str(model_args.compute_dtype),
)
)
logger.info(f"Training/evaluation parameters {training_args}") logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model. # Set seed before initializing model.

View File

@ -1,25 +1,25 @@
import torch
import inspect import inspect
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .utils import find_all_linear_modules from .utils import find_all_linear_modules
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from ..hparams import ModelArguments, FinetuningArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
def init_adapter( def init_adapter(
model: "PreTrainedModel", model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""
Initializes the adapters. Initializes the adapters.
@ -47,10 +47,10 @@ def init_adapter(
if not num_layers: if not num_layers:
raise ValueError("Current model does not support freeze tuning.") raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)] trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0 else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] # noqa: C416
trainable_layers = [] trainable_layers = []
for module_name in finetuning_args.name_module_trainable: for module_name in finetuning_args.name_module_trainable:
@ -69,7 +69,7 @@ def init_adapter(
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
is_mergeable = True is_mergeable = True
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter." assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False is_mergeable = False
@ -90,10 +90,10 @@ def init_adapter(
if len(adapter_to_merge) > 0: if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if adapter_to_resume is not None: # resume lora training if adapter_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable) model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
if is_trainable and adapter_to_resume is None: # create new lora weights while training if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model) target_modules = find_all_linear_modules(model)
else: else:
@ -103,11 +103,12 @@ def init_adapter(
"r": finetuning_args.lora_rank, "r": finetuning_args.lora_rank,
"target_modules": target_modules, "target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha, "lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout "lora_dropout": finetuning_args.lora_dropout,
} }
if model_args.use_unsloth: if model_args.use_unsloth:
from unsloth import FastLlamaModel, FastMistralModel # type: ignore from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length} unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters: if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters:
unsloth_peft_kwargs["loftq_config"] = {} unsloth_peft_kwargs["loftq_config"] = {}
@ -124,7 +125,7 @@ def init_adapter(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
inference_mode=False, inference_mode=False,
modules_to_save=finetuning_args.additional_target, modules_to_save=finetuning_args.additional_target,
**peft_kwargs **peft_kwargs,
) )
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)

View File

@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@ -7,12 +8,14 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter from .adapter import init_adapter
from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass from .utils import load_valuehead_params, register_autoclass
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from ..hparams import ModelArguments, FinetuningArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -29,7 +32,7 @@ def load_model_and_tokenizer(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False, is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False add_valuehead: Optional[bool] = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]: ) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r""" r"""
Loads pretrained model and tokenizer. Loads pretrained model and tokenizer.
@ -43,7 +46,7 @@ def load_model_and_tokenizer(
"trust_remote_code": True, "trust_remote_code": True,
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
"revision": model_args.model_revision, "revision": model_args.model_revision,
"token": model_args.hf_hub_token "token": model_args.hf_hub_token,
} }
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
@ -51,7 +54,7 @@ def load_model_and_tokenizer(
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens, split_special_tokens=model_args.split_special_tokens,
padding_side="right", padding_side="right",
**config_kwargs **config_kwargs,
) )
patch_tokenizer(tokenizer) patch_tokenizer(tokenizer)
@ -61,7 +64,8 @@ def load_model_and_tokenizer(
model = None model = None
if is_trainable and model_args.use_unsloth: if is_trainable and model_args.use_unsloth:
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth") require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
from unsloth import FastLlamaModel, FastMistralModel # type: ignore from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_kwargs = { unsloth_kwargs = {
"model_name": model_args.model_name_or_path, "model_name": model_args.model_name_or_path,
"max_seq_length": model_args.model_max_length, "max_seq_length": model_args.model_max_length,
@ -69,7 +73,7 @@ def load_model_and_tokenizer(
"load_in_4bit": model_args.quantization_bit == 4, "load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token, "token": model_args.hf_hub_token,
"device_map": get_current_device(), "device_map": get_current_device(),
"rope_scaling": getattr(config, "rope_scaling", None) "rope_scaling": getattr(config, "rope_scaling", None),
} }
if getattr(config, "model_type", None) == "llama": if getattr(config, "model_type", None) == "llama":
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs) model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
@ -89,7 +93,7 @@ def load_model_and_tokenizer(
config=config, config=config,
torch_dtype=model_args.compute_dtype, torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs **config_kwargs,
) )
patch_model(model, tokenizer, model_args, is_trainable) patch_model(model, tokenizer, model_args, is_trainable)
@ -119,9 +123,11 @@ def load_model_and_tokenizer(
model.train() model.train()
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( logger.info(
trainable_params, all_param, 100 * trainable_params / all_param "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
)) trainable_params, all_param, 100 * trainable_params / all_param
)
)
if not is_trainable: if not is_trainable:
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.") logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")

View File

@ -1,12 +1,12 @@
import os
import math import math
import torch import os
import random import random
from contextlib import nullcontext
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from contextlib import nullcontext
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@ -17,9 +17,11 @@ from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch from ..extras.patches.llama_patch import apply_llama_patch
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments from ..hparams import ModelArguments
@ -40,7 +42,8 @@ def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToke
Resize token embeddings. Resize token embeddings.
""" """
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore import deepspeed # type: ignore
params = [model.get_input_embeddings().weight] params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight) params.append(model.get_output_embeddings().weight)
@ -88,7 +91,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
sample_idx = random.randint(0, len(dataset) - 1) sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen: if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
@ -119,9 +122,9 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is
scaling_factor = 2.0 scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format( logger.info(
model_args.rope_scaling, scaling_factor "Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)) )
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None: def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
@ -146,22 +149,22 @@ def _configure_quantization(
config: "PretrainedConfig", config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
config_kwargs: Dict[str, Any] config_kwargs: Dict[str, Any],
) -> None: ) -> None:
r""" r"""
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
""" """
if getattr(config, "quantization_config", None): # gptq if getattr(config, "quantization_config", None): # gptq
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
config_kwargs["device_map"] = {"": get_current_device()} config_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4: if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
quantization_config["use_exllama"] = False # disable exllama quantization_config["use_exllama"] = False # disable exllama
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1))) logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
elif model_args.export_quantization_bit is not None: # auto-gptq elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory from accelerate.utils import get_max_memory
@ -172,13 +175,13 @@ def _configure_quantization(
config_kwargs["quantization_config"] = GPTQConfig( config_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit, bits=model_args.export_quantization_bit,
tokenizer=tokenizer, tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args) dataset=_get_quantization_dataset(tokenizer, model_args),
) )
config_kwargs["device_map"] = "auto" config_kwargs["device_map"] = "auto"
config_kwargs["max_memory"] = get_max_memory() config_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
@ -192,7 +195,7 @@ def _configure_quantization(
load_in_4bit=True, load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype, bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization, bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type bnb_4bit_quant_type=model_args.quantization_type,
) )
config_kwargs["device_map"] = {"": get_current_device()} config_kwargs["device_map"] = {"": get_current_device()}
@ -200,9 +203,7 @@ def _configure_quantization(
def _prepare_model_for_training( def _prepare_model_for_training(
model: "PreTrainedModel", model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head"
model_args: "ModelArguments",
output_layer_name: Optional[str] = "lm_head"
) -> None: ) -> None:
r""" r"""
Includes: Includes:
@ -222,10 +223,11 @@ def _prepare_model_for_training(
logger.warning("Current model does not support gradient checkpointing.") logger.warning("Current model does not support gradient checkpointing.")
else: else:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model.config.use_cache = False # turn off when gradient checkpointing is enabled model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.") logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor): def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32) return output.to(torch.float32)
@ -244,9 +246,9 @@ def patch_config(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
config_kwargs: Dict[str, Any], config_kwargs: Dict[str, Any],
is_trainable: bool is_trainable: bool,
) -> None: ) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if getattr(config, "model_type", None) == "qwen": if getattr(config, "model_type", None) == "qwen":
@ -266,10 +268,7 @@ def patch_config(
def patch_model( def patch_model(
model: "PreTrainedModel", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool
) -> None: ) -> None:
if "GenerationMixin" not in str(model.generate.__func__): if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model) model.generate = MethodType(PreTrainedModel.generate, model)

View File

@ -1,16 +1,19 @@
import torch
import inspect import inspect
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Any, Dict, List
import torch
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.utils import cached_file from transformers.utils import cached_file
from ..extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_current_device from ..extras.misc import get_current_device
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer
from ..hparams import ModelArguments, DataArguments, FinetuningArguments
from ..hparams import DataArguments, FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -21,7 +24,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available. Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570 Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
""" """
if getattr(model, "quantization_method", None): # already set on current device if getattr(model, "quantization_method", None): # already set on current device
return model return model
if ( if (
@ -31,7 +34,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
and model.config.model_type != "chatglm" and model.config.model_type != "chatglm"
): ):
from accelerate import dispatch_model from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory from accelerate.utils import get_balanced_memory, infer_auto_device_map
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")} kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
max_memory = get_balanced_memory(model, **kwargs) max_memory = get_balanced_memory(model, **kwargs)
@ -55,6 +58,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
linear_cls = torch.nn.Linear linear_cls = torch.nn.Linear
elif quantization_method == "bitsandbytes": elif quantization_method == "bitsandbytes":
import bitsandbytes as bnb import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
else: else:
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method)) raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
@ -65,10 +69,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
module_names = set() module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if ( if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
isinstance(module, linear_cls)
and not any([output_layer in name for output_layer in output_layer_names])
):
module_names.add(name.split(".")[-1]) module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names))) logger.info("Found linear modules: {}".format(",".join(module_names)))
@ -76,16 +77,14 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
def get_modelcard_args( def get_modelcard_args(
model_args: "ModelArguments", model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments"
data_args: "DataArguments",
finetuning_args: "FinetuningArguments"
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"tasks": "text-generation", "tasks": "text-generation",
"license": "other", "license": "other",
"finetuned_from": model_args.model_name_or_path, "finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")], "dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []) "tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []),
} }
@ -95,14 +94,11 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
""" """
kwargs = { kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}
try: try:
from safetensors import safe_open from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs) vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f: with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()} return {key: f.get_tensor(key) for key in f.keys()}

View File

@ -1,6 +1,7 @@
import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple from typing import Any, Dict, List, Sequence, Tuple
import torch
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
@ -20,7 +21,7 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
padded_tensor = self.label_pad_token_id * torch.ones_like(feature) padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end] padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor) padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r""" r"""
@ -34,10 +35,12 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
for key in ("chosen_ids", "rejected_ids"): for key in ("chosen_ids", "rejected_ids"):
for feature in features: for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append({ concatenated_features.append(
"input_ids": feature["prompt_ids"] + feature[key], {
"attention_mask": [1] * (prompt_len + answer_len) "input_ids": feature["prompt_ids"] + feature[key],
}) "attention_mask": [1] * (prompt_len + answer_len),
}
)
label_positions.append((prompt_len, answer_len)) label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad( batch = self.tokenizer.pad(

View File

@ -1,19 +1,20 @@
import torch
from contextlib import nullcontext
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
from transformers import BatchEncoding, Trainer from transformers import BatchEncoding, Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model from trl.trainer.utils import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel
class CustomDPOTrainer(DPOTrainer): class CustomDPOTrainer(DPOTrainer):
def __init__( def __init__(
self, self,
beta: float, beta: float,
@ -22,15 +23,15 @@ class CustomDPOTrainer(DPOTrainer):
model: Union["PreTrainedModel", torch.nn.Module], model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True, disable_dropout: Optional[bool] = True,
**kwargs **kwargs,
): ):
if disable_dropout: if disable_dropout:
disable_dropout_in_model(model) disable_dropout_in_model(model)
if ref_model is not None: if ref_model is not None:
disable_dropout_in_model(ref_model) disable_dropout_in_model(ref_model)
self.use_dpo_data_collator = True # hack to avoid warning self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0 self.padding_value = 0
self.is_encoder_decoder = model.config.is_encoder_decoder self.is_encoder_decoder = model.config.is_encoder_decoder
@ -53,42 +54,29 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None: if ref_model is not None:
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
if not ( if not (
getattr(ref_model, "is_loaded_in_8bit", False) getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
or getattr(ref_model, "is_loaded_in_4bit", False) ): # quantized models are already set on the correct device
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model) self.ref_model = self._prepare_deepspeed(self.ref_model)
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def sft_loss( def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
self,
chosen_logits: torch.FloatTensor,
chosen_labels: torch.LongTensor
) -> torch.Tensor:
r""" r"""
Computes supervised cross-entropy loss of given labels under the given logits. Computes supervised cross-entropy loss of given labels under the given logits.
Returns: Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples. A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
""" """
all_logps = self.get_batch_logps( all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
chosen_logits,
chosen_labels,
average_log_prob=True
)
return -all_logps return -all_logps
def concatenated_forward( def concatenated_forward(
self, self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits = model( all_logits = model(
input_ids=batch_copied["input_ids"], input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32) ).logits.to(torch.float32)
all_logps = self.get_batch_logps( all_logps = self.get_batch_logps(
@ -106,7 +94,7 @@ class CustomDPOTrainer(DPOTrainer):
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
batch: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
train_eval: Optional[Literal["train", "eval"]] = "train" train_eval: Optional[Literal["train", "eval"]] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
r""" r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test. Computes the DPO loss and other metrics for the given batch of inputs for train or test.

View File

@ -1,6 +1,7 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py # Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, List, Optional
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from ...data import get_dataset, split_dataset from ...data import get_dataset, split_dataset
@ -12,8 +13,10 @@ from ...train.dpo.collator import DPODataCollatorWithPadding
from ...train.dpo.trainer import CustomDPOTrainer from ...train.dpo.trainer import CustomDPOTrainer
from ...train.utils import create_modelcard_and_push, create_ref_model from ...train.utils import create_modelcard_and_push, create_ref_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
from ...hparams import DataArguments, FinetuningArguments from ...hparams import DataArguments, FinetuningArguments
@ -22,25 +25,25 @@ def run_dpo(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None,
): ):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding( data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
) )
# Create reference model # Create reference model
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model = model ref_model = model
else: else:
ref_model = create_ref_model(model_args, finetuning_args) ref_model = create_ref_model(model_args, finetuning_args)
# Update arguments # Update arguments
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer # Initialize our Trainer
@ -54,7 +57,7 @@ def run_dpo(
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
**split_dataset(dataset, data_args, training_args) **split_dataset(dataset, data_args, training_args),
) )
# Training # Training
@ -70,7 +73,7 @@ def run_dpo(
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval") metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model if id(model) == id(ref_model): # unable to compute rewards without a reference model
remove_keys = [key for key in metrics.keys() if "rewards" in key] remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys: for key in remove_keys:
metrics.pop(key) metrics.pop(key)

View File

@ -1,27 +1,28 @@
import math
import os import os
import sys import sys
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl import torch
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from tqdm import tqdm
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.trainer_pt_utils import remove_dummy_checkpoint from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl.core import PPODecorators, logprobs_from_logits
from ...extras.callbacks import LogCallback, FixValueHeadModelCallback from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
from .utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ...hparams import ModelArguments, FinetuningArguments, GeneratingArguments
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -40,7 +41,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"], callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead", reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs **kwargs,
): ):
PPOTrainer.__init__(self, **kwargs) PPOTrainer.__init__(self, **kwargs)
@ -52,7 +53,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.generation_config = GenerationConfig( self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict() **generating_args.to_dict(),
) )
self.state = TrainerState() self.state = TrainerState()
@ -71,7 +72,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if not ( if not (
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False) getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False) or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device ): # quantized models are already set on the correct device
self.reward_model = self._prepare_deepspeed(self.reward_model) self.reward_model = self._prepare_deepspeed(self.reward_model)
else: else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
@ -111,9 +112,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.info(" Num examples = {}".format(num_examples)) logger.info(" Num examples = {}".format(num_examples))
logger.info(" Num Epochs = {}".format(num_train_epochs)) logger.info(" Num Epochs = {}".format(num_train_epochs))
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size)) logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
logger.info(" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format( logger.info(
total_train_batch_size " Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
)) total_train_batch_size
)
)
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps)) logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs)) logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
logger.info(" Total training steps = {}".format(max_steps)) logger.info(" Total training steps = {}".format(max_steps))
@ -138,10 +141,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.eval() self.model.eval()
# Get inputs # Get inputs
self.tokenizer.padding_side = "right" # change padding side self.tokenizer.padding_side = "right" # change padding side
queries, responses, rewards = [], [], [] queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size): for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size]) mini_batch_queries, mini_batch_responses = self.get_inputs(
batch[idx : idx + self.config.mini_batch_size]
)
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model) mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
queries.extend(mini_batch_queries) queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses) responses.extend(mini_batch_responses)
@ -154,7 +159,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
# Run PPO step # Run PPO step
stats = self.step(queries, responses, rewards) stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards)) loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
@ -163,18 +168,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True) batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True) batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards) self.log_stats(stats, batch, rewards)
except: except Exception:
logger.warning("Failed to save stats due to unknown errors.") logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1 self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control) self.log_callback.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0: if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
logs = dict( logs = dict(
loss=round(loss_meter.avg, 4), loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4), reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"], learning_rate=stats["ppo/learning_rate"],
epoch=round(step / steps_in_epoch, 2) epoch=round(step / steps_in_epoch, 2),
) )
tqdm.write(str(logs)) tqdm.write(str(logs))
logs["step"] = step logs["step"] = step
@ -183,10 +188,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
loss_meter.reset() loss_meter.reset()
reward_meter.reset() reward_meter.reset()
if (step+1) % self.args.save_steps == 0: # save checkpoint if (step + 1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join( self.save_model(
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step) os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
)) )
self.save_callback.on_save( self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
) )
@ -207,35 +212,33 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.model_args.upcast_layernorm: if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model) layernorm_params = dump_layernorm(self.model)
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1 if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item() start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items(): for k, v in batch.items():
batch[k] = v[:, start_index:] batch[k] = v[:, start_index:]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
generate_output: torch.Tensor = unwrapped_model.generate( generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
logits_processor=get_logits_processor(),
**batch
) )
if self.model_args.upcast_layernorm: if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params) restore_layernorm(self.model, layernorm_params)
query = batch["input_ids"].detach().cpu() query = batch["input_ids"].detach().cpu()
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu() response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
queries, responses = [], [] queries, responses = [], []
for i in range(len(query)): for i in range(len(query)):
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item() query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero() response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0: if len(response_index) == 0:
response_length = 1 # allow empty response response_length = 1 # allow empty response
else: else:
response_length = response_index[-1].item() + 1 response_length = response_index[-1].item() + 1
queries.append(query[i, query_start_index:]) # remove padding from left queries.append(query[i, query_start_index:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right responses.append(response[i, :response_length]) # remove padding from right
return queries, responses return queries, responses
@ -244,7 +247,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self, self,
queries: List[torch.Tensor], queries: List[torch.Tensor],
responses: List[torch.Tensor], responses: List[torch.Tensor],
unwrapped_model: "AutoModelForCausalLMWithValueHead" unwrapped_model: "AutoModelForCausalLMWithValueHead",
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
r""" r"""
Computes scores using given reward model. Computes scores using given reward model.
@ -264,17 +267,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch = self.prepare_model_inputs(queries, responses) batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True) _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
rewards = [] rewards = []
for i in range(values.size(0)): for i in range(values.size(0)):
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero() end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0 end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
if self.finetuning_args.reward_model_type == "lora": if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
@ -289,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
responses: torch.Tensor, responses: torch.Tensor,
model_inputs: dict, model_inputs: dict,
return_logits: Optional[bool] = False, return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None response_masks: Optional[torch.Tensor] = None,
): ):
r""" r"""
Calculates model outputs in multiple batches. Calculates model outputs in multiple batches.
@ -312,7 +315,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
input_ids = input_kwargs["input_ids"] input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"] attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs) logits, _, values = model(**input_kwargs)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
@ -325,14 +328,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
for j in range(len(query_batch)): for j in range(len(query_batch)):
start = len(query_batch[j]) - 1 start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0].item() start += attention_mask[j, :].nonzero()[0].item()
end = start + len(response_batch[j]) end = start + len(response_batch[j])
if response_masks is not None: if response_masks is not None:
response_masks_batch = torch.cat( response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0 masks[j, :start] = 0
masks[j, end:] = 0 masks[j, end:] = 0

View File

@ -1,9 +1,11 @@
import json import json
import torch
from typing import TYPE_CHECKING, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Dict, List, Literal, Optional
import torch
from ...extras.packages import is_requests_available from ...extras.packages import is_requests_available
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@ -21,16 +23,18 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily if target == "reward": # save default head temporarily
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict() valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone()) setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone()) setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict({ model.v_head.load_state_dict(
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(), {
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone() "summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
}) "summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
}
)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:

View File

@ -1,23 +1,26 @@
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py # Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math import math
from trl import PPOConfig from typing import TYPE_CHECKING, List, Optional
from torch.optim import AdamW from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from trl import PPOConfig
from ...data import get_dataset from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer from ...model import load_model_and_tokenizer
from ...train.utils import create_ref_model, create_reward_model
from ...train.ppo.trainer import CustomPPOTrainer from ...train.ppo.trainer import CustomPPOTrainer
from ...train.utils import create_ref_model, create_reward_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
def run_ppo( def run_ppo(
@ -26,12 +29,14 @@ def run_ppo(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None,
): ):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True) model, tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, training_args.do_train, add_valuehead=True
)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo") dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model # Create reference model and reward model
@ -55,7 +60,7 @@ def run_ppo(
use_score_scaling=finetuning_args.ppo_score_norm, use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm, use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards, whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False} accelerator_kwargs={"step_scheduler_with_optimizer": False},
) )
# Create optimizer and scheduler # Create optimizer and scheduler
@ -70,7 +75,7 @@ def run_ppo(
training_args.lr_scheduler_type, training_args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps), num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps num_training_steps=num_training_steps,
) )
# Initialize our Trainer # Initialize our Trainer
@ -88,7 +93,7 @@ def run_ppo(
dataset=dataset, dataset=dataset,
data_collator=data_collator, data_collator=data_collator,
optimizer=optimizer, optimizer=optimizer,
lr_scheduler=lr_scheduler lr_scheduler=lr_scheduler,
) )
# Training # Training
@ -97,6 +102,6 @@ def run_ppo(
ppo_trainer.save_model() ppo_trainer.save_model()
if training_args.should_save: if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
ppo_trainer.save_state() # must be called after save_model to have a folder ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"]) plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@ -1,7 +1,8 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py # Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
import math import math
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForLanguageModeling, Trainer from transformers import DataCollatorForLanguageModeling, Trainer
from ...data import get_dataset, split_dataset from ...data import get_dataset, split_dataset
@ -9,9 +10,11 @@ from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer from ...model import load_model_and_tokenizer
from ...train.utils import create_modelcard_and_push from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import ModelArguments, DataArguments, FinetuningArguments
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
def run_pt( def run_pt(
@ -19,7 +22,7 @@ def run_pt(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None,
): ):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt") dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
@ -32,7 +35,7 @@ def run_pt(
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
**split_dataset(dataset, data_args, training_args) **split_dataset(dataset, data_args, training_args),
) )
# Training # Training

View File

@ -1,6 +1,7 @@
import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Sequence from typing import Any, Dict, Sequence
import torch
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
@ -20,8 +21,9 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
features = [ features = [
{ {
"input_ids": feature["prompt_ids"] + feature[key], "input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])) "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])),
} }
for key in ("chosen_ids", "rejected_ids") for feature in features for key in ("chosen_ids", "rejected_ids")
for feature in features
] ]
return super().__call__(features) return super().__call__(features)

View File

@ -1,6 +1,7 @@
import numpy as np
from typing import Dict, Sequence, Tuple, Union from typing import Dict, Sequence, Tuple, Union
import numpy as np
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
preds, _ = eval_preds preds, _ = eval_preds

View File

@ -1,14 +1,16 @@
import os
import json import json
import torch import os
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer from transformers import Trainer
from ...extras.logging import get_logger from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.trainer import PredictionOutput
logger = get_logger(__name__) logger = get_logger(__name__)
@ -21,13 +23,10 @@ class PairwiseTrainer(Trainer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.can_return_loss = True # override property to return eval_loss self.can_return_loss = True # override property to return eval_loss
def compute_loss( def compute_loss(
self, self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
model: "PreTrainedModel",
inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
r""" r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
@ -68,9 +67,9 @@ class PairwiseTrainer(Trainer):
assert div_index > 0 assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index] chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index] rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
if return_outputs: # use the score on the last token except pad token for inference if return_outputs: # use the score on the last token except pad token for inference
chosen_scores.append(chosen_rewards[i, chosen_length-1]) chosen_scores.append(chosen_rewards[i, chosen_length - 1])
rejected_scores.append(rejected_rewards[i, rejected_length-1]) rejected_scores.append(rejected_rewards[i, rejected_length - 1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean() loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
loss = loss / batch_size loss = loss / batch_size
@ -80,10 +79,7 @@ class PairwiseTrainer(Trainer):
return loss return loss
def save_predictions( def save_predictions(self, predict_results: "PredictionOutput") -> None:
self,
predict_results: "PredictionOutput"
) -> None:
r""" r"""
Saves model predictions to `output_dir`. Saves model predictions to `output_dir`.

View File

@ -1,6 +1,7 @@
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, List, Optional
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from ...data import get_dataset, split_dataset from ...data import get_dataset, split_dataset
@ -13,9 +14,11 @@ from ...train.rm.metric import compute_accuracy
from ...train.rm.trainer import PairwiseTrainer from ...train.rm.trainer import PairwiseTrainer
from ...train.utils import create_modelcard_and_push from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
from ...hparams import ModelArguments, DataArguments, FinetuningArguments
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
def run_rm( def run_rm(
@ -23,15 +26,17 @@ def run_rm(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None,
): ):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True) model, tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, training_args.do_train, add_valuehead=True
)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
# Update arguments # Update arguments
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer # Initialize our Trainer
@ -42,7 +47,7 @@ def run_rm(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks + [FixValueHeadModelCallback()], callbacks=callbacks + [FixValueHeadModelCallback()],
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args) **split_dataset(dataset, data_args, training_args),
) )
# Training # Training

View File

@ -1,11 +1,11 @@
import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import numpy as np
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import ( from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
is_jieba_available, is_nltk_available, is_rouge_available
)
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
@ -14,7 +14,7 @@ if is_jieba_available():
import jieba import jieba
if is_nltk_available(): if is_nltk_available():
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
if is_rouge_available(): if is_rouge_available():
from rouge_chinese import Rouge from rouge_chinese import Rouge

View File

@ -1,14 +1,16 @@
import os
import json import json
import torch import os
import numpy as np
import torch.nn as nn
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
@ -33,16 +35,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
if self.args.predict_with_generate: if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len: if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility) if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
inputs["labels"] = inputs["labels"][:, :prompt_len] inputs["labels"] = inputs["labels"][:, :prompt_len]
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
if generated_tokens is not None and self.args.predict_with_generate: if generated_tokens is not None and self.args.predict_with_generate:
@ -51,23 +53,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return loss, generated_tokens, labels return loss, generated_tokens, labels
def _pad_tensors_to_target_len( def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
self,
src_tensor: torch.Tensor,
tgt_tensor: torch.Tensor
) -> torch.Tensor:
r""" r"""
Pads the tensor to the same length as the target tensor. Pads the tensor to the same length as the target tensor.
""" """
assert self.tokenizer.pad_token_id is not None, "Pad token is required." assert self.tokenizer.pad_token_id is not None, "Pad token is required."
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory return padded_tensor.contiguous() # in contiguous memory
def save_predictions( def save_predictions(self, predict_results: "PredictionOutput") -> None:
self,
predict_results: "PredictionOutput"
) -> None:
r""" r"""
Saves model predictions to `output_dir`. Saves model predictions to `output_dir`.
@ -79,15 +74,23 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}") logger.info(f"Saving prediction results to {output_prediction_file}")
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) labels = np.where(
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
)
preds = np.where(
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
)
for i in range(len(preds)): for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
if len(pad_len): if len(pad_len):
preds[i] = np.concatenate((preds[i][pad_len[0]:], preds[i][:pad_len[0]]), axis=-1) # move pad token to last preds[i] = np.concatenate(
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
) # move pad token to last
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False) decoded_labels = self.tokenizer.batch_decode(
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer: with open(output_prediction_file, "w", encoding="utf-8") as writer:

View File

@ -1,6 +1,7 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py # Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from ...data import get_dataset, split_dataset from ...data import get_dataset, split_dataset
@ -15,7 +16,8 @@ from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
def run_sft( def run_sft(
@ -24,29 +26,31 @@ def run_sft(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None,
): ):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft") dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
if training_args.predict_with_generate: if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation tokenizer.padding_side = "left" # use left-padding in generation
if getattr(model, "is_quantized", False) and not training_args.do_train: if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
data_collator = DataCollatorForSeq2Seq( data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
) )
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()
training_args_dict.update(dict( training_args_dict.update(
generation_max_length=training_args.generation_max_length or data_args.cutoff_len, dict(
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
)) generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams,
)
)
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer # Initialize our Trainer
@ -57,7 +61,7 @@ def run_sft(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args, training_args) **split_dataset(dataset, data_args, training_args),
) )
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
@ -79,7 +83,7 @@ def run_sft(
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None) metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
@ -87,7 +91,7 @@ def run_sft(
# Predict # Predict
if training_args.do_predict: if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None) predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics) trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics)

View File

@ -1,16 +1,18 @@
import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from transformers import PreTrainedModel from transformers import PreTrainedModel
from ..extras.callbacks import LogCallback from ..extras.callbacks import LogCallback
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..hparams import get_train_args, get_infer_args from ..hparams import get_infer_args, get_train_args
from ..model import load_model_and_tokenizer from ..model import load_model_and_tokenizer
from .pt import run_pt
from .sft import run_sft
from .rm import run_rm
from .ppo import run_ppo
from .dpo import run_dpo from .dpo import run_dpo
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
@ -64,23 +66,23 @@ def export_model(args: Optional[Dict[str, Any]] = None):
model.save_pretrained( model.save_pretrained(
save_directory=model_args.export_dir, save_directory=model_args.export_dir,
max_shard_size="{}GB".format(model_args.export_size), max_shard_size="{}GB".format(model_args.export_size),
safe_serialization=(not model_args.export_legacy_format) safe_serialization=(not model_args.export_legacy_format),
) )
if model_args.export_hub_model_id is not None: if model_args.export_hub_model_id is not None:
model.push_to_hub( model.push_to_hub(
model_args.export_hub_model_id, model_args.export_hub_model_id,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
max_shard_size="{}GB".format(model_args.export_size), max_shard_size="{}GB".format(model_args.export_size),
safe_serialization=(not model_args.export_legacy_format) safe_serialization=(not model_args.export_legacy_format),
) )
try: try:
tokenizer.padding_side = "left" # restore padding side tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left" tokenizer.init_kwargs["padding_side"] = "left"
tokenizer.save_pretrained(model_args.export_dir) tokenizer.save_pretrained(model_args.export_dir)
if model_args.export_hub_model_id is not None: if model_args.export_hub_model_id is not None:
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
except: except Exception:
logger.warning("Cannot save tokenizer, please copy the files manually.") logger.warning("Cannot save tokenizer, please copy the files manually.")

View File

@ -1,14 +1,17 @@
import torch
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
import torch
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..hparams import ModelArguments, FinetuningArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, Trainer from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments from ..hparams import DataArguments
@ -20,7 +23,7 @@ def create_modelcard_and_push(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments" finetuning_args: "FinetuningArguments",
) -> None: ) -> None:
if training_args.do_train: if training_args.do_train:
if training_args.push_to_hub: if training_args.push_to_hub:
@ -33,9 +36,7 @@ def create_modelcard_and_push(
def create_ref_model( def create_ref_model(
model_args: "ModelArguments", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: Optional[bool] = False
finetuning_args: "FinetuningArguments",
add_valuehead: Optional[bool] = False
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]: ) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r""" r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported. Creates reference model for PPO/DPO training. Evaluation mode is not supported.
@ -44,11 +45,13 @@ def create_ref_model(
""" """
if finetuning_args.ref_model is not None: if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict() ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict( ref_model_args_dict.update(
model_name_or_path=finetuning_args.ref_model, dict(
adapter_name_or_path=finetuning_args.ref_model_adapters, model_name_or_path=finetuning_args.ref_model,
quantization_bit=finetuning_args.ref_model_quantization_bit adapter_name_or_path=finetuning_args.ref_model_adapters,
)) quantization_bit=finetuning_args.ref_model_quantization_bit,
)
)
ref_model_args = ModelArguments(**ref_model_args_dict) ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora") ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer( ref_model, _ = load_model_and_tokenizer(
@ -68,9 +71,7 @@ def create_ref_model(
def create_reward_model( def create_reward_model(
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead": ) -> "AutoModelForCausalLMWithValueHead":
r""" r"""
Creates reward model for PPO training. Creates reward model for PPO training.
@ -81,24 +82,30 @@ def create_reward_model(
return finetuning_args.reward_model return finetuning_args.reward_model
elif finetuning_args.reward_model_type == "lora": elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090 for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name: if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32 param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args) vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded." assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False) model.register_buffer(
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False) "default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False
)
model.register_buffer(
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None return None
else: else:
reward_model_args_dict = model_args.to_dict() reward_model_args_dict = model_args.to_dict()
reward_model_args_dict.update(dict( reward_model_args_dict.update(
model_name_or_path=finetuning_args.reward_model, dict(
adapter_name_or_path=finetuning_args.reward_model_adapters, model_name_or_path=finetuning_args.reward_model,
quantization_bit=finetuning_args.reward_model_quantization_bit adapter_name_or_path=finetuning_args.reward_model_adapters,
)) quantization_bit=finetuning_args.reward_model_quantization_bit,
)
)
reward_model_args = ModelArguments(**reward_model_args_dict) reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora") reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer( reward_model, _ = load_model_and_tokenizer(

View File

@ -1,24 +1,22 @@
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel from ..chat import ChatModel
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..hparams import GeneratingArguments from ..hparams import GeneratingArguments
from .common import get_save_dir from .common import get_save_dir
from .locales import ALERTS from .locales import ALERTS
if TYPE_CHECKING: if TYPE_CHECKING:
from .manager import Manager from .manager import Manager
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__( def __init__(
self, self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
manager: "Manager",
demo_mode: Optional[bool] = False,
lazy_init: Optional[bool] = True
) -> None: ) -> None:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
@ -26,11 +24,12 @@ class WebChatModel(ChatModel):
self.tokenizer = None self.tokenizer = None
self.generating_args = GeneratingArguments() self.generating_args = GeneratingArguments()
if not lazy_init: # read arguments from command line if not lazy_init: # read arguments from command line
super().__init__() super().__init__()
if demo_mode: # load demo_config.json if exists if demo_mode: # load demo_config.json if exists
import json import json
try: try:
with open("demo_config.json", "r", encoding="utf-8") as f: with open("demo_config.json", "r", encoding="utf-8") as f:
args = json.load(f) args = json.load(f)
@ -38,7 +37,7 @@ class WebChatModel(ChatModel):
super().__init__(args) super().__init__(args)
except AssertionError: except AssertionError:
print("Please provided model name and template in `demo_config.json`.") print("Please provided model name and template in `demo_config.json`.")
except: except Exception:
print("Cannot find `demo_config.json` at current directory.") print("Cannot find `demo_config.json` at current directory.")
@property @property
@ -64,9 +63,12 @@ class WebChatModel(ChatModel):
return return
if get("top.adapter_path"): if get("top.adapter_path"):
adapter_name_or_path = ",".join([ adapter_name_or_path = ",".join(
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) [
for adapter in get("top.adapter_path")]) get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else: else:
adapter_name_or_path = None adapter_name_or_path = None
@ -79,7 +81,7 @@ class WebChatModel(ChatModel):
template=get("top.template"), template=get("top.template"),
flash_attn=(get("top.booster") == "flash_attn"), flash_attn=(get("top.booster") == "flash_attn"),
use_unsloth=(get("top.booster") == "unsloth"), use_unsloth=(get("top.booster") == "unsloth"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
) )
super().__init__(args) super().__init__(args)
@ -108,7 +110,7 @@ class WebChatModel(ChatModel):
tools: str, tools: str,
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float temperature: float,
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]: ) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
chatbot.append([query, ""]) chatbot.append([query, ""])
response = "" response = ""

View File

@ -1,9 +1,10 @@
import os
import json import json
import gradio as gr import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
import gradio as gr
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from ..extras.constants import ( from ..extras.constants import (
DATA_CONFIG, DATA_CONFIG,
@ -12,7 +13,7 @@ from ..extras.constants import (
PEFT_METHODS, PEFT_METHODS,
SUPPORTED_MODELS, SUPPORTED_MODELS,
TRAINING_STAGES, TRAINING_STAGES,
DownloadSource DownloadSource,
) )
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope
@ -36,7 +37,7 @@ def load_config() -> Dict[str, Any]:
try: try:
with open(get_config_path(), "r", encoding="utf-8") as f: with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
except: except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@ -59,7 +60,7 @@ def get_model_path(model_name: str) -> str:
use_modelscope() use_modelscope()
and path_dict.get(DownloadSource.MODELSCOPE) and path_dict.get(DownloadSource.MODELSCOPE)
and model_path == path_dict.get(DownloadSource.DEFAULT) and model_path == path_dict.get(DownloadSource.DEFAULT)
): # replace path ): # replace path
model_path = path_dict.get(DownloadSource.MODELSCOPE) model_path = path_dict.get(DownloadSource.MODELSCOPE)
return model_path return model_path
@ -87,9 +88,8 @@ def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
save_dir = get_save_dir(model_name, finetuning_type) save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir): if save_dir and os.path.isdir(save_dir):
for adapter in os.listdir(save_dir): for adapter in os.listdir(save_dir):
if ( if os.path.isdir(os.path.join(save_dir, adapter)) and any(
os.path.isdir(os.path.join(save_dir, adapter)) os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
and any([os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES])
): ):
adapters.append(adapter) adapters.append(adapter)
return gr.update(value=[], choices=adapters, interactive=True) return gr.update(value=[], choices=adapters, interactive=True)

View File

@ -1,11 +1,16 @@
from .chatbot import create_chat_box
from .eval import create_eval_tab
from .export import create_export_tab
from .infer import create_infer_tab
from .top import create_top from .top import create_top
from .train import create_train_tab from .train import create_train_tab
from .eval import create_eval_tab
from .infer import create_infer_tab
from .export import create_export_tab
from .chatbot import create_chat_box
__all__ = [ __all__ = [
"create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box" "create_top",
"create_train_tab",
"create_eval_tab",
"create_infer_tab",
"create_export_tab",
"create_chat_box",
] ]

View File

@ -1,6 +1,7 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Optional, Tuple from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr
from ..utils import check_json_schema from ..utils import check_json_schema
@ -12,8 +13,7 @@ if TYPE_CHECKING:
def create_chat_box( def create_chat_box(
engine: "Engine", engine: "Engine", visible: Optional[bool] = False
visible: Optional[bool] = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box: with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot() chatbot = gr.Chatbot()
@ -38,20 +38,23 @@ def create_chat_box(
engine.chatter.predict, engine.chatter.predict,
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature], [chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
[chatbot, history], [chatbot, history],
show_progress=True show_progress=True,
).then( ).then(lambda: gr.update(value=""), outputs=[query])
lambda: gr.update(value=""), outputs=[query]
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict( return (
system=system, chat_box,
tools=tools, chatbot,
query=query, history,
submit_btn=submit_btn, dict(
clear_btn=clear_btn, system=system,
max_new_tokens=max_new_tokens, tools=tools,
top_p=top_p, query=query,
temperature=temperature submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
),
) )

View File

@ -1,10 +1,12 @@
import os
import json import json
import gradio as gr import os
from typing import TYPE_CHECKING, Any, Dict, Tuple from typing import TYPE_CHECKING, Any, Dict, Tuple
import gradio as gr
from ...extras.constants import DATA_CONFIG from ...extras.constants import DATA_CONFIG
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@ -24,7 +26,7 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
try: try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except: except Exception:
return gr.update(interactive=False) return gr.update(interactive=False)
if ( if (
@ -48,7 +50,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
elif data_file.endswith(".jsonl"): elif data_file.endswith(".jsonl"):
data = [json.loads(line) for line in f] data = [json.loads(line) for line in f]
else: else:
data = [line for line in f] data = [line for line in f] # noqa: C416
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True) return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
@ -67,32 +69,17 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
with gr.Row(): with gr.Row():
preview_samples = gr.JSON(interactive=False) preview_samples = gr.JSON(interactive=False)
dataset.change( dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
).then(
lambda: 0, outputs=[page_index], queue=False lambda: 0, outputs=[page_index], queue=False
) )
data_preview_btn.click( data_preview_btn.click(
get_preview, get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
) )
prev_btn.click( prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
prev_page, [page_index], [page_index], queue=False get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
) )
next_btn.click( next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
next_page, [page_index, preview_count], [page_index], queue=False get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
) )
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False) close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return dict( return dict(
@ -102,5 +89,5 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
prev_btn=prev_btn, prev_btn=prev_btn,
next_btn=next_btn, next_btn=next_btn,
close_btn=close_btn, close_btn=close_btn,
preview_samples=preview_samples preview_samples=preview_samples,
) )

View File

@ -1,9 +1,11 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from ..common import list_dataset, DEFAULT_DATA_DIR import gradio as gr
from ..common import DEFAULT_DATA_DIR, list_dataset
from .data import create_preview_box from .data import create_preview_box
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@ -31,9 +33,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
predict = gr.Checkbox(value=True) predict = gr.Checkbox(value=True)
input_elems.update({cutoff_len, max_samples, batch_size, predict}) input_elems.update({cutoff_len, max_samples, batch_size, predict})
elem_dict.update(dict( elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
))
with gr.Row(): with gr.Row():
max_new_tokens = gr.Slider(10, 2048, value=128, step=1) max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
@ -42,9 +42,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir = gr.Textbox() output_dir = gr.Textbox()
input_elems.update({max_new_tokens, top_p, temperature, output_dir}) input_elems.update({max_new_tokens, top_p, temperature, output_dir})
elem_dict.update(dict( elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
))
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button() cmd_preview_btn = gr.Button()
@ -59,10 +57,16 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
output_box = gr.Markdown() output_box = gr.Markdown()
output_elems = [output_box, process_bar] output_elems = [output_box, process_bar]
elem_dict.update(dict( elem_dict.update(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, dict(
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box cmd_preview_btn=cmd_preview_btn,
)) start_btn=start_btn,
stop_btn=stop_btn,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,
)
)
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems) cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
start_btn.click(engine.runner.run_eval, input_elems, output_elems) start_btn.click(engine.runner.run_eval, input_elems, output_elems)

View File

@ -1,10 +1,12 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List from typing import TYPE_CHECKING, Dict, Generator, List
import gradio as gr
from ...train import export_model from ...train import export_model
from ..common import get_save_dir from ..common import get_save_dir
from ..locales import ALERTS from ..locales import ALERTS
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@ -24,7 +26,7 @@ def save_model(
max_shard_size: int, max_shard_size: int,
export_quantization_bit: int, export_quantization_bit: int,
export_quantization_dataset: str, export_quantization_dataset: str,
export_dir: str export_dir: str,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
error = "" error = ""
if not model_name: if not model_name:
@ -44,7 +46,9 @@ def save_model(
return return
if adapter_path: if adapter_path:
adapter_name_or_path = ",".join([get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]) adapter_name_or_path = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
)
else: else:
adapter_name_or_path = None adapter_name_or_path = None
@ -56,7 +60,7 @@ def save_model(
export_dir=export_dir, export_dir=export_dir,
export_size=max_shard_size, export_size=max_shard_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None, export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset export_quantization_dataset=export_quantization_dataset,
) )
yield ALERTS["info_exporting"][lang] yield ALERTS["info_exporting"][lang]
@ -86,9 +90,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
max_shard_size, max_shard_size,
export_quantization_bit, export_quantization_bit,
export_quantization_dataset, export_quantization_dataset,
export_dir export_dir,
], ],
[info_box] [info_box],
) )
return dict( return dict(
@ -97,5 +101,5 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
export_quantization_dataset=export_quantization_dataset, export_quantization_dataset=export_quantization_dataset,
export_dir=export_dir, export_dir=export_dir,
export_btn=export_btn, export_btn=export_btn,
info_box=info_box info_box=info_box,
) )

View File

@ -1,8 +1,10 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr
from .chatbot import create_chat_box from .chatbot import create_chat_box
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@ -23,18 +25,12 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False) chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems)) elem_dict.update(dict(chat_box=chat_box, **chat_elems))
load_btn.click( load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
engine.chatter.load_model, input_elems, [info_box]
).then(
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box] lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
) )
unload_btn.click( unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
engine.chatter.unload_model, input_elems, [info_box]
).then(
lambda: ([], []), outputs=[chatbot, history] lambda: ([], []), outputs=[chatbot, history]
).then( ).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box])
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
)
return elem_dict return elem_dict

View File

@ -1,11 +1,13 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr
from ...data import templates from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from ..common import get_model_path, get_template, list_adapters, save_config from ..common import get_model_path, get_template, list_adapters, save_config
from ..utils import can_quantize from ..utils import can_quantize
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@ -30,25 +32,19 @@ def create_top() -> Dict[str, "Component"]:
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none") rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none") booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none")
model_name.change( model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
).then(
get_model_path, [model_name], [model_path], queue=False get_model_path, [model_name], [model_path], queue=False
).then( ).then(
get_template, [model_name], [template], queue=False get_template, [model_name], [template], queue=False
) # do not save config since the below line will save ) # do not save config since the below line will save
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False) model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change( finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
).then(
can_quantize, [finetuning_type], [quantization_bit], queue=False can_quantize, [finetuning_type], [quantization_bit], queue=False
) )
refresh_btn.click( refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
)
return dict( return dict(
lang=lang, lang=lang,
@ -61,5 +57,5 @@ def create_top() -> Dict[str, "Component"]:
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
template=template, template=template,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
booster=booster booster=booster,
) )

View File

@ -1,12 +1,14 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr
from transformers.trainer_utils import SchedulerType from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES from ...extras.constants import TRAINING_STAGES
from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
from ..components.data import create_preview_box from ..components.data import create_preview_box
from ..utils import gen_plot from ..utils import gen_plot
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@ -29,9 +31,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
input_elems.update({training_stage, dataset_dir, dataset}) input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict( elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
))
with gr.Row(): with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
@ -41,25 +41,33 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16") compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type}) input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
elem_dict.update(dict( elem_dict.update(
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs, dict(
max_samples=max_samples, compute_type=compute_type cutoff_len=cutoff_len,
)) learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
compute_type=compute_type,
)
)
with gr.Row(): with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown( lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
)
max_grad_norm = gr.Textbox(value="1.0") max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size}) input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
elem_dict.update(dict( elem_dict.update(
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, dict(
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size batch_size=batch_size,
)) gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
val_size=val_size,
)
)
with gr.Accordion(label="Extra config", open=False) as extra_tab: with gr.Accordion(label="Extra config", open=False) as extra_tab:
with gr.Row(): with gr.Row():
@ -73,10 +81,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
upcast_layernorm = gr.Checkbox(value=False) upcast_layernorm = gr.Checkbox(value=False)
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, sft_packing, upcast_layernorm}) input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, sft_packing, upcast_layernorm})
elem_dict.update(dict( elem_dict.update(
extra_tab=extra_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps, dict(
neftune_alpha=neftune_alpha, sft_packing=sft_packing, upcast_layernorm=upcast_layernorm extra_tab=extra_tab,
)) logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha,
sft_packing=sft_packing,
upcast_layernorm=upcast_layernorm,
)
)
with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row(): with gr.Row():
@ -87,10 +102,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter = gr.Checkbox(scale=1) create_new_adapter = gr.Checkbox(scale=1)
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, create_new_adapter}) input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, create_new_adapter})
elem_dict.update(dict( elem_dict.update(
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target, dict(
additional_target=additional_target, create_new_adapter=create_new_adapter lora_tab=lora_tab,
)) lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
additional_target=additional_target,
create_new_adapter=create_new_adapter,
)
)
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row(): with gr.Row():
@ -103,13 +124,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
list_adapters, list_adapters,
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")], [engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
[reward_model], [reward_model],
queue=False queue=False,
) )
input_elems.update({dpo_beta, dpo_ftx, reward_model}) input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict( elem_dict.update(
rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
)) )
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button() cmd_preview_btn = gr.Button()
@ -139,20 +160,28 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn.click(engine.runner.set_abort, queue=False) stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems) resume_btn.change(engine.runner.monitor, outputs=output_elems)
elem_dict.update(dict( elem_dict.update(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir, dict(
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer cmd_preview_btn=cmd_preview_btn,
)) start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,
loss_viewer=loss_viewer,
)
)
output_box.change( output_box.change(
gen_plot, gen_plot,
[ [
engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.model_name"),
engine.manager.get_elem_by_name("top.finetuning_type"), engine.manager.get_elem_by_name("top.finetuning_type"),
output_dir output_dir,
], ],
loss_viewer, loss_viewer,
queue=False queue=False,
) )
return elem_dict return elem_dict

View File

@ -1,7 +1,8 @@
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, Optional from typing import Any, Dict, Generator, Optional
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from .chatter import WebChatModel from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config from .common import get_model_path, list_dataset, load_config
from .locales import LOCALES from .locales import LOCALES
@ -11,7 +12,6 @@ from .utils import get_time
class Engine: class Engine:
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None: def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
self.demo_mode = demo_mode self.demo_mode = demo_mode
self.pure_chat = pure_chat self.pure_chat = pure_chat
@ -26,10 +26,7 @@ class Engine:
user_config = load_config() if not self.demo_mode else {} user_config = load_config() if not self.demo_mode else {}
lang = user_config.get("lang", None) or "en" lang = user_config.get("lang", None) or "en"
init_dict = { init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
"top.lang": {"value": lang},
"infer.chat_box": {"visible": self.chatter.loaded}
}
if not self.pure_chat: if not self.pure_chat:
init_dict["train.dataset"] = {"choices": list_dataset()["choices"]} init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
@ -49,13 +46,17 @@ class Engine:
else: else:
yield self._form_dict({"eval.resume_btn": {"value": True}}) yield self._form_dict({"eval.resume_btn": {"value": True}})
else: else:
yield self._form_dict({ yield self._form_dict(
"train.output_dir": {"value": "train_" + get_time()}, {
"eval.output_dir": {"value": "eval_" + get_time()}, "train.output_dir": {"value": "train_" + get_time()},
}) "eval.output_dir": {"value": "eval_" + get_time()},
}
)
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]: def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
return { return {
component: gr.update(**LOCALES[name][lang]) component: gr.update(**LOCALES[name][lang])
for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES for elems in self.manager.all_elems.values()
for name, component in elems.items()
if name in LOCALES
} }

View File

@ -1,21 +1,22 @@
import gradio as gr
from typing import Optional from typing import Optional
import gradio as gr
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from .common import save_config
from .components import ( from .components import (
create_chat_box,
create_eval_tab,
create_export_tab,
create_infer_tab,
create_top, create_top,
create_train_tab, create_train_tab,
create_eval_tab,
create_infer_tab,
create_export_tab,
create_chat_box
) )
from .common import save_config
from .css import CSS from .css import CSS
from .engine import Engine from .engine import Engine
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"") require_version("gradio>=3.38.0,<4.0.0", 'To fix: pip install "gradio>=3.38.0,<4.0.0"')
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks: def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
@ -23,11 +24,9 @@ def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
with gr.Blocks(title="LLaMA Board", css=CSS) as demo: with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
if demo_mode: if demo_mode:
gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
gr.HTML( gr.HTML(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>" '<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
)
gr.HTML(
"<h3><center>Visit <a href=\"https://github.com/hiyouga/LLaMA-Factory\" target=\"_blank\">"
"LLaMA Factory</a> for details.</center></h3>" "LLaMA Factory</a> for details.</center></h3>"
) )
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")

View File

@ -1,726 +1,220 @@
LOCALES = { LOCALES = {
"lang": { "lang": {"en": {"label": "Lang"}, "zh": {"label": "语言"}},
"en": { "model_name": {"en": {"label": "Model name"}, "zh": {"label": "模型名称"}},
"label": "Lang"
},
"zh": {
"label": "语言"
}
},
"model_name": {
"en": {
"label": "Model name"
},
"zh": {
"label": "模型名称"
}
},
"model_path": { "model_path": {
"en": { "en": {"label": "Model path", "info": "Path to pretrained model or model identifier from Hugging Face."},
"label": "Model path", "zh": {"label": "模型路径", "info": "本地模型的文件路径或 Hugging Face 的模型标识符。"},
"info": "Path to pretrained model or model identifier from Hugging Face."
},
"zh": {
"label": "模型路径",
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
}
},
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
}
},
"adapter_path": {
"en": {
"label": "Adapter path"
},
"zh": {
"label": "适配器路径"
}
},
"refresh_btn": {
"en": {
"value": "Refresh adapters"
},
"zh": {
"value": "刷新适配器"
}
},
"advanced_tab": {
"en": {
"label": "Advanced configurations"
},
"zh": {
"label": "高级设置"
}
}, },
"finetuning_type": {"en": {"label": "Finetuning method"}, "zh": {"label": "微调方法"}},
"adapter_path": {"en": {"label": "Adapter path"}, "zh": {"label": "适配器路径"}},
"refresh_btn": {"en": {"value": "Refresh adapters"}, "zh": {"value": "刷新适配器"}},
"advanced_tab": {"en": {"label": "Advanced configurations"}, "zh": {"label": "高级设置"}},
"quantization_bit": { "quantization_bit": {
"en": { "en": {"label": "Quantization bit", "info": "Enable 4/8-bit model quantization (QLoRA)."},
"label": "Quantization bit", "zh": {"label": "量化等级", "info": "启用 4/8 比特模型量化QLoRA"},
"info": "Enable 4/8-bit model quantization (QLoRA)."
},
"zh": {
"label": "量化等级",
"info": "启用 4/8 比特模型量化QLoRA"
}
}, },
"template": { "template": {
"en": { "en": {"label": "Prompt template", "info": "The template used in constructing prompts."},
"label": "Prompt template", "zh": {"label": "提示模板", "info": "构建提示词时使用的模板"},
"info": "The template used in constructing prompts."
},
"zh": {
"label": "提示模板",
"info": "构建提示词时使用的模板"
}
},
"rope_scaling": {
"en": {
"label": "RoPE scaling"
},
"zh": {
"label": "RoPE 插值方法"
}
},
"booster": {
"en": {
"label": "Booster"
},
"zh": {
"label": "加速方式"
}
}, },
"rope_scaling": {"en": {"label": "RoPE scaling"}, "zh": {"label": "RoPE 插值方法"}},
"booster": {"en": {"label": "Booster"}, "zh": {"label": "加速方式"}},
"training_stage": { "training_stage": {
"en": { "en": {"label": "Stage", "info": "The stage to perform in training."},
"label": "Stage", "zh": {"label": "训练阶段", "info": "目前采用的训练方式。"},
"info": "The stage to perform in training."
},
"zh": {
"label": "训练阶段",
"info": "目前采用的训练方式。"
}
}, },
"dataset_dir": { "dataset_dir": {
"en": { "en": {"label": "Data dir", "info": "Path to the data directory."},
"label": "Data dir", "zh": {"label": "数据路径", "info": "数据文件夹的路径。"},
"info": "Path to the data directory."
},
"zh": {
"label": "数据路径",
"info": "数据文件夹的路径。"
}
},
"dataset": {
"en": {
"label": "Dataset"
},
"zh": {
"label": "数据集"
}
},
"data_preview_btn": {
"en": {
"value": "Preview dataset"
},
"zh": {
"value": "预览数据集"
}
},
"preview_count": {
"en": {
"label": "Count"
},
"zh": {
"label": "数量"
}
},
"page_index": {
"en": {
"label": "Page"
},
"zh": {
"label": "页数"
}
},
"prev_btn": {
"en": {
"value": "Prev"
},
"zh": {
"value": "上一页"
}
},
"next_btn": {
"en": {
"value": "Next"
},
"zh": {
"value": "下一页"
}
},
"close_btn": {
"en": {
"value": "Close"
},
"zh": {
"value": "关闭"
}
},
"preview_samples": {
"en": {
"label": "Samples"
},
"zh": {
"label": "样例"
}
}, },
"dataset": {"en": {"label": "Dataset"}, "zh": {"label": "数据集"}},
"data_preview_btn": {"en": {"value": "Preview dataset"}, "zh": {"value": "预览数据集"}},
"preview_count": {"en": {"label": "Count"}, "zh": {"label": "数量"}},
"page_index": {"en": {"label": "Page"}, "zh": {"label": "页数"}},
"prev_btn": {"en": {"value": "Prev"}, "zh": {"value": "上一页"}},
"next_btn": {"en": {"value": "Next"}, "zh": {"value": "下一页"}},
"close_btn": {"en": {"value": "Close"}, "zh": {"value": "关闭"}},
"preview_samples": {"en": {"label": "Samples"}, "zh": {"label": "样例"}},
"cutoff_len": { "cutoff_len": {
"en": { "en": {"label": "Cutoff length", "info": "Max tokens in input sequence."},
"label": "Cutoff length", "zh": {"label": "截断长度", "info": "输入序列分词后的最大长度。"},
"info": "Max tokens in input sequence."
},
"zh": {
"label": "截断长度",
"info": "输入序列分词后的最大长度。"
}
}, },
"learning_rate": { "learning_rate": {
"en": { "en": {"label": "Learning rate", "info": "Initial learning rate for AdamW."},
"label": "Learning rate", "zh": {"label": "学习率", "info": "AdamW 优化器的初始学习率。"},
"info": "Initial learning rate for AdamW."
},
"zh": {
"label": "学习率",
"info": "AdamW 优化器的初始学习率。"
}
}, },
"num_train_epochs": { "num_train_epochs": {
"en": { "en": {"label": "Epochs", "info": "Total number of training epochs to perform."},
"label": "Epochs", "zh": {"label": "训练轮数", "info": "需要执行的训练总轮数。"},
"info": "Total number of training epochs to perform."
},
"zh": {
"label": "训练轮数",
"info": "需要执行的训练总轮数。"
}
}, },
"max_samples": { "max_samples": {
"en": { "en": {"label": "Max samples", "info": "Maximum samples per dataset."},
"label": "Max samples", "zh": {"label": "最大样本数", "info": "每个数据集最多使用的样本数。"},
"info": "Maximum samples per dataset."
},
"zh": {
"label": "最大样本数",
"info": "每个数据集最多使用的样本数。"
}
}, },
"compute_type": { "compute_type": {
"en": { "en": {"label": "Compute type", "info": "Whether to use fp16 or bf16 mixed precision training."},
"label": "Compute type", "zh": {"label": "计算类型", "info": "是否启用 FP16 或 BF16 混合精度训练。"},
"info": "Whether to use fp16 or bf16 mixed precision training."
},
"zh": {
"label": "计算类型",
"info": "是否启用 FP16 或 BF16 混合精度训练。"
}
}, },
"batch_size": { "batch_size": {
"en": { "en": {"label": "Batch size", "info": "Number of samples to process per GPU."},
"label": "Batch size", "zh": {"label": "批处理大小", "info": "每块 GPU 上处理的样本数量。"},
"info": "Number of samples to process per GPU."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
}, },
"gradient_accumulation_steps": { "gradient_accumulation_steps": {
"en": { "en": {"label": "Gradient accumulation", "info": "Number of gradient accumulation steps."},
"label": "Gradient accumulation", "zh": {"label": "梯度累积", "info": "梯度累积的步数。"},
"info": "Number of gradient accumulation steps."
},
"zh": {
"label": "梯度累积",
"info": "梯度累积的步数。"
}
}, },
"lr_scheduler_type": { "lr_scheduler_type": {
"en": { "en": {
"label": "LR Scheduler", "label": "LR Scheduler",
"info": "Name of learning rate scheduler.", "info": "Name of learning rate scheduler.",
}, },
"zh": { "zh": {"label": "学习率调节器", "info": "采用的学习率调节器名称。"},
"label": "学习率调节器",
"info": "采用的学习率调节器名称。"
}
}, },
"max_grad_norm": { "max_grad_norm": {
"en": { "en": {"label": "Maximum gradient norm", "info": "Norm for gradient clipping.."},
"label": "Maximum gradient norm", "zh": {"label": "最大梯度范数", "info": "用于梯度裁剪的范数。"},
"info": "Norm for gradient clipping.."
},
"zh": {
"label": "最大梯度范数",
"info": "用于梯度裁剪的范数。"
}
}, },
"val_size": { "val_size": {
"en": { "en": {"label": "Val size", "info": "Proportion of data in the dev set."},
"label": "Val size", "zh": {"label": "验证集比例", "info": "验证集占全部样本的百分比。"},
"info": "Proportion of data in the dev set."
},
"zh": {
"label": "验证集比例",
"info": "验证集占全部样本的百分比。"
}
},
"extra_tab": {
"en": {
"label": "Extra configurations"
},
"zh": {
"label": "其它参数设置"
}
}, },
"extra_tab": {"en": {"label": "Extra configurations"}, "zh": {"label": "其它参数设置"}},
"logging_steps": { "logging_steps": {
"en": { "en": {"label": "Logging steps", "info": "Number of steps between two logs."},
"label": "Logging steps", "zh": {"label": "日志间隔", "info": "每两次日志输出间的更新步数。"},
"info": "Number of steps between two logs."
},
"zh": {
"label": "日志间隔",
"info": "每两次日志输出间的更新步数。"
}
}, },
"save_steps": { "save_steps": {
"en": { "en": {"label": "Save steps", "info": "Number of steps between two checkpoints."},
"label": "Save steps", "zh": {"label": "保存间隔", "info": "每两次断点保存间的更新步数。"},
"info": "Number of steps between two checkpoints."
},
"zh": {
"label": "保存间隔",
"info": "每两次断点保存间的更新步数。"
}
}, },
"warmup_steps": { "warmup_steps": {
"en": { "en": {"label": "Warmup steps", "info": "Number of steps used for warmup."},
"label": "Warmup steps", "zh": {"label": "预热步数", "info": "学习率预热采用的步数。"},
"info": "Number of steps used for warmup."
},
"zh": {
"label": "预热步数",
"info": "学习率预热采用的步数。"
}
}, },
"neftune_alpha": { "neftune_alpha": {
"en": { "en": {"label": "NEFTune Alpha", "info": "Magnitude of noise adding to embedding vectors."},
"label": "NEFTune Alpha", "zh": {"label": "NEFTune 噪声参数", "info": "嵌入向量所添加的噪声大小。"},
"info": "Magnitude of noise adding to embedding vectors."
},
"zh": {
"label": "NEFTune 噪声参数",
"info": "嵌入向量所添加的噪声大小。"
}
}, },
"sft_packing": { "sft_packing": {
"en": { "en": {
"label": "Pack sequences", "label": "Pack sequences",
"info": "Pack sequences into samples of fixed length in supervised fine-tuning." "info": "Pack sequences into samples of fixed length in supervised fine-tuning.",
}, },
"zh": { "zh": {"label": "序列打包", "info": "在有监督微调阶段将序列打包为相同长度的样本。"},
"label": "序列打包",
"info": "在有监督微调阶段将序列打包为相同长度的样本。"
}
}, },
"upcast_layernorm": { "upcast_layernorm": {
"en": { "en": {"label": "Upcast LayerNorm", "info": "Upcast weights of layernorm in float32."},
"label": "Upcast LayerNorm", "zh": {"label": "缩放归一化层", "info": "将归一化层权重缩放至 32 位精度。"},
"info": "Upcast weights of layernorm in float32."
},
"zh": {
"label": "缩放归一化层",
"info": "将归一化层权重缩放至 32 位精度。"
}
},
"lora_tab": {
"en": {
"label": "LoRA configurations"
},
"zh": {
"label": "LoRA 参数设置"
}
}, },
"lora_tab": {"en": {"label": "LoRA configurations"}, "zh": {"label": "LoRA 参数设置"}},
"lora_rank": { "lora_rank": {
"en": { "en": {"label": "LoRA rank", "info": "The rank of LoRA matrices."},
"label": "LoRA rank", "zh": {"label": "LoRA 秩", "info": "LoRA 矩阵的秩。"},
"info": "The rank of LoRA matrices."
},
"zh": {
"label": "LoRA 秩",
"info": "LoRA 矩阵的秩。"
}
}, },
"lora_dropout": { "lora_dropout": {
"en": { "en": {"label": "LoRA Dropout", "info": "Dropout ratio of LoRA weights."},
"label": "LoRA Dropout", "zh": {"label": "LoRA 随机丢弃", "info": "LoRA 权重随机丢弃的概率。"},
"info": "Dropout ratio of LoRA weights."
},
"zh": {
"label": "LoRA 随机丢弃",
"info": "LoRA 权重随机丢弃的概率。"
}
}, },
"lora_target": { "lora_target": {
"en": { "en": {
"label": "LoRA modules (optional)", "label": "LoRA modules (optional)",
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules." "info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules.",
}, },
"zh": { "zh": {"label": "LoRA 作用模块(非必填)", "info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"},
"label": "LoRA 作用模块(非必填)",
"info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"
}
}, },
"additional_target": { "additional_target": {
"en": { "en": {
"label": "Additional modules (optional)", "label": "Additional modules (optional)",
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules." "info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules.",
}, },
"zh": { "zh": {"label": "附加模块(非必填)", "info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"},
"label": "附加模块(非必填)",
"info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"
}
}, },
"create_new_adapter": { "create_new_adapter": {
"en": { "en": {
"label": "Create new adapter", "label": "Create new adapter",
"info": "Whether to create a new adapter with randomly initialized weight or not." "info": "Whether to create a new adapter with randomly initialized weight or not.",
}, },
"zh": { "zh": {"label": "新建适配器", "info": "是否创建一个经过随机初始化的新适配器。"},
"label": "新建适配器",
"info": "是否创建一个经过随机初始化的新适配器。"
}
},
"rlhf_tab": {
"en": {
"label": "RLHF configurations"
},
"zh": {
"label": "RLHF 参数设置"
}
}, },
"rlhf_tab": {"en": {"label": "RLHF configurations"}, "zh": {"label": "RLHF 参数设置"}},
"dpo_beta": { "dpo_beta": {
"en": { "en": {"label": "DPO beta", "info": "Value of the beta parameter in the DPO loss."},
"label": "DPO beta", "zh": {"label": "DPO beta 参数", "info": "DPO 损失函数中 beta 超参数大小。"},
"info": "Value of the beta parameter in the DPO loss."
},
"zh": {
"label": "DPO beta 参数",
"info": "DPO 损失函数中 beta 超参数大小。"
}
}, },
"dpo_ftx": { "dpo_ftx": {
"en": { "en": {"label": "DPO-ftx weight", "info": "The weight of SFT loss in the DPO-ftx."},
"label": "DPO-ftx weight", "zh": {"label": "DPO-ftx 权重", "info": "DPO-ftx 中 SFT 损失的权重大小。"},
"info": "The weight of SFT loss in the DPO-ftx."
},
"zh": {
"label": "DPO-ftx 权重",
"info": "DPO-ftx 中 SFT 损失的权重大小。"
}
}, },
"reward_model": { "reward_model": {
"en": { "en": {
"label": "Reward model", "label": "Reward model",
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)" "info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)",
}, },
"zh": { "zh": {"label": "奖励模型", "info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)"},
"label": "奖励模型",
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)"
}
},
"cmd_preview_btn": {
"en": {
"value": "Preview command"
},
"zh": {
"value": "预览命令"
}
},
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
}, },
"cmd_preview_btn": {"en": {"value": "Preview command"}, "zh": {"value": "预览命令"}},
"start_btn": {"en": {"value": "Start"}, "zh": {"value": "开始"}},
"stop_btn": {"en": {"value": "Abort"}, "zh": {"value": "中断"}},
"output_dir": { "output_dir": {
"en": { "en": {"label": "Output dir", "info": "Directory for saving results."},
"label": "Output dir", "zh": {"label": "输出目录", "info": "保存结果的路径。"},
"info": "Directory for saving results."
},
"zh": {
"label": "输出目录",
"info": "保存结果的路径。"
}
},
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"loss_viewer": {
"en": {
"label": "Loss"
},
"zh": {
"label": "损失"
}
},
"predict": {
"en": {
"label": "Save predictions"
},
"zh": {
"label": "保存预测结果"
}
},
"load_btn": {
"en": {
"value": "Load model"
},
"zh": {
"value": "加载模型"
}
},
"unload_btn": {
"en": {
"value": "Unload model"
},
"zh": {
"value": "卸载模型"
}
},
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"system": {
"en": {
"placeholder": "System prompt (optional)"
},
"zh": {
"placeholder": "系统提示词(非必填)"
}
},
"tools": {
"en": {
"placeholder": "Tools (optional)"
},
"zh": {
"placeholder": "工具列表(非必填)"
}
},
"query": {
"en": {
"placeholder": "Input..."
},
"zh": {
"placeholder": "输入..."
}
},
"submit_btn": {
"en": {
"value": "Submit"
},
"zh": {
"value": "提交"
}
},
"clear_btn": {
"en": {
"value": "Clear history"
},
"zh": {
"value": "清空历史"
}
},
"max_length": {
"en": {
"label": "Maximum length"
},
"zh": {
"label": "最大长度"
}
},
"max_new_tokens": {
"en": {
"label": "Maximum new tokens"
},
"zh": {
"label": "最大生成长度"
}
},
"top_p": {
"en": {
"label": "Top-p"
},
"zh": {
"label": "Top-p 采样值"
}
},
"temperature": {
"en": {
"label": "Temperature"
},
"zh": {
"label": "温度系数"
}
}, },
"output_box": {"en": {"value": "Ready."}, "zh": {"value": "准备就绪。"}},
"loss_viewer": {"en": {"label": "Loss"}, "zh": {"label": "损失"}},
"predict": {"en": {"label": "Save predictions"}, "zh": {"label": "保存预测结果"}},
"load_btn": {"en": {"value": "Load model"}, "zh": {"value": "加载模型"}},
"unload_btn": {"en": {"value": "Unload model"}, "zh": {"value": "卸载模型"}},
"info_box": {"en": {"value": "Model unloaded, please load a model first."}, "zh": {"value": "模型未加载,请先加载模型。"}},
"system": {"en": {"placeholder": "System prompt (optional)"}, "zh": {"placeholder": "系统提示词(非必填)"}},
"tools": {"en": {"placeholder": "Tools (optional)"}, "zh": {"placeholder": "工具列表(非必填)"}},
"query": {"en": {"placeholder": "Input..."}, "zh": {"placeholder": "输入..."}},
"submit_btn": {"en": {"value": "Submit"}, "zh": {"value": "提交"}},
"clear_btn": {"en": {"value": "Clear history"}, "zh": {"value": "清空历史"}},
"max_length": {"en": {"label": "Maximum length"}, "zh": {"label": "最大长度"}},
"max_new_tokens": {"en": {"label": "Maximum new tokens"}, "zh": {"label": "最大生成长度"}},
"top_p": {"en": {"label": "Top-p"}, "zh": {"label": "Top-p 采样值"}},
"temperature": {"en": {"label": "Temperature"}, "zh": {"label": "温度系数"}},
"max_shard_size": { "max_shard_size": {
"en": { "en": {"label": "Max shard size (GB)", "info": "The maximum size for a model file."},
"label": "Max shard size (GB)", "zh": {"label": "最大分块大小GB", "info": "单个模型文件的最大大小。"},
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小GB",
"info": "单个模型文件的最大大小。"
}
}, },
"export_quantization_bit": { "export_quantization_bit": {
"en": { "en": {"label": "Export quantization bit.", "info": "Quantizing the exported model."},
"label": "Export quantization bit.", "zh": {"label": "导出量化等级", "info": "量化导出模型。"},
"info": "Quantizing the exported model."
},
"zh": {
"label": "导出量化等级",
"info": "量化导出模型。"
}
}, },
"export_quantization_dataset": { "export_quantization_dataset": {
"en": { "en": {"label": "Export quantization dataset.", "info": "The calibration dataset used for quantization."},
"label": "Export quantization dataset.", "zh": {"label": "导出量化数据集", "info": "量化过程中使用的校准数据集。"},
"info": "The calibration dataset used for quantization."
},
"zh": {
"label": "导出量化数据集",
"info": "量化过程中使用的校准数据集。"
}
}, },
"export_dir": { "export_dir": {
"en": { "en": {"label": "Export dir", "info": "Directory to save exported model."},
"label": "Export dir", "zh": {"label": "导出目录", "info": "保存导出模型的文件夹路径。"},
"info": "Directory to save exported model."
},
"zh": {
"label": "导出目录",
"info": "保存导出模型的文件夹路径。"
}
}, },
"export_btn": { "export_btn": {"en": {"value": "Export"}, "zh": {"value": "开始导出"}},
"en": {
"value": "Export"
},
"zh": {
"value": "开始导出"
}
}
} }
ALERTS = { ALERTS = {
"err_conflict": { "err_conflict": {"en": "A process is in running, please abort it firstly.", "zh": "任务已存在,请先中断训练。"},
"en": "A process is in running, please abort it firstly.", "err_exists": {"en": "You have loaded a model, please unload it first.", "zh": "模型已存在,请先卸载模型。"},
"zh": "任务已存在,请先中断训练。" "err_no_model": {"en": "Please select a model.", "zh": "请选择模型。"},
}, "err_no_path": {"en": "Model not found.", "zh": "模型未找到。"},
"err_exists": { "err_no_dataset": {"en": "Please choose a dataset.", "zh": "请选择数据集。"},
"en": "You have loaded a model, please unload it first.", "err_no_adapter": {"en": "Please select an adapter.", "zh": "请选择一个适配器。"},
"zh": "模型已存在,请先卸载模型。" "err_no_export_dir": {"en": "Please provide export dir.", "zh": "请填写导出目录"},
}, "err_failed": {"en": "Failed.", "zh": "训练出错。"},
"err_no_model": {
"en": "Please select a model.",
"zh": "请选择模型。"
},
"err_no_path": {
"en": "Model not found.",
"zh": "模型未找到。"
},
"err_no_dataset": {
"en": "Please choose a dataset.",
"zh": "请选择数据集。"
},
"err_no_adapter": {
"en": "Please select an adapter.",
"zh": "请选择一个适配器。"
},
"err_no_export_dir": {
"en": "Please provide export dir.",
"zh": "请填写导出目录"
},
"err_failed": {
"en": "Failed.",
"zh": "训练出错。"
},
"err_demo": { "err_demo": {
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.", "en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
"zh": "展示模式不支持训练,请先复制到私人空间。" "zh": "展示模式不支持训练,请先复制到私人空间。",
}, },
"err_device_count": { "err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
"en": "Multiple GPUs are not supported yet.", "info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
"zh": "尚不支持多 GPU 训练。" "info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
}, "info_finished": {"en": "Finished.", "zh": "训练完毕。"},
"info_aborting": { "info_loading": {"en": "Loading model...", "zh": "加载中……"},
"en": "Aborted, wait for terminating...", "info_unloading": {"en": "Unloading model...", "zh": "卸载中……"},
"zh": "训练中断,正在等待线程结束……" "info_loaded": {"en": "Model loaded, now you can chat with your model!", "zh": "模型已加载,可以开始聊天了!"},
}, "info_unloaded": {"en": "Model unloaded.", "zh": "模型已卸载。"},
"info_aborted": { "info_exporting": {"en": "Exporting model...", "zh": "正在导出模型……"},
"en": "Ready.", "info_exported": {"en": "Model exported.", "zh": "模型导出完成。"},
"zh": "准备就绪。"
},
"info_finished": {
"en": "Finished.",
"zh": "训练完毕。"
},
"info_loading": {
"en": "Loading model...",
"zh": "加载中……"
},
"info_unloading": {
"en": "Unloading model...",
"zh": "卸载中……"
},
"info_loaded": {
"en": "Model loaded, now you can chat with your model!",
"zh": "模型已加载,可以开始聊天了!"
},
"info_unloaded": {
"en": "Model unloaded.",
"zh": "模型已卸载。"
},
"info_exporting": {
"en": "Exporting model...",
"zh": "正在导出模型……"
},
"info_exported": {
"en": "Model exported.",
"zh": "模型导出完成。"
}
} }

View File

@ -1,11 +1,11 @@
from typing import TYPE_CHECKING, Dict, List, Set from typing import TYPE_CHECKING, Dict, List, Set
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
class Manager: class Manager:
def __init__(self) -> None: def __init__(self) -> None:
self.all_elems: Dict[str, Dict[str, "Component"]] = {} self.all_elems: Dict[str, Dict[str, "Component"]] = {}
@ -26,7 +26,7 @@ class Manager:
self.all_elems["top"]["quantization_bit"], self.all_elems["top"]["quantization_bit"],
self.all_elems["top"]["template"], self.all_elems["top"]["template"],
self.all_elems["top"]["rope_scaling"], self.all_elems["top"]["rope_scaling"],
self.all_elems["top"]["booster"] self.all_elems["top"]["booster"],
} }
def list_elems(self) -> List["Component"]: def list_elems(self) -> List["Component"]:

View File

@ -1,12 +1,12 @@
import logging
import os import os
import time import time
import logging
import gradio as gr
from threading import Thread from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import gradio as gr
import transformers import transformers
from gradio.components import Component # cannot use TYPE_CHECKING here
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.callbacks import LogCallback from ..extras.callbacks import LogCallback
@ -18,12 +18,12 @@ from .common import get_module, get_save_dir, load_config
from .locales import ALERTS from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar from .utils import gen_cmd, get_eval_results, update_process_bar
if TYPE_CHECKING: if TYPE_CHECKING:
from .manager import Manager from .manager import Manager
class Runner: class Runner:
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None: def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
@ -90,9 +90,12 @@ class Runner:
user_config = load_config() user_config = load_config()
if get("top.adapter_path"): if get("top.adapter_path"):
adapter_name_or_path = ",".join([ adapter_name_or_path = ",".join(
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) [
for adapter in get("top.adapter_path")]) get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else: else:
adapter_name_or_path = None adapter_name_or_path = None
@ -131,12 +134,12 @@ class Runner:
create_new_adapter=get("train.create_new_adapter"), create_new_adapter=get("train.create_new_adapter"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")), output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"), fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16") bf16=(get("train.compute_type") == "bf16"),
) )
args["disable_tqdm"] = True args["disable_tqdm"] = True
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]: if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
args["create_new_adapter"] = (args["quantization_bit"] is None) args["create_new_adapter"] = args["quantization_bit"] is None
if args["stage"] == "ppo": if args["stage"] == "ppo":
args["reward_model"] = get_save_dir( args["reward_model"] = get_save_dir(
@ -161,9 +164,12 @@ class Runner:
user_config = load_config() user_config = load_config()
if get("top.adapter_path"): if get("top.adapter_path"):
adapter_name_or_path = ",".join([ adapter_name_or_path = ",".join(
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) [
for adapter in get("top.adapter_path")]) get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else: else:
adapter_name_or_path = None adapter_name_or_path = None
@ -187,7 +193,7 @@ class Runner:
max_new_tokens=get("eval.max_new_tokens"), max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"), top_p=get("eval.top_p"),
temperature=get("eval.temperature"), temperature=get("eval.temperature"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")) output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
) )
if get("eval.predict"): if get("eval.predict"):
@ -197,7 +203,9 @@ class Runner:
return args return args
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: def _preview(
self, data: Dict[Component, Any], do_train: bool
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train, from_preview=True) error = self._initialize(data, do_train, from_preview=True)
if error: if error:
gr.Warning(error) gr.Warning(error)
@ -235,9 +243,11 @@ class Runner:
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)] get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
self.running = True self.running = True
lang = get("top.lang") lang = get("top.lang")
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get( output_dir = get_save_dir(
"{}.output_dir".format("train" if self.do_train else "eval") get("top.model_name"),
)) get("top.finetuning_type"),
get("{}.output_dir".format("train" if self.do_train else "eval")),
)
while self.thread.is_alive(): while self.thread.is_alive():
time.sleep(2) time.sleep(2)

View File

@ -1,13 +1,15 @@
import os
import json import json
import gradio as gr import os
from typing import TYPE_CHECKING, Any, Dict
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict
import gradio as gr
from ..extras.packages import is_matplotlib_available from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth from ..extras.ploting import smooth
from .common import get_save_dir from .common import get_save_dir
if TYPE_CHECKING: if TYPE_CHECKING:
from ..extras.callbacks import LogCallback from ..extras.callbacks import LogCallback
@ -22,16 +24,13 @@ def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0 percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
label = "Running {:d}/{:d}: {} < {}".format( label = "Running {:d}/{:d}: {} < {}".format(
callback.cur_steps, callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
callback.max_steps,
callback.elapsed_time,
callback.remaining_time
) )
return gr.update(label=label, value=percentage, visible=True) return gr.update(label=label, value=percentage, visible=True)
def get_time() -> str: def get_time() -> str:
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
def can_quantize(finetuning_type: str) -> Dict[str, Any]: def can_quantize(finetuning_type: str) -> Dict[str, Any]:

View File

@ -3,11 +3,12 @@
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 # Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/ # Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
from typing import Optional
import fire import fire
import torch import torch
from typing import Optional from deepspeed.accelerator import get_accelerator # type: ignore
from deepspeed.accelerator import get_accelerator # type: ignore from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from llmtuner import ChatModel from llmtuner import ChatModel
@ -16,25 +17,13 @@ def calculate_flops(
model_name_or_path: str, model_name_or_path: str,
batch_size: Optional[int] = 1, batch_size: Optional[int] = 1,
seq_length: Optional[int] = 256, seq_length: Optional[int] = 256,
flash_attn: Optional[bool] = False flash_attn: Optional[bool] = False,
): ):
with get_accelerator().device(0): with get_accelerator().device(0):
chat_model = ChatModel(dict( chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="vanilla", flash_attn=flash_attn))
model_name_or_path=model_name_or_path,
template="vanilla",
flash_attn=flash_attn
))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
input_dict = { input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
"input_ids": fake_input, flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
"labels": fake_input.clone()
}
flops, macs, params = get_model_profile(
chat_model.model,
kwargs=input_dict,
print_profile=True,
detailed=True
)
print("FLOPs:", flops) print("FLOPs:", flops)
print("MACs:", macs) print("MACs:", macs)
print("Params:", params) print("Params:", params)

View File

@ -3,12 +3,13 @@
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16 # Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py # Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
import fire
import math import math
import torch
from tqdm import tqdm
from typing import Optional from typing import Optional
import fire
import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from llmtuner.data import get_dataset from llmtuner.data import get_dataset
@ -17,8 +18,8 @@ from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer from llmtuner.model import load_model_and_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
BASE_BS = 4_000_000 # from llama paper BASE_BS = 4_000_000 # from llama paper
def calculate_lr( def calculate_lr(
@ -26,18 +27,20 @@ def calculate_lr(
dataset: str, dataset: str,
cutoff_len: int, # i.e. maximum input length during training cutoff_len: int, # i.e. maximum input length during training
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
is_mistral: bool, # mistral model uses a smaller learning rate, is_mistral: bool, # mistral model uses a smaller learning rate,
dataset_dir: Optional[str] = "data" dataset_dir: Optional[str] = "data",
): ):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict( model_args, data_args, training_args, finetuning_args, _ = get_train_args(
stage="sft", dict(
model_name_or_path=model_name_or_path, stage="sft",
dataset=dataset, model_name_or_path=model_name_or_path,
dataset_dir=dataset_dir, dataset=dataset,
template="default", dataset_dir=dataset_dir,
cutoff_len=cutoff_len, template="default",
output_dir="dummy_dir" cutoff_len=cutoff_len,
)) output_dir="dummy_dir",
)
)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False) _, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft") trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
@ -49,14 +52,16 @@ def calculate_lr(
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item() valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"]) total_tokens += torch.numel(batch["labels"])
batch_max_len = cutoff_len * batch_size # max tokens in a batch batch_max_len = cutoff_len * batch_size # max tokens in a batch
valid_ratio = valid_tokens / total_tokens valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size) lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral else lr lr = lr / 6.0 if is_mistral else lr
print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format( print(
lr, valid_ratio * 100, batch_valid_len "Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
)) lr, valid_ratio * 100, batch_valid_len
)
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -4,32 +4,28 @@
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py # Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied # Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
import os
import fire
import json import json
import torch import os
from tqdm import tqdm
from collections import OrderedDict from collections import OrderedDict
from safetensors.torch import save_file
from transformers.modeling_utils import (
shard_checkpoint,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME
)
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
def save_weight( def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: bool
):
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict() baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"): for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
@ -41,8 +37,8 @@ def save_weight(
if "W_pack" in key: if "W_pack" in key:
proj_size = value.size(0) // 3 proj_size = value.size(0) // 3
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :] llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size:2*proj_size, :] llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size : 2 * proj_size, :]
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*proj_size:, :] llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2 * proj_size :, :]
elif "lm_head" in key: elif "lm_head" in key:
llama2_state_dict[key] = torch.nn.functional.normalize(value) llama2_state_dict[key] = torch.nn.functional.normalize(value)
else: else:
@ -56,7 +52,7 @@ def save_weight(
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"}) save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else: else:
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
else: else:
@ -66,10 +62,7 @@ def save_weight(
print("Model weights saved in {}".format(output_dir)) print("Model weights saved in {}".format(output_dir))
def save_config( def save_config(input_dir: str, output_dir: str):
input_dir: str,
output_dir: str
):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f) llama2_config_dict: Dict[str, Any] = json.load(f)
@ -83,19 +76,14 @@ def save_config(
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_baichuan2( def llamafy_baichuan2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: Optional[bool] = False
):
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:
raise print("Output dir already exists", e) raise print("Output dir already exists", e)
save_weight(input_dir, output_dir, shard_size, save_safetensors) save_weight(input_dir, output_dir, shard_size, save_safetensors)
save_config(input_dir, output_dir) save_config(input_dir, output_dir)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,32 +3,28 @@
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB # Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later. # Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
import os
import fire
import json import json
import torch import os
from tqdm import tqdm
from collections import OrderedDict from collections import OrderedDict
from safetensors.torch import save_file
from transformers.modeling_utils import (
shard_checkpoint,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME
)
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
def save_weight( def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: bool
):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
internlm2_config_dict: Dict[str, Any] = json.load(f) internlm2_config_dict: Dict[str, Any] = json.load(f)
@ -50,8 +46,10 @@ def save_weight(
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...] llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...] llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...] q_size : q_size + kv_size, ...
]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...]
elif "wo" in key: elif "wo" in key:
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
elif "attention_norm" in key: elif "attention_norm" in key:
@ -85,10 +83,7 @@ def save_weight(
print("Model weights saved in {}".format(output_dir)) print("Model weights saved in {}".format(output_dir))
def save_config( def save_config(input_dir: str, output_dir: str):
input_dir: str,
output_dir: str
):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f) llama2_config_dict: Dict[str, Any] = json.load(f)
@ -103,12 +98,7 @@ def save_config(
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_internlm2( def llamafy_internlm2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: Optional[bool] = False
):
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:

View File

@ -3,39 +3,36 @@
# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB # Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied # Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
import os
import fire
import json import json
import torch import os
from tqdm import tqdm
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Optional
import fire
import torch
from safetensors import safe_open from safetensors import safe_open
from safetensors.torch import save_file from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import ( from transformers.modeling_utils import (
shard_checkpoint,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
WEIGHTS_INDEX_NAME shard_checkpoint,
) )
from transformers.utils import check_min_version from transformers.utils import check_min_version
from typing import Any, Dict, Optional
try: try:
check_min_version("4.34.0") check_min_version("4.34.0")
except: except Exception:
raise ValueError("Please upgrade `transformers` to 4.34.0") raise ValueError("Please upgrade `transformers` to 4.34.0")
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
def save_weight( def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: bool
) -> str:
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict() qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"): for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
@ -57,13 +54,15 @@ def save_weight(
if "attn.c_attn" in key: if "attn.c_attn" in key:
proj_size = value.size(0) // 3 proj_size = value.size(0) // 3
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...] llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[proj_size:2*proj_size, ...] llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2*proj_size:, ...] proj_size : 2 * proj_size, ...
]
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2 * proj_size :, ...]
elif "attn.c_proj" in key: elif "attn.c_proj" in key:
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = ( llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = torch.zeros_like(
torch.zeros_like(value[:, 0]).squeeze() value[:, 0]
) ).squeeze()
elif "ln_1" in key: elif "ln_1" in key:
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
elif "ln_2" in key: elif "ln_2" in key:
@ -99,11 +98,7 @@ def save_weight(
return str(torch_dtype).replace("torch.", "") return str(torch_dtype).replace("torch.", "")
def save_config( def save_config(input_dir: str, output_dir: str, torch_dtype: str):
input_dir: str,
output_dir: str,
torch_dtype: str
):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f) qwen_config_dict: Dict[str, Any] = json.load(f)
@ -133,12 +128,7 @@ def save_config(
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_qwen( def llamafy_qwen(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: Optional[bool] = False
):
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:

View File

@ -4,12 +4,13 @@
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py # Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
import os import os
from typing import TYPE_CHECKING, Optional
import fire import fire
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import TYPE_CHECKING, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -17,7 +18,6 @@ if TYPE_CHECKING:
class Shell(nn.Module): class Shell(nn.Module):
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
super().__init__() super().__init__()
self.weight = nn.Parameter(weight, requires_grad=False) self.weight = nn.Parameter(weight, requires_grad=False)
@ -26,7 +26,7 @@ class Shell(nn.Module):
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None: def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]): for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]): # noqa: C403
parent_name = ".".join(name.split(".")[:-1]) parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1] child_name = name.split(".")[-1]
parent_module = model.get_submodule(parent_name) parent_module = model.get_submodule(parent_name)
@ -35,7 +35,7 @@ def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
weight = getattr(base_layer, "weight", None) weight = getattr(base_layer, "weight", None)
bias = getattr(base_layer, "bias", None) bias = getattr(base_layer, "bias", None)
setattr(parent_module, child_name, Shell(weight, bias)) setattr(parent_module, child_name, Shell(weight, bias))
print("Model unwrapped.") print("Model unwrapped.")
@ -60,7 +60,7 @@ def quantize_loftq(
lora_dropout=0.1, lora_dropout=0.1,
target_modules=[name.strip() for name in lora_target.split(",")], target_modules=[name.strip() for name in lora_target.split(",")],
init_lora_weights="loftq", init_lora_weights="loftq",
loftq_config=loftq_config loftq_config=loftq_config,
) )
# Init LoftQ model # Init LoftQ model