forked from p04798526/LLaMA-Factory-Mirror
format style
This commit is contained in:
parent
f6d6e00337
commit
638234ceee
|
@ -7,7 +7,7 @@ line-length = 119
|
||||||
target-version = ["py38"]
|
target-version = ["py38"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
ignore = ["C901", "E501", "E741", "W605"]
|
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
|
||||||
select = ["C", "E", "F", "I", "W"]
|
select = ["C", "E", "F", "I", "W"]
|
||||||
line-length = 119
|
line-length = 119
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from llmtuner import ChatModel, create_app
|
from llmtuner import ChatModel, create_app
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
from llmtuner import ChatModel
|
from llmtuner import ChatModel
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
if platform.system() != "Windows":
|
if platform.system() != "Windows":
|
||||||
import readline
|
import readline # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Install `readline` for a better experience.")
|
print("Install `readline` for a better experience.")
|
||||||
|
|
||||||
|
|
|
@ -8,12 +8,4 @@ from llmtuner.webui import create_ui, create_web_demo
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.4.0"
|
__version__ = "0.4.0"
|
||||||
__all__ = [
|
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||||
"create_app",
|
|
||||||
"ChatModel",
|
|
||||||
"Evaluator",
|
|
||||||
"export_model",
|
|
||||||
"run_exp",
|
|
||||||
"create_ui",
|
|
||||||
"create_web_demo"
|
|
||||||
]
|
|
||||||
|
|
|
@ -1,30 +1,29 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Tuple
|
import json
|
||||||
from pydantic import BaseModel
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..chat import ChatModel
|
||||||
|
from ..extras.misc import torch_gc
|
||||||
|
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||||
from .protocol import (
|
from .protocol import (
|
||||||
Role,
|
|
||||||
Finish,
|
|
||||||
ModelCard,
|
|
||||||
ModelList,
|
|
||||||
ChatMessage,
|
|
||||||
DeltaMessage,
|
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionStreamResponse,
|
|
||||||
ChatCompletionResponseChoice,
|
ChatCompletionResponseChoice,
|
||||||
ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionResponseUsage,
|
ChatCompletionResponseUsage,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
|
ChatMessage,
|
||||||
|
DeltaMessage,
|
||||||
|
Finish,
|
||||||
|
ModelCard,
|
||||||
|
ModelList,
|
||||||
|
Role,
|
||||||
ScoreEvaluationRequest,
|
ScoreEvaluationRequest,
|
||||||
ScoreEvaluationResponse
|
ScoreEvaluationResponse,
|
||||||
)
|
|
||||||
from ..chat import ChatModel
|
|
||||||
from ..extras.misc import torch_gc
|
|
||||||
from ..extras.packages import (
|
|
||||||
is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,15 +41,15 @@ if is_uvicorn_available():
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||||
yield
|
yield
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def to_json(data: BaseModel) -> str:
|
def to_json(data: BaseModel) -> str:
|
||||||
try: # pydantic v2
|
try: # pydantic v2
|
||||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||||
except: # pydantic v1
|
except Exception: # pydantic v1
|
||||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,8 +89,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
history = []
|
history = []
|
||||||
if len(prev_messages) % 2 == 0:
|
if len(prev_messages) % 2 == 0:
|
||||||
for i in range(0, len(prev_messages), 2):
|
for i in range(0, len(prev_messages), 2):
|
||||||
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT:
|
||||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
history.append([prev_messages[i].content, prev_messages[i + 1].content])
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
else:
|
else:
|
||||||
|
@ -107,65 +106,65 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
responses = chat_model.chat(
|
responses = chat_model.chat(
|
||||||
query, history, system,
|
query,
|
||||||
|
history,
|
||||||
|
system,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
max_new_tokens=request.max_tokens,
|
max_new_tokens=request.max_tokens,
|
||||||
num_return_sequences=request.n
|
num_return_sequences=request.n,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_length, response_length = 0, 0
|
prompt_length, response_length = 0, 0
|
||||||
choices = []
|
choices = []
|
||||||
for i, response in enumerate(responses):
|
for i, response in enumerate(responses):
|
||||||
choices.append(ChatCompletionResponseChoice(
|
choices.append(
|
||||||
index=i,
|
ChatCompletionResponseChoice(
|
||||||
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
index=i,
|
||||||
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
||||||
))
|
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH,
|
||||||
|
)
|
||||||
|
)
|
||||||
prompt_length = response.prompt_length
|
prompt_length = response.prompt_length
|
||||||
response_length += response.response_length
|
response_length += response.response_length
|
||||||
|
|
||||||
usage = ChatCompletionResponseUsage(
|
usage = ChatCompletionResponseUsage(
|
||||||
prompt_tokens=prompt_length,
|
prompt_tokens=prompt_length,
|
||||||
completion_tokens=response_length,
|
completion_tokens=response_length,
|
||||||
total_tokens=prompt_length+response_length
|
total_tokens=prompt_length + response_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||||
|
|
||||||
def stream_chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
def stream_chat_completion(
|
||||||
|
query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest
|
||||||
|
):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||||
delta=DeltaMessage(role=Role.ASSISTANT, content=""),
|
|
||||||
finish_reason=None
|
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield to_json(chunk)
|
||||||
|
|
||||||
for new_text in chat_model.stream_chat(
|
for new_text in chat_model.stream_chat(
|
||||||
query, history, system,
|
query,
|
||||||
|
history,
|
||||||
|
system,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
max_new_tokens=request.max_tokens
|
max_new_tokens=request.max_tokens,
|
||||||
):
|
):
|
||||||
if len(new_text) == 0:
|
if len(new_text) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
|
||||||
delta=DeltaMessage(content=new_text),
|
|
||||||
finish_reason=None
|
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield to_json(chunk)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
|
||||||
index=0,
|
|
||||||
delta=DeltaMessage(),
|
|
||||||
finish_reason=Finish.STOP
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield to_json(chunk)
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import time
|
import time
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
class Role(str, Enum):
|
class Role(str, Enum):
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer, Role
|
from ..data import Role, get_template_and_fix_tokenizer
|
||||||
from ..extras.misc import get_logits_processor
|
from ..extras.misc import get_logits_processor
|
||||||
from ..model import dispatch_model, load_model_and_tokenizer
|
|
||||||
from ..hparams import get_infer_args
|
from ..hparams import get_infer_args
|
||||||
|
from ..model import dispatch_model, load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Response:
|
class Response:
|
||||||
|
|
||||||
response_text: str
|
response_text: str
|
||||||
response_length: int
|
response_length: int
|
||||||
prompt_length: int
|
prompt_length: int
|
||||||
|
@ -20,10 +20,9 @@ class Response:
|
||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
|
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||||
self.can_generate = (finetuning_args.stage == "sft")
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(
|
self.model, self.tokenizer = load_model_and_tokenizer(
|
||||||
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||||
)
|
)
|
||||||
|
@ -37,7 +36,7 @@ class ChatModel:
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs,
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
messages = []
|
messages = []
|
||||||
if history is not None:
|
if history is not None:
|
||||||
|
@ -63,16 +62,18 @@ class ChatModel:
|
||||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
|
|
||||||
generating_args = self.generating_args.to_dict()
|
generating_args = self.generating_args.to_dict()
|
||||||
generating_args.update(dict(
|
generating_args.update(
|
||||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
dict(
|
||||||
temperature=temperature or generating_args["temperature"],
|
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||||
top_p=top_p or generating_args["top_p"],
|
temperature=temperature or generating_args["temperature"],
|
||||||
top_k=top_k or generating_args["top_k"],
|
top_p=top_p or generating_args["top_p"],
|
||||||
num_return_sequences=num_return_sequences or 1,
|
top_k=top_k or generating_args["top_k"],
|
||||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
num_return_sequences=num_return_sequences or 1,
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||||
pad_token_id=self.tokenizer.pad_token_id
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
))
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||||
generating_args["do_sample"] = True
|
generating_args["do_sample"] = True
|
||||||
|
@ -88,7 +89,7 @@ class ChatModel:
|
||||||
gen_kwargs = dict(
|
gen_kwargs = dict(
|
||||||
inputs=input_ids,
|
inputs=input_ids,
|
||||||
generation_config=GenerationConfig(**generating_args),
|
generation_config=GenerationConfig(**generating_args),
|
||||||
logits_processor=get_logits_processor()
|
logits_processor=get_logits_processor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
return gen_kwargs, prompt_length
|
||||||
|
@ -100,7 +101,7 @@ class ChatModel:
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs,
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
r"""
|
r"""
|
||||||
Args: query, history, system, **input_kwargs
|
Args: query, history, system, **input_kwargs
|
||||||
|
@ -117,12 +118,14 @@ class ChatModel:
|
||||||
for i in range(len(response)):
|
for i in range(len(response)):
|
||||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
||||||
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
||||||
results.append(Response(
|
results.append(
|
||||||
response_text=response[i],
|
Response(
|
||||||
response_length=response_length,
|
response_text=response[i],
|
||||||
prompt_length=prompt_length,
|
response_length=response_length,
|
||||||
finish_reason="stop" if len(eos_index) else "length"
|
prompt_length=prompt_length,
|
||||||
))
|
finish_reason="stop" if len(eos_index) else "length",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -133,7 +136,7 @@ class ChatModel:
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
|
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
@ -145,11 +148,7 @@ class ChatModel:
|
||||||
yield from streamer
|
yield from streamer
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_scores(
|
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
|
||||||
self,
|
|
||||||
batch_input: List[str],
|
|
||||||
**input_kwargs
|
|
||||||
) -> List[float]:
|
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
device = getattr(self.model.pretrained_model, "device", "cuda")
|
||||||
|
|
||||||
|
@ -159,7 +158,7 @@ class ChatModel:
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
add_special_tokens=True
|
add_special_tokens=True,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
input_ids: torch.Tensor = inputs["input_ids"]
|
input_ids: torch.Tensor = inputs["input_ids"]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from .loader import get_dataset
|
from .loader import get_dataset
|
||||||
from .template import get_template_and_fix_tokenizer, templates
|
from .template import get_template_and_fix_tokenizer, templates
|
||||||
from .utils import split_dataset, Role
|
from .utils import Role, split_dataset
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"]
|
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"]
|
||||||
|
|
|
@ -27,7 +27,9 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||||
|
|
||||||
if dataset_attr.response:
|
if dataset_attr.response:
|
||||||
if isinstance(examples[dataset_attr.response][i], list):
|
if isinstance(examples[dataset_attr.response][i], list):
|
||||||
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
|
response = [
|
||||||
|
{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
||||||
else:
|
else:
|
||||||
|
@ -47,10 +49,10 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
||||||
dataset_attr.user_tag: Role.USER,
|
dataset_attr.user_tag: Role.USER,
|
||||||
dataset_attr.assistant_tag: Role.ASSISTANT,
|
dataset_attr.assistant_tag: Role.ASSISTANT,
|
||||||
dataset_attr.observation_tag: Role.OBSERVATION,
|
dataset_attr.observation_tag: Role.OBSERVATION,
|
||||||
dataset_attr.function_tag: Role.FUNCTION
|
dataset_attr.function_tag: Role.FUNCTION,
|
||||||
}
|
}
|
||||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||||
messages = messages[:len(messages) // 2 * 2] # should be multiples of 2
|
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||||
if len(messages) == 0:
|
if len(messages) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -65,7 +67,9 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
||||||
if message[dataset_attr.role_tag] not in accept_tags:
|
if message[dataset_attr.role_tag] not in accept_tags:
|
||||||
raise ValueError("Invalid role tag in {}.".format(messages))
|
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||||
|
|
||||||
prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]})
|
prompt.append(
|
||||||
|
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||||
|
)
|
||||||
|
|
||||||
last_message = prompt.pop(-1)
|
last_message = prompt.pop(-1)
|
||||||
response.append(last_message)
|
response.append(last_message)
|
||||||
|
@ -98,12 +102,7 @@ def align_dataset(
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
load_from_cache_file=(not data_args.overwrite_cache),
|
||||||
desc="Converting format of dataset"
|
desc="Converting format of dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataset.map(
|
return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs)
|
||||||
convert_func,
|
|
||||||
batched=True,
|
|
||||||
remove_columns=column_names,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
|
@ -76,7 +76,11 @@ class ToolFormatter:
|
||||||
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
||||||
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
||||||
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
|
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
|
||||||
name=name, type=param.get("type", ""), required=required, desc=param.get("description", ""), enum=enum
|
name=name,
|
||||||
|
type=param.get("type", ""),
|
||||||
|
required=required,
|
||||||
|
desc=param.get("description", ""),
|
||||||
|
enum=enum,
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||||
|
@ -85,9 +89,7 @@ class ToolFormatter:
|
||||||
tool_names.append(tool["name"])
|
tool_names.append(tool["name"])
|
||||||
|
|
||||||
return TOOL_SYSTEM_PROMPT.format(
|
return TOOL_SYSTEM_PROMPT.format(
|
||||||
tool_text=tool_text,
|
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
||||||
tool_names=", ".join(tool_names),
|
|
||||||
format_prompt=JSON_FORMAT_PROMPT
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
|
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
import os
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
from typing import TYPE_CHECKING, List, Literal, Union
|
from typing import TYPE_CHECKING, List, Literal, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
||||||
|
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .utils import checksum
|
|
||||||
from .parser import get_dataset_list
|
|
||||||
from .aligner import align_dataset
|
from .aligner import align_dataset
|
||||||
from .template import get_template_and_fix_tokenizer
|
from .parser import get_dataset_list
|
||||||
from .preprocess import get_preprocess_and_print_func
|
from .preprocess import get_preprocess_and_print_func
|
||||||
|
from .template import get_template_and_fix_tokenizer
|
||||||
|
from .utils import checksum
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -18,8 +18,8 @@ if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from ..hparams import DataArguments, ModelArguments
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
from ..hparams import ModelArguments, DataArguments
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -44,14 +44,14 @@ def load_single_dataset(
|
||||||
elif dataset_attr.load_from == "file":
|
elif dataset_attr.load_from == "file":
|
||||||
data_files = []
|
data_files = []
|
||||||
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
if os.path.isdir(local_path): # is directory
|
if os.path.isdir(local_path): # is directory
|
||||||
for file_name in os.listdir(local_path):
|
for file_name in os.listdir(local_path):
|
||||||
data_files.append(os.path.join(local_path, file_name))
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||||
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
||||||
raise ValueError("File types should be identical.")
|
raise ValueError("File types should be identical.")
|
||||||
elif os.path.isfile(local_path): # is file
|
elif os.path.isfile(local_path): # is file
|
||||||
data_files.append(local_path)
|
data_files.append(local_path)
|
||||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||||
else:
|
else:
|
||||||
|
@ -78,12 +78,12 @@ def load_single_dataset(
|
||||||
split=data_args.split,
|
split=data_args.split,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.ms_hub_token,
|
token=model_args.ms_hub_token,
|
||||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
).to_hf_dataset()
|
).to_hf_dataset()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||||
else:
|
else:
|
||||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||||
kwargs = {"trust_remote_code": True}
|
kwargs = {"trust_remote_code": True}
|
||||||
else:
|
else:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
@ -97,13 +97,13 @@ def load_single_dataset(
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||||
|
|
||||||
if data_args.max_samples is not None: # truncate dataset
|
if data_args.max_samples is not None: # truncate dataset
|
||||||
num_samples = min(data_args.max_samples, len(dataset))
|
num_samples = min(data_args.max_samples, len(dataset))
|
||||||
dataset = dataset.select(range(num_samples))
|
dataset = dataset.select(range(num_samples))
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ def load_single_dataset(
|
||||||
def merge_dataset(
|
def merge_dataset(
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments"
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
if len(all_datasets) == 1:
|
if len(all_datasets) == 1:
|
||||||
return all_datasets[0]
|
return all_datasets[0]
|
||||||
|
@ -128,7 +128,7 @@ def merge_dataset(
|
||||||
datasets=all_datasets,
|
datasets=all_datasets,
|
||||||
probabilities=data_args.interleave_probs,
|
probabilities=data_args.interleave_probs,
|
||||||
seed=training_args.seed,
|
seed=training_args.seed,
|
||||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown mixing strategy.")
|
raise ValueError("Unknown mixing strategy.")
|
||||||
|
@ -160,7 +160,7 @@ def get_dataset(
|
||||||
|
|
||||||
with training_args.main_process_first(desc="load dataset"):
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
all_datasets = []
|
all_datasets = []
|
||||||
for dataset_attr in get_dataset_list(data_args): # TODO: add split
|
for dataset_attr in get_dataset_list(data_args): # TODO: add split
|
||||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||||
|
|
||||||
|
@ -174,15 +174,10 @@ def get_dataset(
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
load_from_cache_file=(not data_args.overwrite_cache),
|
||||||
desc="Running tokenizer on dataset"
|
desc="Running tokenizer on dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||||
preprocess_func,
|
|
||||||
batched=True,
|
|
||||||
remove_columns=column_names,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||||
|
|
||||||
from ..extras.constants import DATA_CONFIG
|
from ..extras.constants import DATA_CONFIG
|
||||||
from ..extras.misc import use_modelscope
|
from ..extras.misc import use_modelscope
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatasetAttr:
|
class DatasetAttr:
|
||||||
|
|
||||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||||
dataset_name: Optional[str] = None
|
dataset_name: Optional[str] = None
|
||||||
dataset_sha1: Optional[str] = None
|
dataset_sha1: Optional[str] = None
|
||||||
|
@ -49,7 +49,9 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if data_args.dataset is not None:
|
if data_args.dataset is not None:
|
||||||
raise ValueError("Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)))
|
raise ValueError(
|
||||||
|
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||||
|
)
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
|
|
||||||
if data_args.interleave_probs is not None:
|
if data_args.interleave_probs is not None:
|
||||||
|
@ -74,7 +76,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||||
dataset_attr = DatasetAttr(
|
dataset_attr = DatasetAttr(
|
||||||
"file",
|
"file",
|
||||||
dataset_name=dataset_info[name]["file_name"],
|
dataset_name=dataset_info[name]["file_name"],
|
||||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
dataset_sha1=dataset_info[name].get("file_sha1", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
|
||||||
from ..extras.constants import IGNORE_INDEX
|
from ..extras.constants import IGNORE_INDEX
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
@ -17,9 +18,7 @@ logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(
|
def preprocess_pretrain_dataset(
|
||||||
examples: Dict[str, List[Any]],
|
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
data_args: "DataArguments"
|
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build grouped texts with format `X1 X2 X3 ...`
|
# build grouped texts with format `X1 X2 X3 ...`
|
||||||
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
|
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
|
||||||
|
@ -35,7 +34,7 @@ def preprocess_pretrain_dataset(
|
||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // block_size) * block_size
|
||||||
# split by chunks of cutoff_len
|
# split by chunks of cutoff_len
|
||||||
result = {
|
result = {
|
||||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||||
for k, t in concatenated_examples.items()
|
for k, t in concatenated_examples.items()
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
@ -57,9 +56,11 @@ def preprocess_supervised_dataset(
|
||||||
|
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
template.encode_multiturn(
|
||||||
)):
|
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||||
|
)
|
||||||
|
):
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
|
@ -96,9 +97,9 @@ def preprocess_packed_supervised_dataset(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||||
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
template.encode_multiturn(tokenizer, messages, examples["system"][i], examples["tools"][i])
|
||||||
)):
|
):
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
|
@ -119,9 +120,9 @@ def preprocess_packed_supervised_dataset(
|
||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // block_size) * block_size
|
||||||
# split by chunks of cutoff_len
|
# split by chunks of cutoff_len
|
||||||
for i in range(0, total_length, block_size):
|
for i in range(0, total_length, block_size):
|
||||||
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||||
model_inputs["attention_mask"].append([1] * block_size)
|
model_inputs["attention_mask"].append([1] * block_size)
|
||||||
model_inputs["labels"].append(labels[i: i + block_size])
|
model_inputs["labels"].append(labels[i : i + block_size])
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
@ -191,9 +192,11 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
print("labels:\n{}".format(
|
print(
|
||||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
"labels:\n{}".format(
|
||||||
))
|
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
@ -232,10 +235,14 @@ def get_preprocess_and_print_func(
|
||||||
|
|
||||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||||
elif stage == "rm":
|
elif stage == "rm":
|
||||||
preprocess_func = partial(preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
|
preprocess_func = partial(
|
||||||
|
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||||
else:
|
else:
|
||||||
preprocess_func = partial(preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
|
preprocess_func = partial(
|
||||||
|
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
|
||||||
return preprocess_func, print_function
|
return preprocess_func, print_function
|
||||||
|
|
|
@ -2,8 +2,8 @@ from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
from .formatter import FunctionFormatter, StringFormatter, ToolFormatter
|
||||||
from .utils import Role
|
from .utils import Role
|
||||||
from .formatter import StringFormatter, FunctionFormatter, ToolFormatter
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -15,7 +15,6 @@ logger = get_logger(__name__)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Template:
|
class Template:
|
||||||
|
|
||||||
format_user: Callable
|
format_user: Callable
|
||||||
format_assistant: Callable
|
format_assistant: Callable
|
||||||
format_system: Callable
|
format_system: Callable
|
||||||
|
@ -34,7 +33,7 @@ class Template:
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: Optional[int] = 1_000_000
|
cutoff_len: Optional[int] = 1_000_000,
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
r"""
|
r"""
|
||||||
Returns a single pair of token ids representing prompt and response respectively.
|
Returns a single pair of token ids representing prompt and response respectively.
|
||||||
|
@ -53,7 +52,7 @@ class Template:
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tools: str,
|
tools: str,
|
||||||
cutoff_len: Optional[int] = 1_000_000
|
cutoff_len: Optional[int] = 1_000_000,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
|
@ -67,7 +66,7 @@ class Template:
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tools: str,
|
tools: str,
|
||||||
cutoff_len: int
|
cutoff_len: int,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
|
@ -102,19 +101,17 @@ class Template:
|
||||||
if total_length >= cutoff_len:
|
if total_length >= cutoff_len:
|
||||||
break
|
break
|
||||||
|
|
||||||
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
|
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
|
||||||
total_length += len(encoded_messages[i])
|
total_length += len(encoded_messages[i])
|
||||||
|
|
||||||
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
|
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
|
||||||
total_length += len(encoded_messages[i+1])
|
total_length += len(encoded_messages[i + 1])
|
||||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
|
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||||
|
|
||||||
return encoded_pairs
|
return encoded_pairs
|
||||||
|
|
||||||
def _convert_elements_to_ids(
|
def _convert_elements_to_ids(
|
||||||
self,
|
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
elements: List[Union[str, Dict[str, str]]]
|
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
r"""
|
r"""
|
||||||
Converts elements to token ids.
|
Converts elements to token ids.
|
||||||
|
@ -139,14 +136,13 @@ class Template:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Llama2Template(Template):
|
class Llama2Template(Template):
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tools: str,
|
tools: str,
|
||||||
cutoff_len: int
|
cutoff_len: int,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
|
@ -182,12 +178,12 @@ class Llama2Template(Template):
|
||||||
if total_length >= cutoff_len:
|
if total_length >= cutoff_len:
|
||||||
break
|
break
|
||||||
|
|
||||||
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
|
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
|
||||||
total_length += len(encoded_messages[i])
|
total_length += len(encoded_messages[i])
|
||||||
|
|
||||||
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
|
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
|
||||||
total_length += len(encoded_messages[i+1])
|
total_length += len(encoded_messages[i + 1])
|
||||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
|
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||||
|
|
||||||
return encoded_pairs
|
return encoded_pairs
|
||||||
|
|
||||||
|
@ -207,32 +203,26 @@ def register_template(
|
||||||
separator: Optional[List[Union[str, Dict[str, str]]]] = "",
|
separator: Optional[List[Union[str, Dict[str, str]]]] = "",
|
||||||
stop_words: Optional[List[str]] = [],
|
stop_words: Optional[List[str]] = [],
|
||||||
efficient_eos: Optional[bool] = False,
|
efficient_eos: Optional[bool] = False,
|
||||||
replace_eos: Optional[bool] = False
|
replace_eos: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||||
templates[name] = template_class(
|
templates[name] = template_class(
|
||||||
format_user=format_user or StringFormatter(container=["{{content}}"]),
|
format_user=format_user or StringFormatter(container=["{{content}}"]),
|
||||||
format_assistant=format_assistant or StringFormatter(container=[
|
format_assistant=format_assistant or StringFormatter(container=["{{content}}", {"eos_token"}]),
|
||||||
"{{content}}", {"eos_token"}
|
|
||||||
]),
|
|
||||||
format_system=format_system or StringFormatter(container=["{{content}}"]),
|
format_system=format_system or StringFormatter(container=["{{content}}"]),
|
||||||
format_tool=format_tool or ToolFormatter(type="default"),
|
format_tool=format_tool or ToolFormatter(type="default"),
|
||||||
format_observation=format_observation or format_user,
|
format_observation=format_observation or format_user,
|
||||||
format_function=format_function or FunctionFormatter(container=[
|
format_function=format_function
|
||||||
"Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}
|
or FunctionFormatter(container=["Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}]),
|
||||||
]),
|
|
||||||
system=system,
|
system=system,
|
||||||
separator=separator,
|
separator=separator,
|
||||||
stop_words=stop_words,
|
stop_words=stop_words,
|
||||||
efficient_eos=efficient_eos,
|
efficient_eos=efficient_eos,
|
||||||
replace_eos=replace_eos
|
replace_eos=replace_eos,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_template_and_fix_tokenizer(
|
def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer") -> Template:
|
||||||
name: str,
|
|
||||||
tokenizer: "PreTrainedTokenizer"
|
|
||||||
) -> Template:
|
|
||||||
if tokenizer.eos_token_id is None:
|
if tokenizer.eos_token_id is None:
|
||||||
tokenizer.eos_token = "<|endoftext|>"
|
tokenizer.eos_token = "<|endoftext|>"
|
||||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||||
|
@ -241,7 +231,7 @@ def get_template_and_fix_tokenizer(
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||||
|
|
||||||
if name is None: # for pre-training
|
if name is None: # for pre-training
|
||||||
return None
|
return None
|
||||||
|
|
||||||
template = templates.get(name, None)
|
template = templates.get(name, None)
|
||||||
|
@ -258,8 +248,7 @@ def get_template_and_fix_tokenizer(
|
||||||
|
|
||||||
if stop_words:
|
if stop_words:
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
dict(additional_special_tokens=stop_words),
|
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||||
replace_additional_special_tokens=False
|
|
||||||
)
|
)
|
||||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
||||||
|
|
||||||
|
@ -268,263 +257,153 @@ def get_template_and_fix_tokenizer(
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="alpaca",
|
name="alpaca",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||||
"### Instruction:\n{{content}}\n\n### Response:\n"
|
|
||||||
]),
|
|
||||||
system=(
|
system=(
|
||||||
"Below is an instruction that describes a task. "
|
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
|
||||||
"Write a response that appropriately completes the request."
|
|
||||||
),
|
),
|
||||||
separator=[
|
separator=["\n\n"],
|
||||||
"\n\n"
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="aquila",
|
name="aquila",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["Human: {{content}}###Assistant:"]),
|
||||||
"Human: {{content}}###Assistant:"
|
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||||
]),
|
|
||||||
format_assistant=StringFormatter(container=[
|
|
||||||
"{{content}}"
|
|
||||||
]),
|
|
||||||
system=(
|
system=(
|
||||||
"A chat between a curious human and an artificial intelligence assistant. "
|
"A chat between a curious human and an artificial intelligence assistant. "
|
||||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||||
),
|
),
|
||||||
separator=[
|
separator=["###"],
|
||||||
"###"
|
stop_words=["</s>"],
|
||||||
],
|
efficient_eos=True,
|
||||||
stop_words=[
|
|
||||||
"</s>"
|
|
||||||
],
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="baichuan",
|
name="baichuan",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||||
{"token": "<reserved_102>"},
|
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||||
"{{content}}",
|
efficient_eos=True,
|
||||||
{"token": "<reserved_103>"}
|
|
||||||
]),
|
|
||||||
format_assistant=StringFormatter(container=[
|
|
||||||
"{{content}}"
|
|
||||||
]),
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="baichuan2",
|
name="baichuan2",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
|
||||||
{"token": "<reserved_106>"},
|
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||||
"{{content}}",
|
efficient_eos=True,
|
||||||
{"token": "<reserved_107>"}
|
|
||||||
]),
|
|
||||||
format_assistant=StringFormatter(container=[
|
|
||||||
"{{content}}"
|
|
||||||
]),
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="belle",
|
name="belle", format_user=StringFormatter(container=["Human: {{content}}\n\nBelle: "]), separator=["\n\n"]
|
||||||
format_user=StringFormatter(container=[
|
|
||||||
"Human: {{content}}\n\nBelle: "
|
|
||||||
]),
|
|
||||||
separator=[
|
|
||||||
"\n\n"
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="bluelm",
|
name="bluelm",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||||
{"token": "[|Human|]:"},
|
|
||||||
"{{content}}",
|
|
||||||
{"token": "[|AI|]:"}
|
|
||||||
])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="chatglm2",
|
name="chatglm2",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||||
"[Round {{idx}}]\n\n问:{{content}}\n\n答:"
|
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||||
]),
|
format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||||
format_assistant=StringFormatter(container=[
|
separator=["\n\n"],
|
||||||
"{{content}}"
|
efficient_eos=True,
|
||||||
]),
|
|
||||||
format_system=StringFormatter(container=[
|
|
||||||
{"token": "[gMASK]"},
|
|
||||||
{"token": "sop"},
|
|
||||||
"{{content}}"
|
|
||||||
]),
|
|
||||||
separator=[
|
|
||||||
"\n\n"
|
|
||||||
],
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="chatglm3",
|
name="chatglm3",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||||
{"token": "<|user|>"},
|
format_assistant=StringFormatter(container=["\n" "{{content}}"]),
|
||||||
"\n",
|
format_system=StringFormatter(
|
||||||
"{{content}}",
|
container=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
||||||
{"token": "<|assistant|>"}
|
),
|
||||||
]),
|
format_observation=StringFormatter(container=[{"token": "<|observation|>"}, "\n", "{{content}}"]),
|
||||||
format_assistant=StringFormatter(container=[
|
format_function=FunctionFormatter(container=["{{name}}\n{{arguments}}"]),
|
||||||
"\n"
|
|
||||||
"{{content}}"
|
|
||||||
]),
|
|
||||||
format_system=StringFormatter(container=[
|
|
||||||
{"token": "[gMASK]"},
|
|
||||||
{"token": "sop"},
|
|
||||||
{"token": "<|system|>"},
|
|
||||||
"\n",
|
|
||||||
"{{content}}"
|
|
||||||
]),
|
|
||||||
format_observation=StringFormatter(container=[
|
|
||||||
{"token": "<|observation|>"},
|
|
||||||
"\n",
|
|
||||||
"{{content}}"
|
|
||||||
]),
|
|
||||||
format_function=FunctionFormatter(container=[
|
|
||||||
"{{name}}\n{{arguments}}"
|
|
||||||
]),
|
|
||||||
system=(
|
system=(
|
||||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
||||||
"Follow the user's instructions carefully. Respond using markdown."
|
"Follow the user's instructions carefully. Respond using markdown."
|
||||||
),
|
),
|
||||||
stop_words=[
|
stop_words=["<|user|>", "<|observation|>"],
|
||||||
"<|user|>",
|
efficient_eos=True,
|
||||||
"<|observation|>"
|
|
||||||
],
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="codegeex2",
|
name="codegeex2", format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"])
|
||||||
format_system=StringFormatter(container=[
|
|
||||||
{"token": "[gMASK]"},
|
|
||||||
{"token": "sop"},
|
|
||||||
"{{content}}"
|
|
||||||
])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(name="deepseek", format_user=StringFormatter(container=["User: {{content}}\n\nAssistant:"]))
|
||||||
name="deepseek",
|
|
||||||
format_user=StringFormatter(container=[
|
|
||||||
"User: {{content}}\n\nAssistant:"
|
|
||||||
])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="deepseekcoder",
|
name="deepseekcoder",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n### Response:\n"]),
|
||||||
"### Instruction:\n{{content}}\n### Response:\n"
|
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||||
]),
|
|
||||||
format_assistant=StringFormatter(container=[
|
|
||||||
"{{content}}"
|
|
||||||
]),
|
|
||||||
system=(
|
system=(
|
||||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||||
"For politically sensitive questions, security and privacy issues, "
|
"For politically sensitive questions, security and privacy issues, "
|
||||||
"and other non-computer science questions, you will refuse to answer\n"
|
"and other non-computer science questions, you will refuse to answer\n"
|
||||||
),
|
),
|
||||||
separator=[
|
separator=["\n", {"token": "<|EOT|>"}, "\n"],
|
||||||
"\n",
|
stop_words=["<|EOT|>"],
|
||||||
{"token": "<|EOT|>"},
|
efficient_eos=True,
|
||||||
"\n"
|
|
||||||
],
|
|
||||||
stop_words=[
|
|
||||||
"<|EOT|>"
|
|
||||||
],
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="default",
|
name="default",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["Human: {{content}}\nAssistant: "]),
|
||||||
"Human: {{content}}\nAssistant: "
|
|
||||||
]),
|
|
||||||
system=(
|
system=(
|
||||||
"A chat between a curious user and an artificial intelligence assistant. "
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||||
),
|
),
|
||||||
separator=[
|
separator=["\n"],
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="falcon",
|
name="falcon",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["User: {{content}}\nFalcon:"]),
|
||||||
"User: {{content}}\nFalcon:"
|
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||||
]),
|
separator=["\n"],
|
||||||
format_assistant=StringFormatter(container=[
|
efficient_eos=True,
|
||||||
"{{content}}"
|
|
||||||
]),
|
|
||||||
separator=[
|
|
||||||
"\n"
|
|
||||||
],
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||||
"<|User|>:{{content}}",
|
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||||
{"token": "<eoh>"},
|
separator=[{"token": "<eoa>"}, "\n"],
|
||||||
"\n<|Bot|>:"
|
stop_words=["<eoa>"],
|
||||||
]),
|
efficient_eos=True,
|
||||||
format_assistant=StringFormatter(container=[
|
|
||||||
"{{content}}"
|
|
||||||
]),
|
|
||||||
separator=[
|
|
||||||
{"token": "<eoa>"},
|
|
||||||
"\n"
|
|
||||||
],
|
|
||||||
stop_words=[
|
|
||||||
"<eoa>"
|
|
||||||
],
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="intern2",
|
name="intern2",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(
|
||||||
{"token": "[UNUSED_TOKEN_146]"},
|
container=[
|
||||||
"user\n{{content}}",
|
{"token": "[UNUSED_TOKEN_146]"},
|
||||||
{"token": "[UNUSED_TOKEN_145]"},
|
"user\n{{content}}",
|
||||||
"\n",
|
{"token": "[UNUSED_TOKEN_145]"},
|
||||||
{"token": "[UNUSED_TOKEN_146]"},
|
"\n",
|
||||||
"assistant\n"
|
{"token": "[UNUSED_TOKEN_146]"},
|
||||||
]),
|
"assistant\n",
|
||||||
format_assistant=StringFormatter(container=[
|
]
|
||||||
"{{content}}"
|
),
|
||||||
]),
|
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||||
format_system=StringFormatter(container=[
|
format_system=StringFormatter(
|
||||||
{"token": "[UNUSED_TOKEN_146]"},
|
container=[{"token": "[UNUSED_TOKEN_146]"}, "system\n{{content}}", {"token": "[UNUSED_TOKEN_145]"}, "\n"]
|
||||||
"system\n{{content}}",
|
),
|
||||||
{"token": "[UNUSED_TOKEN_145]"},
|
|
||||||
"\n"
|
|
||||||
]),
|
|
||||||
system=(
|
system=(
|
||||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||||
"- InternLM (书生·浦语) is a conversational language model that is developed "
|
"- InternLM (书生·浦语) is a conversational language model that is developed "
|
||||||
|
@ -532,14 +411,9 @@ register_template(
|
||||||
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
|
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
|
||||||
"by the user such as English and 中文."
|
"by the user such as English and 中文."
|
||||||
),
|
),
|
||||||
separator=[
|
separator=[{"token": "[UNUSED_TOKEN_145]"}, "\n"],
|
||||||
{"token": "[UNUSED_TOKEN_145]"},
|
stop_words=["[UNUSED_TOKEN_145]"],
|
||||||
"\n"
|
efficient_eos=True,
|
||||||
],
|
|
||||||
stop_words=[
|
|
||||||
"[UNUSED_TOKEN_145]"
|
|
||||||
],
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -556,7 +430,7 @@ register_template(
|
||||||
"If a question does not make any sense, or is not factually coherent, "
|
"If a question does not make any sense, or is not factually coherent, "
|
||||||
"explain why instead of answering something not correct. "
|
"explain why instead of answering something not correct. "
|
||||||
"If you don't know the answer to a question, please don't share false information."
|
"If you don't know the answer to a question, please don't share false information."
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -564,142 +438,83 @@ register_template(
|
||||||
name="llama2_zh",
|
name="llama2_zh",
|
||||||
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
|
||||||
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||||
system="You are a helpful assistant. 你是一个乐于助人的助手。"
|
system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(name="mistral", format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]))
|
||||||
name="mistral",
|
|
||||||
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="openchat",
|
name="openchat",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(
|
||||||
"GPT4 Correct User: {{content}}",
|
container=["GPT4 Correct User: {{content}}", {"token": "<|end_of_turn|>"}, "GPT4 Correct Assistant:"]
|
||||||
{"token": "<|end_of_turn|>"},
|
),
|
||||||
"GPT4 Correct Assistant:"
|
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||||
]),
|
separator=[{"token": "<|end_of_turn|>"}],
|
||||||
format_assistant=StringFormatter(container=[
|
stop_words=["<|end_of_turn|>"],
|
||||||
"{{content}}"
|
efficient_eos=True,
|
||||||
]),
|
|
||||||
separator=[
|
|
||||||
{"token": "<|end_of_turn|>"}
|
|
||||||
],
|
|
||||||
stop_words=[
|
|
||||||
"<|end_of_turn|>"
|
|
||||||
],
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="qwen",
|
name="qwen",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
|
format_system=StringFormatter(container=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
]),
|
|
||||||
format_system=StringFormatter(container=[
|
|
||||||
"<|im_start|>system\n{{content}}<|im_end|>\n"
|
|
||||||
]),
|
|
||||||
system="You are a helpful assistant.",
|
system="You are a helpful assistant.",
|
||||||
separator=[
|
separator=["\n"],
|
||||||
"\n"
|
stop_words=["<|im_end|>"],
|
||||||
],
|
replace_eos=True,
|
||||||
stop_words=[
|
|
||||||
"<|im_end|>"
|
|
||||||
],
|
|
||||||
replace_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(name="solar", format_user=StringFormatter(container=["### User:\n{{content}}\n\n### Assistant:\n"]))
|
||||||
name="solar",
|
|
||||||
format_user=StringFormatter(container=[
|
|
||||||
"### User:\n{{content}}\n\n### Assistant:\n"
|
|
||||||
])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="starchat",
|
name="starchat",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(
|
||||||
{"token": "<|user|>"},
|
container=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
|
||||||
"\n{{content}}",
|
),
|
||||||
{"token": "<|end|>"},
|
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||||
"\n",
|
format_system=StringFormatter(container=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
|
||||||
{"token": "<|assistant|>"}
|
separator=[{"token": "<|end|>"}, "\n"],
|
||||||
]),
|
stop_words=["<|end|>"],
|
||||||
format_assistant=StringFormatter(container=[
|
efficient_eos=True,
|
||||||
"{{content}}"
|
|
||||||
]),
|
|
||||||
format_system=StringFormatter(container=[
|
|
||||||
{"token": "<|system|>"},
|
|
||||||
"\n{{content}}",
|
|
||||||
{"token": "<|end|>"},
|
|
||||||
"\n"
|
|
||||||
]),
|
|
||||||
separator=[
|
|
||||||
{"token": "<|end|>"},
|
|
||||||
"\n"
|
|
||||||
],
|
|
||||||
stop_words=[
|
|
||||||
"<|end|>"
|
|
||||||
],
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(name="vanilla")
|
||||||
name="vanilla"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="vicuna",
|
name="vicuna",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["USER: {{content}} ASSISTANT:"]),
|
||||||
"USER: {{content}} ASSISTANT:"
|
|
||||||
]),
|
|
||||||
system=(
|
system=(
|
||||||
"A chat between a curious user and an artificial intelligence assistant. "
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="xuanyuan",
|
name="xuanyuan",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["Human: {{content}} Assistant:"]),
|
||||||
"Human: {{content}} Assistant:"
|
|
||||||
]),
|
|
||||||
system=(
|
system=(
|
||||||
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
||||||
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
||||||
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(name="xverse", format_user=StringFormatter(container=["Human: {{content}}\n\nAssistant: "]))
|
||||||
name="xverse",
|
|
||||||
format_user=StringFormatter(container=[
|
|
||||||
"Human: {{content}}\n\nAssistant: "
|
|
||||||
])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="yayi",
|
name="yayi",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||||
{"token": "<|Human|>"},
|
format_system=StringFormatter(container=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||||
":\n{{content}}\n\n",
|
|
||||||
{"token": "<|YaYi|>"},
|
|
||||||
":"
|
|
||||||
]),
|
|
||||||
format_system=StringFormatter(container=[
|
|
||||||
{"token": "<|System|>"},
|
|
||||||
":\n{{content}}\n\n"
|
|
||||||
]),
|
|
||||||
system=(
|
system=(
|
||||||
"You are a helpful, respectful and honest assistant named YaYi "
|
"You are a helpful, respectful and honest assistant named YaYi "
|
||||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||||
|
@ -711,67 +526,43 @@ register_template(
|
||||||
"explain why instead of answering something not correct. "
|
"explain why instead of answering something not correct. "
|
||||||
"If you don't know the answer to a question, please don't share false information."
|
"If you don't know the answer to a question, please don't share false information."
|
||||||
),
|
),
|
||||||
separator=[
|
separator=["\n\n"],
|
||||||
"\n\n"
|
stop_words=["<|End|>"],
|
||||||
],
|
|
||||||
stop_words=[
|
|
||||||
"<|End|>"
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="yi",
|
name="yi",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
|
separator=["\n"],
|
||||||
]),
|
stop_words=["<|im_end|>"],
|
||||||
separator=[
|
replace_eos=True,
|
||||||
"\n"
|
|
||||||
],
|
|
||||||
stop_words=[
|
|
||||||
"<|im_end|>"
|
|
||||||
],
|
|
||||||
replace_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="yuan",
|
name="yuan",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["{{content}}", {"token": "<sep>"}]),
|
||||||
"{{content}}",
|
separator=["\n"],
|
||||||
{"token": "<sep>"}
|
stop_words=["<eod>"],
|
||||||
]),
|
replace_eos=True,
|
||||||
separator=[
|
|
||||||
"\n"
|
|
||||||
],
|
|
||||||
stop_words=[
|
|
||||||
"<eod>"
|
|
||||||
],
|
|
||||||
replace_eos=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="zephyr",
|
name="zephyr",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=["<|user|>\n{{content}}</s><|assistant|>"]),
|
||||||
"<|user|>\n{{content}}</s><|assistant|>"
|
format_system=StringFormatter(
|
||||||
]),
|
container=[
|
||||||
format_system=StringFormatter(container=[
|
"<|system|>\n{{content}}</s>",
|
||||||
"<|system|>\n{{content}}</s>",
|
]
|
||||||
]),
|
),
|
||||||
system="You are a friendly chatbot who always responds in the style of a pirate"
|
system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="ziya",
|
name="ziya",
|
||||||
format_user=StringFormatter(container=[
|
format_user=StringFormatter(container=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
|
||||||
{"token": "<human>"},
|
separator=["\n"],
|
||||||
":{{content}}\n",
|
|
||||||
{"token": "<bot>"},
|
|
||||||
":"
|
|
||||||
]),
|
|
||||||
separator=[
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,9 +4,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,12 +46,10 @@ def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments")
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
|
||||||
data_args: "DataArguments",
|
|
||||||
training_args: "TrainingArguments"
|
|
||||||
) -> Dict[str, "Dataset"]:
|
) -> Dict[str, "Dataset"]:
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if data_args.val_size > 1e-6: # Split the dataset
|
if data_args.val_size > 1e-6: # Split the dataset
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
val_set = dataset.take(int(data_args.val_size))
|
val_set = dataset.take(int(data_args.val_size))
|
||||||
train_set = dataset.skip(int(data_args.val_size))
|
train_set = dataset.skip(int(data_args.val_size))
|
||||||
|
@ -63,5 +63,5 @@ def split_dataset(
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
return {"train_dataset": dataset}
|
return {"train_dataset": dataset}
|
||||||
else: # do_eval or do_predict
|
else: # do_eval or do_predict
|
||||||
return {"eval_dataset": dataset}
|
return {"eval_dataset": dataset}
|
||||||
|
|
|
@ -1,35 +1,34 @@
|
||||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import inspect
|
import inspect
|
||||||
from tqdm import tqdm, trange
|
import json
|
||||||
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from tqdm import tqdm, trange
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from .template import get_eval_template
|
|
||||||
from ..extras.constants import CHOICES, SUBJECTS
|
from ..extras.constants import CHOICES, SUBJECTS
|
||||||
from ..hparams import get_eval_args
|
from ..hparams import get_eval_args
|
||||||
from ..model import dispatch_model, load_model_and_tokenizer
|
from ..model import dispatch_model, load_model_and_tokenizer
|
||||||
|
from .template import get_eval_template
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
|
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
||||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||||
self.model = dispatch_model(self.model)
|
self.model = dispatch_model(self.model)
|
||||||
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
||||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||||
self.choice_inputs = [self.tokenizer.encode(
|
self.choice_inputs = [
|
||||||
self.eval_template.prefix + ch, add_special_tokens=False
|
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
||||||
)[-1] for ch in CHOICES]
|
]
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||||
|
@ -41,10 +40,10 @@ class Evaluator:
|
||||||
|
|
||||||
def eval(self) -> None:
|
def eval(self) -> None:
|
||||||
mapping = cached_file(
|
mapping = cached_file(
|
||||||
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
filename="mapping.json",
|
filename="mapping.json",
|
||||||
cache_dir=self.model_args.cache_dir,
|
cache_dir=self.model_args.cache_dir,
|
||||||
token=self.model_args.hf_hub_token
|
token=self.model_args.hf_hub_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(mapping, "r", encoding="utf-8") as f:
|
with open(mapping, "r", encoding="utf-8") as f:
|
||||||
|
@ -54,7 +53,7 @@ class Evaluator:
|
||||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||||
results = {}
|
results = {}
|
||||||
for subject in pbar:
|
for subject in pbar:
|
||||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||||
kwargs = {"trust_remote_code": True}
|
kwargs = {"trust_remote_code": True}
|
||||||
else:
|
else:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
@ -65,32 +64,34 @@ class Evaluator:
|
||||||
cache_dir=self.model_args.cache_dir,
|
cache_dir=self.model_args.cache_dir,
|
||||||
download_mode=self.eval_args.download_mode,
|
download_mode=self.eval_args.download_mode,
|
||||||
token=self.model_args.hf_hub_token,
|
token=self.model_args.hf_hub_token,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
pbar.set_postfix_str(categorys[subject]["name"])
|
pbar.set_postfix_str(categorys[subject]["name"])
|
||||||
inputs, outputs, labels = [], [], []
|
inputs, outputs, labels = [], [], []
|
||||||
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
||||||
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
support_set = (
|
||||||
|
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||||
|
)
|
||||||
messages = self.eval_template.format_example(
|
messages = self.eval_template.format_example(
|
||||||
target_data=dataset[self.data_args.split][i],
|
target_data=dataset[self.data_args.split][i],
|
||||||
support_set=support_set,
|
support_set=support_set,
|
||||||
subject_name=categorys[subject]["name"]
|
subject_name=categorys[subject]["name"],
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids, _ = self.template.encode_oneturn(
|
input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
|
||||||
tokenizer=self.tokenizer, messages=messages
|
|
||||||
)
|
|
||||||
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
||||||
labels.append(messages[-1]["content"])
|
labels.append(messages[-1]["content"])
|
||||||
|
|
||||||
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
|
for i in trange(
|
||||||
|
0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
|
||||||
|
):
|
||||||
batch_input = self.tokenizer.pad(
|
batch_input = self.tokenizer.pad(
|
||||||
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
||||||
).to(self.model.device)
|
).to(self.model.device)
|
||||||
preds = self.batch_inference(batch_input)
|
preds = self.batch_inference(batch_input)
|
||||||
outputs += preds
|
outputs += preds
|
||||||
|
|
||||||
corrects = (np.array(outputs) == np.array(labels))
|
corrects = np.array(outputs) == np.array(labels)
|
||||||
category_name = categorys[subject]["category"]
|
category_name = categorys[subject]["category"]
|
||||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||||
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
||||||
|
@ -100,10 +101,13 @@ class Evaluator:
|
||||||
self._save_results(category_corrects, results)
|
self._save_results(category_corrects, results)
|
||||||
|
|
||||||
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
||||||
score_info = "\n".join([
|
score_info = "\n".join(
|
||||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
[
|
||||||
for category_name, category_correct in category_corrects.items() if len(category_correct)
|
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||||
])
|
for category_name, category_correct in category_corrects.items()
|
||||||
|
if len(category_correct)
|
||||||
|
]
|
||||||
|
)
|
||||||
print(score_info)
|
print(score_info)
|
||||||
if self.eval_args.save_dir is not None:
|
if self.eval_args.save_dir is not None:
|
||||||
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||||
|
|
||||||
from ..extras.constants import CHOICES
|
|
||||||
from ..data import Role
|
from ..data import Role
|
||||||
|
from ..extras.constants import CHOICES
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
@ -10,24 +11,17 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EvalTemplate:
|
class EvalTemplate:
|
||||||
|
|
||||||
system: str
|
system: str
|
||||||
choice: str
|
choice: str
|
||||||
answer: str
|
answer: str
|
||||||
prefix: str
|
prefix: str
|
||||||
|
|
||||||
def parse_example(
|
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||||
self,
|
|
||||||
example: Dict[str, str]
|
|
||||||
) -> Tuple[str, str]:
|
|
||||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||||
|
|
||||||
def format_example(
|
def format_example(
|
||||||
self,
|
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
|
||||||
target_data: Dict[str, str],
|
|
||||||
support_set: "Dataset",
|
|
||||||
subject_name: str
|
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
messages = []
|
messages = []
|
||||||
for k in range(len(support_set)):
|
for k in range(len(support_set)):
|
||||||
|
@ -45,19 +39,8 @@ class EvalTemplate:
|
||||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_eval_template(
|
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||||
name: str,
|
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
||||||
system: str,
|
|
||||||
choice: str,
|
|
||||||
answer: str,
|
|
||||||
prefix: str
|
|
||||||
) -> None:
|
|
||||||
eval_templates[name] = EvalTemplate(
|
|
||||||
system=system,
|
|
||||||
choice=choice,
|
|
||||||
answer=answer,
|
|
||||||
prefix=prefix
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_eval_template(name: str) -> "EvalTemplate":
|
def get_eval_template(name: str) -> "EvalTemplate":
|
||||||
|
@ -71,7 +54,7 @@ register_eval_template(
|
||||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
answer="\nAnswer: ",
|
answer="\nAnswer: ",
|
||||||
prefix=" "
|
prefix=" ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,5 +63,5 @@ register_eval_template(
|
||||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
answer="\n答案:",
|
answer="\n答案:",
|
||||||
prefix="\n"
|
prefix="\n",
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||||
|
|
||||||
from .constants import LOG_FILE_NAME
|
from .constants import LOG_FILE_NAME
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
|
@ -12,14 +13,13 @@ from .misc import fix_valuehead_checkpoint
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FixValueHeadModelCallback(TrainerCallback):
|
class FixValueHeadModelCallback(TrainerCallback):
|
||||||
|
|
||||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called after a checkpoint save.
|
Event called after a checkpoint save.
|
||||||
|
@ -28,12 +28,11 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||||
fix_valuehead_checkpoint(
|
fix_valuehead_checkpoint(
|
||||||
model=kwargs.pop("model"),
|
model=kwargs.pop("model"),
|
||||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||||
safe_serialization=args.save_safetensors
|
safe_serialization=args.save_safetensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
|
|
||||||
def __init__(self, runner=None):
|
def __init__(self, runner=None):
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
self.in_training = False
|
self.in_training = False
|
||||||
|
@ -99,7 +98,9 @@ class LogCallback(TrainerCallback):
|
||||||
self.cur_steps = 0
|
self.cur_steps = 0
|
||||||
self.max_steps = 0
|
self.max_steps = 0
|
||||||
|
|
||||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
|
def on_predict(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Event called after a successful prediction.
|
Event called after a successful prediction.
|
||||||
"""
|
"""
|
||||||
|
@ -125,18 +126,22 @@ class LogCallback(TrainerCallback):
|
||||||
epoch=state.log_history[-1].get("epoch", None),
|
epoch=state.log_history[-1].get("epoch", None),
|
||||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||||
elapsed_time=self.elapsed_time,
|
elapsed_time=self.elapsed_time,
|
||||||
remaining_time=self.remaining_time
|
remaining_time=self.remaining_time,
|
||||||
)
|
)
|
||||||
if self.runner is not None:
|
if self.runner is not None:
|
||||||
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
logger.info(
|
||||||
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||||
))
|
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||||
f.write(json.dumps(logs) + "\n")
|
f.write(json.dumps(logs) + "\n")
|
||||||
|
|
||||||
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_prediction_step(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Event called after a prediction step.
|
Event called after a prediction step.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
from collections import OrderedDict, defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from collections import defaultdict, OrderedDict
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,14 +11,7 @@ DEFAULT_MODULE = defaultdict(str)
|
||||||
|
|
||||||
DEFAULT_TEMPLATE = defaultdict(str)
|
DEFAULT_TEMPLATE = defaultdict(str)
|
||||||
|
|
||||||
FILEEXT2TYPE = {
|
FILEEXT2TYPE = {"arrow": "arrow", "csv": "csv", "json": "json", "jsonl": "json", "parquet": "parquet", "txt": "text"}
|
||||||
"arrow": "arrow",
|
|
||||||
"csv": "csv",
|
|
||||||
"json": "json",
|
|
||||||
"jsonl": "json",
|
|
||||||
"parquet": "parquet",
|
|
||||||
"txt": "text"
|
|
||||||
}
|
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
@ -39,22 +32,21 @@ TRAINING_STAGES = {
|
||||||
"Reward Modeling": "rm",
|
"Reward Modeling": "rm",
|
||||||
"PPO": "ppo",
|
"PPO": "ppo",
|
||||||
"DPO": "dpo",
|
"DPO": "dpo",
|
||||||
"Pre-Training": "pt"
|
"Pre-Training": "pt",
|
||||||
}
|
}
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
|
||||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
||||||
|
|
||||||
|
|
||||||
class DownloadSource(str, Enum):
|
class DownloadSource(str, Enum):
|
||||||
DEFAULT = "hf"
|
DEFAULT = "hf"
|
||||||
MODELSCOPE = "ms"
|
MODELSCOPE = "ms"
|
||||||
|
|
||||||
|
|
||||||
def register_model_group(
|
def register_model_group(
|
||||||
models: Dict[str, Dict[DownloadSource, str]],
|
models: Dict[str, Dict[DownloadSource, str]], module: Optional[str] = None, template: Optional[str] = None
|
||||||
module: Optional[str] = None,
|
|
||||||
template: Optional[str] = None
|
|
||||||
) -> None:
|
) -> None:
|
||||||
prefix = None
|
prefix = None
|
||||||
for name, path in models.items():
|
for name, path in models.items():
|
||||||
|
@ -73,19 +65,19 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"Baichuan-7B-Base": {
|
"Baichuan-7B-Base": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B"
|
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
|
||||||
},
|
},
|
||||||
"Baichuan-13B-Base": {
|
"Baichuan-13B-Base": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base"
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
|
||||||
},
|
},
|
||||||
"Baichuan-13B-Chat": {
|
"Baichuan-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat"
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
module="W_pack",
|
module="W_pack",
|
||||||
template="baichuan"
|
template="baichuan",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,23 +85,23 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"Baichuan2-7B-Base": {
|
"Baichuan2-7B-Base": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base"
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
|
||||||
},
|
},
|
||||||
"Baichuan2-13B-Base": {
|
"Baichuan2-13B-Base": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base"
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
||||||
},
|
},
|
||||||
"Baichuan2-7B-Chat": {
|
"Baichuan2-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat"
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
||||||
},
|
},
|
||||||
"Baichuan2-13B-Chat": {
|
"Baichuan2-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat"
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
module="W_pack",
|
module="W_pack",
|
||||||
template="baichuan2"
|
template="baichuan2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,18 +109,18 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"BLOOM-560M": {
|
"BLOOM-560M": {
|
||||||
DownloadSource.DEFAULT: "bigscience/bloom-560m",
|
DownloadSource.DEFAULT: "bigscience/bloom-560m",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
|
||||||
},
|
},
|
||||||
"BLOOM-3B": {
|
"BLOOM-3B": {
|
||||||
DownloadSource.DEFAULT: "bigscience/bloom-3b",
|
DownloadSource.DEFAULT: "bigscience/bloom-3b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
|
||||||
},
|
},
|
||||||
"BLOOM-7B1": {
|
"BLOOM-7B1": {
|
||||||
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
|
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
module="query_key_value"
|
module="query_key_value",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,18 +128,18 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"BLOOMZ-560M": {
|
"BLOOMZ-560M": {
|
||||||
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
|
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
|
||||||
},
|
},
|
||||||
"BLOOMZ-3B": {
|
"BLOOMZ-3B": {
|
||||||
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
|
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
|
||||||
},
|
},
|
||||||
"BLOOMZ-7B1-mt": {
|
"BLOOMZ-7B1-mt": {
|
||||||
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
|
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
module="query_key_value"
|
module="query_key_value",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -155,14 +147,14 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"BlueLM-7B-Base": {
|
"BlueLM-7B-Base": {
|
||||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
||||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base"
|
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
|
||||||
},
|
},
|
||||||
"BlueLM-7B-Chat": {
|
"BlueLM-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat"
|
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="bluelm"
|
template="bluelm",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -170,11 +162,11 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChatGLM2-6B-Chat": {
|
"ChatGLM2-6B-Chat": {
|
||||||
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
|
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
|
||||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b"
|
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="query_key_value",
|
module="query_key_value",
|
||||||
template="chatglm2"
|
template="chatglm2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -182,15 +174,15 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChatGLM3-6B-Base": {
|
"ChatGLM3-6B-Base": {
|
||||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
|
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
|
||||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base"
|
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
|
||||||
},
|
},
|
||||||
"ChatGLM3-6B-Chat": {
|
"ChatGLM3-6B-Chat": {
|
||||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
|
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
|
||||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b"
|
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
module="query_key_value",
|
module="query_key_value",
|
||||||
template="chatglm3"
|
template="chatglm3",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -198,30 +190,30 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChineseLLaMA2-1.3B": {
|
"ChineseLLaMA2-1.3B": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-7B": {
|
"ChineseLLaMA2-7B": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-13B": {
|
"ChineseLLaMA2-13B": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-1.3B-Chat": {
|
"ChineseLLaMA2-1.3B-Chat": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-7B-Chat": {
|
"ChineseLLaMA2-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-13B-Chat": {
|
"ChineseLLaMA2-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="llama2_zh"
|
template="llama2_zh",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -229,22 +221,22 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"DeepSeekLLM-7B-Base": {
|
"DeepSeekLLM-7B-Base": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base"
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
|
||||||
},
|
},
|
||||||
"DeepSeekLLM-67B-Base": {
|
"DeepSeekLLM-67B-Base": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base"
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
|
||||||
},
|
},
|
||||||
"DeepSeekLLM-7B-Chat": {
|
"DeepSeekLLM-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat"
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
|
||||||
},
|
},
|
||||||
"DeepSeekLLM-67B-Chat": {
|
"DeepSeekLLM-67B-Chat": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat"
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="deepseek"
|
template="deepseek",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -252,22 +244,22 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"DeepSeekCoder-6.7B-Base": {
|
"DeepSeekCoder-6.7B-Base": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base"
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-33B-Base": {
|
"DeepSeekCoder-33B-Base": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base"
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-6.7B-Chat": {
|
"DeepSeekCoder-6.7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct"
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-33B-Chat": {
|
"DeepSeekCoder-33B-Chat": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct"
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="deepseekcoder"
|
template="deepseekcoder",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -275,14 +267,14 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"DeepSeekMoE-16B-Base": {
|
"DeepSeekMoE-16B-Base": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base"
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
|
||||||
},
|
},
|
||||||
"DeepSeekMoE-16B-Chat": {
|
"DeepSeekMoE-16B-Chat": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat"
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="deepseek"
|
template="deepseek",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -290,31 +282,31 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"Falcon-7B": {
|
"Falcon-7B": {
|
||||||
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
|
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
|
||||||
},
|
},
|
||||||
"Falcon-40B": {
|
"Falcon-40B": {
|
||||||
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
|
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
|
||||||
},
|
},
|
||||||
"Falcon-180B": {
|
"Falcon-180B": {
|
||||||
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
|
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B"
|
DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
|
||||||
},
|
},
|
||||||
"Falcon-7B-Chat": {
|
"Falcon-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
|
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
|
||||||
},
|
},
|
||||||
"Falcon-40B-Chat": {
|
"Falcon-40B-Chat": {
|
||||||
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
|
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
|
||||||
},
|
},
|
||||||
"Falcon-180B-Chat": {
|
"Falcon-180B-Chat": {
|
||||||
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
|
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat"
|
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
module="query_key_value",
|
module="query_key_value",
|
||||||
template="falcon"
|
template="falcon",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -322,22 +314,22 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"InternLM-7B": {
|
"InternLM-7B": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b"
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
|
||||||
},
|
},
|
||||||
"InternLM-20B": {
|
"InternLM-20B": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b"
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
|
||||||
},
|
},
|
||||||
"InternLM-7B-Chat": {
|
"InternLM-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b"
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
|
||||||
},
|
},
|
||||||
"InternLM-20B-Chat": {
|
"InternLM-20B-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b"
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="intern"
|
template="intern",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -345,23 +337,23 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"InternLM2-7B": {
|
"InternLM2-7B": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2-7b",
|
DownloadSource.DEFAULT: "internlm/internlm2-7b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b"
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
|
||||||
},
|
},
|
||||||
"InternLM2-20B": {
|
"InternLM2-20B": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2-20b",
|
DownloadSource.DEFAULT: "internlm/internlm2-20b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b"
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
|
||||||
},
|
},
|
||||||
"InternLM2-7B-Chat": {
|
"InternLM2-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
|
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b"
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
|
||||||
},
|
},
|
||||||
"InternLM2-20B-Chat": {
|
"InternLM2-20B-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
|
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b"
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
module="wqkv",
|
module="wqkv",
|
||||||
template="intern2"
|
template="intern2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -369,31 +361,28 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"LingoWhale-8B": {
|
"LingoWhale-8B": {
|
||||||
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
||||||
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B"
|
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="qkv_proj"
|
module="qkv_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA-7B": {
|
"LLaMA-7B": {DownloadSource.DEFAULT: "huggyllama/llama-7b", DownloadSource.MODELSCOPE: "skyline2006/llama-7b"},
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-7b"
|
|
||||||
},
|
|
||||||
"LLaMA-13B": {
|
"LLaMA-13B": {
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-13b"
|
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
|
||||||
},
|
},
|
||||||
"LLaMA-30B": {
|
"LLaMA-30B": {
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-30b"
|
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
|
||||||
},
|
},
|
||||||
"LLaMA-65B": {
|
"LLaMA-65B": {
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-65b"
|
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -402,30 +391,30 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA2-7B": {
|
"LLaMA2-7B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms"
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-13B": {
|
"LLaMA2-13B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms"
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-70B": {
|
"LLaMA2-70B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms"
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-7B-Chat": {
|
"LLaMA2-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms"
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-13B-Chat": {
|
"LLaMA2-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms"
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-70B-Chat": {
|
"LLaMA2-70B-Chat": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms"
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="llama2"
|
template="llama2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -433,18 +422,18 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"Mistral-7B": {
|
"Mistral-7B": {
|
||||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
|
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
|
||||||
},
|
},
|
||||||
"Mistral-7B-Chat": {
|
"Mistral-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
|
||||||
},
|
},
|
||||||
"Mistral-7B-v0.2-Chat": {
|
"Mistral-7B-v0.2-Chat": {
|
||||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
|
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="mistral"
|
template="mistral",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -452,14 +441,14 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"Mixtral-8x7B": {
|
"Mixtral-8x7B": {
|
||||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
|
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
|
||||||
},
|
},
|
||||||
"Mixtral-8x7B-Chat": {
|
"Mixtral-8x7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="mistral"
|
template="mistral",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -467,110 +456,87 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"OpenChat3.5-7B-Chat": {
|
"OpenChat3.5-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "openchat/openchat_3.5",
|
DownloadSource.DEFAULT: "openchat/openchat_3.5",
|
||||||
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5"
|
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="openchat"
|
template="openchat",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Phi-1.5-1.3B": {
|
"Phi-1.5-1.3B": {DownloadSource.DEFAULT: "microsoft/phi-1_5", DownloadSource.MODELSCOPE: "allspace/PHI_1-5"},
|
||||||
DownloadSource.DEFAULT: "microsoft/phi-1_5",
|
"Phi-2-2.7B": {DownloadSource.DEFAULT: "microsoft/phi-2", DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"},
|
||||||
DownloadSource.MODELSCOPE: "allspace/PHI_1-5"
|
|
||||||
},
|
|
||||||
"Phi-2-2.7B": {
|
|
||||||
DownloadSource.DEFAULT: "microsoft/phi-2",
|
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Qwen-1.8B": {
|
"Qwen-1.8B": {DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"},
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
|
"Qwen-7B": {DownloadSource.DEFAULT: "Qwen/Qwen-7B", DownloadSource.MODELSCOPE: "qwen/Qwen-7B"},
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"
|
"Qwen-14B": {DownloadSource.DEFAULT: "Qwen/Qwen-14B", DownloadSource.MODELSCOPE: "qwen/Qwen-14B"},
|
||||||
},
|
"Qwen-72B": {DownloadSource.DEFAULT: "Qwen/Qwen-72B", DownloadSource.MODELSCOPE: "qwen/Qwen-72B"},
|
||||||
"Qwen-7B": {
|
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
|
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B"
|
|
||||||
},
|
|
||||||
"Qwen-14B": {
|
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
|
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B"
|
|
||||||
},
|
|
||||||
"Qwen-72B": {
|
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
|
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B"
|
|
||||||
},
|
|
||||||
"Qwen-1.8B-Chat": {
|
"Qwen-1.8B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat"
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
|
||||||
},
|
|
||||||
"Qwen-7B-Chat": {
|
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"
|
|
||||||
},
|
},
|
||||||
|
"Qwen-7B-Chat": {DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"},
|
||||||
"Qwen-14B-Chat": {
|
"Qwen-14B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat"
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
|
||||||
},
|
},
|
||||||
"Qwen-72B-Chat": {
|
"Qwen-72B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat"
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
|
||||||
},
|
},
|
||||||
"Qwen-1.8B-int8-Chat": {
|
"Qwen-1.8B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8"
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-1.8B-int4-Chat": {
|
"Qwen-1.8B-int4-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4"
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
|
||||||
},
|
},
|
||||||
"Qwen-7B-int8-Chat": {
|
"Qwen-7B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8"
|
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-7B-int4-Chat": {
|
"Qwen-7B-int4-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4"
|
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
|
||||||
},
|
},
|
||||||
"Qwen-14B-int8-Chat": {
|
"Qwen-14B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8"
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-14B-int4-Chat": {
|
"Qwen-14B-int4-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4"
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
|
||||||
},
|
},
|
||||||
"Qwen-72B-int8-Chat": {
|
"Qwen-72B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8"
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-72B-int4-Chat": {
|
"Qwen-72B-int4-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4"
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
module="c_attn",
|
module="c_attn",
|
||||||
template="qwen"
|
template="qwen",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"SOLAR-10.7B": {
|
"SOLAR-10.7B": {DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0"},
|
||||||
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0"
|
|
||||||
},
|
|
||||||
"SOLAR-10.7B-Chat": {
|
"SOLAR-10.7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
|
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="solar"
|
template="solar",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -578,7 +544,7 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"Skywork-13B-Base": {
|
"Skywork-13B-Base": {
|
||||||
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
||||||
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base"
|
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -588,68 +554,51 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"Vicuna1.5-7B-Chat": {
|
"Vicuna1.5-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
||||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5"
|
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
|
||||||
},
|
},
|
||||||
"Vicuna1.5-13B-Chat": {
|
"Vicuna1.5-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
||||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5"
|
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="vicuna"
|
template="vicuna",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"XuanYuan-70B": {
|
"XuanYuan-70B": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"},
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"
|
"XuanYuan-70B-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"},
|
||||||
},
|
"XuanYuan-70B-int8-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"},
|
||||||
"XuanYuan-70B-Chat": {
|
"XuanYuan-70B-int4-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"},
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"
|
|
||||||
},
|
|
||||||
"XuanYuan-70B-int8-Chat": {
|
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"
|
|
||||||
},
|
|
||||||
"XuanYuan-70B-int4-Chat": {
|
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
template="xuanyuan"
|
template="xuanyuan",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"XVERSE-7B": {
|
"XVERSE-7B": {DownloadSource.DEFAULT: "xverse/XVERSE-7B", DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"},
|
||||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
|
"XVERSE-13B": {DownloadSource.DEFAULT: "xverse/XVERSE-13B", DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"},
|
||||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"
|
"XVERSE-65B": {DownloadSource.DEFAULT: "xverse/XVERSE-65B", DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"},
|
||||||
},
|
|
||||||
"XVERSE-13B": {
|
|
||||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
|
|
||||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"
|
|
||||||
},
|
|
||||||
"XVERSE-65B": {
|
|
||||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
|
|
||||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"
|
|
||||||
},
|
|
||||||
"XVERSE-65B-2": {
|
"XVERSE-65B-2": {
|
||||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
|
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
|
||||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2"
|
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
|
||||||
},
|
},
|
||||||
"XVERSE-7B-Chat": {
|
"XVERSE-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat"
|
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
|
||||||
},
|
},
|
||||||
"XVERSE-13B-Chat": {
|
"XVERSE-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat"
|
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
|
||||||
},
|
},
|
||||||
"XVERSE-65B-Chat": {
|
"XVERSE-65B-Chat": {
|
||||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
|
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat"
|
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="xverse"
|
template="xverse",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -657,45 +606,33 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"Yayi-7B": {
|
"Yayi-7B": {
|
||||||
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
|
||||||
},
|
},
|
||||||
"Yayi-13B": {
|
"Yayi-13B": {
|
||||||
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="yayi"
|
template="yayi",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Yi-6B": {
|
"Yi-6B": {DownloadSource.DEFAULT: "01-ai/Yi-6B", DownloadSource.MODELSCOPE: "01ai/Yi-6B"},
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-6B",
|
"Yi-34B": {DownloadSource.DEFAULT: "01-ai/Yi-34B", DownloadSource.MODELSCOPE: "01ai/Yi-34B"},
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B"
|
"Yi-6B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"},
|
||||||
},
|
"Yi-34B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"},
|
||||||
"Yi-34B": {
|
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-34B",
|
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B"
|
|
||||||
},
|
|
||||||
"Yi-6B-Chat": {
|
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
|
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"
|
|
||||||
},
|
|
||||||
"Yi-34B-Chat": {
|
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"
|
|
||||||
},
|
|
||||||
"Yi-6B-int8-Chat": {
|
"Yi-6B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits"
|
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
||||||
},
|
},
|
||||||
"Yi-34B-int8-Chat": {
|
"Yi-34B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits"
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="yi"
|
template="yi",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -703,18 +640,18 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"Yuan2-2B-Chat": {
|
"Yuan2-2B-Chat": {
|
||||||
DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
|
DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
|
||||||
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf"
|
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
|
||||||
},
|
},
|
||||||
"Yuan2-51B-Chat": {
|
"Yuan2-51B-Chat": {
|
||||||
DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
|
DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
|
||||||
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf"
|
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
|
||||||
},
|
},
|
||||||
"Yuan2-102B-Chat": {
|
"Yuan2-102B-Chat": {
|
||||||
DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
|
DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
|
||||||
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf"
|
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="yuan"
|
template="yuan",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -722,12 +659,12 @@ register_model_group(
|
||||||
models={
|
models={
|
||||||
"Zephyr-7B-Alpha-Chat": {
|
"Zephyr-7B-Alpha-Chat": {
|
||||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
|
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
|
||||||
},
|
},
|
||||||
"Zephyr-7B-Beta-Chat": {
|
"Zephyr-7B-Beta-Chat": {
|
||||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta"
|
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
template="zephyr"
|
template="zephyr",
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import sys
|
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
class LoggerHandler(logging.Handler):
|
class LoggerHandler(logging.Handler):
|
||||||
|
@ -27,8 +27,7 @@ def get_logger(name: str) -> logging.Logger:
|
||||||
Gets a standard logger with a stream hander to stdout.
|
Gets a standard logger with a stream hander to stdout.
|
||||||
"""
|
"""
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||||
datefmt="%m/%d/%Y %H:%M:%S"
|
|
||||||
)
|
)
|
||||||
handler = logging.StreamHandler(sys.stdout)
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
|
|
|
@ -1,31 +1,33 @@
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
from typing import TYPE_CHECKING, Dict, Tuple
|
from typing import TYPE_CHECKING, Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import PeftModel
|
||||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
WEIGHTS_NAME,
|
|
||||||
SAFE_WEIGHTS_NAME,
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
is_torch_bf16_gpu_available,
|
is_torch_bf16_gpu_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_xpu_available
|
is_torch_xpu_available,
|
||||||
)
|
)
|
||||||
from peft import PeftModel
|
|
||||||
|
|
||||||
from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
try:
|
try:
|
||||||
_is_bf16_available = is_torch_bf16_gpu_available()
|
_is_bf16_available = is_torch_bf16_gpu_available()
|
||||||
except:
|
except Exception:
|
||||||
_is_bf16_available = False
|
_is_bf16_available = False
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from llmtuner.hparams import ModelArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,6 +38,7 @@ class AverageMeter:
|
||||||
r"""
|
r"""
|
||||||
Computes and stores the average and current value.
|
Computes and stores the average and current value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
@ -75,9 +78,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
|
|
||||||
|
|
||||||
def fix_valuehead_checkpoint(
|
def fix_valuehead_checkpoint(
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||||
output_dir: str,
|
|
||||||
safe_serialization: bool
|
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
The model is already unwrapped.
|
The model is already unwrapped.
|
||||||
|
@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
|
||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
@ -112,9 +114,7 @@ def fix_valuehead_checkpoint(
|
||||||
|
|
||||||
os.remove(path_to_checkpoint)
|
os.remove(path_to_checkpoint)
|
||||||
model.pretrained_model.save_pretrained(
|
model.pretrained_model.save_pretrained(
|
||||||
output_dir,
|
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||||
state_dict=decoder_state_dict or None,
|
|
||||||
safe_serialization=safe_serialization
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
|
@ -182,11 +182,10 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||||
model_args.model_name_or_path = snapshot_download(
|
model_args.model_name_or_path = snapshot_download(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
|
||||||
revision=revision,
|
|
||||||
cache_dir=model_args.cache_dir
|
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||||
|
|
|
@ -9,7 +9,7 @@ def is_package_available(name: str) -> bool:
|
||||||
def get_package_version(name: str) -> str:
|
def get_package_version(name: str) -> str:
|
||||||
try:
|
try:
|
||||||
return importlib.metadata.version(name)
|
return importlib.metadata.version(name)
|
||||||
except:
|
except Exception:
|
||||||
return "0.0.0"
|
return "0.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,16 @@
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional, Tuple
|
|
||||||
from transformers.utils import logging
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
Cache, LlamaAttention, LlamaFlashAttention2, apply_rotary_pos_emb, repeat_kv
|
Cache,
|
||||||
|
LlamaAttention,
|
||||||
|
LlamaFlashAttention2,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -19,7 +24,7 @@ def llama_torch_attn_forward(
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional["Cache"] = None,
|
past_key_value: Optional["Cache"] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
@ -45,15 +50,17 @@ def llama_torch_attn_forward(
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||||
state = torch.cat((
|
state = torch.cat(
|
||||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||||
), dim=2)
|
dim=2,
|
||||||
|
)
|
||||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
|
@ -68,14 +75,17 @@ def llama_torch_attn_forward(
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
attn_output = torch.cat((
|
attn_output = torch.cat(
|
||||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
(
|
||||||
))
|
attn_output[:, :, : self.num_heads // 2],
|
||||||
|
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
@ -94,7 +104,7 @@ def llama_flash_attn_forward(
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# LlamaFlashAttention2 attention does not support output_attentions
|
# LlamaFlashAttention2 attention does not support output_attentions
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
|
@ -124,9 +134,9 @@ def llama_flash_attn_forward(
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
|
|
||||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
@ -144,14 +154,16 @@ def llama_flash_attn_forward(
|
||||||
key_states = key_states.to(target_dtype)
|
key_states = key_states.to(target_dtype)
|
||||||
value_states = value_states.to(target_dtype)
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||||
state = torch.cat((
|
state = torch.cat(
|
||||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||||
), dim=2)
|
dim=2,
|
||||||
|
)
|
||||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
|
@ -162,11 +174,14 @@ def llama_flash_attn_forward(
|
||||||
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
attn_output = torch.cat((
|
attn_output = torch.cat(
|
||||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
(
|
||||||
))
|
attn_output[:, :, : self.num_heads // 2],
|
||||||
|
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
import os
|
|
||||||
import math
|
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers.trainer import TRAINER_STATE_NAME
|
from transformers.trainer import TRAINER_STATE_NAME
|
||||||
|
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
from .packages import is_matplotlib_available
|
from .packages import is_matplotlib_available
|
||||||
|
|
||||||
|
|
||||||
if is_matplotlib_available():
|
if is_matplotlib_available():
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
@ -20,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||||
"""
|
"""
|
||||||
last = scalars[0]
|
last = scalars[0]
|
||||||
smoothed = list()
|
smoothed = list()
|
||||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||||
for next_val in scalars:
|
for next_val in scalars:
|
||||||
smoothed_val = last * weight + (1 - weight) * next_val
|
smoothed_val = last * weight + (1 - weight) * next_val
|
||||||
smoothed.append(smoothed_val)
|
smoothed.append(smoothed_val)
|
||||||
|
@ -29,7 +31,6 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||||
|
|
||||||
|
|
||||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||||
|
|
||||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
from .generating_args import GeneratingArguments
|
from .generating_args import GeneratingArguments
|
||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
from .parser import get_train_args, get_infer_args, get_eval_args
|
from .parser import get_eval_args, get_infer_args, get_train_args
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -14,5 +14,5 @@ __all__ = [
|
||||||
"ModelArguments",
|
"ModelArguments",
|
||||||
"get_train_args",
|
"get_train_args",
|
||||||
"get_infer_args",
|
"get_infer_args",
|
||||||
"get_eval_args"
|
"get_eval_args",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Literal, Optional
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -8,80 +8,66 @@ class DataArguments:
|
||||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||||
"""
|
"""
|
||||||
template: Optional[str] = field(
|
template: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
|
||||||
)
|
)
|
||||||
dataset: Optional[str] = field(
|
dataset: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
dataset_dir: Optional[str] = field(
|
dataset_dir: Optional[str] = field(
|
||||||
default="data",
|
default="data", metadata={"help": "Path to the folder containing the datasets."}
|
||||||
metadata={"help": "Path to the folder containing the datasets."}
|
|
||||||
)
|
)
|
||||||
split: Optional[str] = field(
|
split: Optional[str] = field(
|
||||||
default="train",
|
default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
|
||||||
)
|
)
|
||||||
cutoff_len: Optional[int] = field(
|
cutoff_len: Optional[int] = field(
|
||||||
default=1024,
|
default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."}
|
||||||
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
|
||||||
)
|
)
|
||||||
reserved_label_len: Optional[int] = field(
|
reserved_label_len: Optional[int] = field(
|
||||||
default=1,
|
default=1, metadata={"help": "The maximum length reserved for label after tokenization."}
|
||||||
metadata={"help": "The maximum length reserved for label after tokenization."}
|
|
||||||
)
|
)
|
||||||
train_on_prompt: Optional[bool] = field(
|
train_on_prompt: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
|
||||||
metadata={"help": "Whether to disable the mask on the prompt or not."}
|
|
||||||
)
|
|
||||||
streaming: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Enable dataset streaming."}
|
|
||||||
)
|
)
|
||||||
|
streaming: Optional[bool] = field(default=False, metadata={"help": "Enable dataset streaming."})
|
||||||
buffer_size: Optional[int] = field(
|
buffer_size: Optional[int] = field(
|
||||||
default=16384,
|
default=16384, metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
||||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
|
||||||
)
|
)
|
||||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||||
default="concat",
|
default="concat",
|
||||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
|
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||||
)
|
)
|
||||||
interleave_probs: Optional[str] = field(
|
interleave_probs: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
|
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
overwrite_cache: Optional[bool] = field(
|
overwrite_cache: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "The number of processes to use for the preprocessing."}
|
||||||
metadata={"help": "The number of processes to use for the preprocessing."}
|
|
||||||
)
|
)
|
||||||
max_samples: Optional[int] = field(
|
max_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
|
||||||
)
|
)
|
||||||
eval_num_beams: Optional[int] = field(
|
eval_num_beams: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||||
)
|
)
|
||||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
ignore_pad_token_for_loss: Optional[bool] = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
metadata={
|
||||||
|
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
val_size: Optional[float] = field(
|
val_size: Optional[float] = field(
|
||||||
default=0,
|
default=0, metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
||||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
|
||||||
)
|
)
|
||||||
sft_packing: Optional[bool] = field(
|
sft_packing: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
||||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
|
||||||
)
|
)
|
||||||
cache_path: Optional[str] = field(
|
cache_path: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Path to save or load the preprocessed datasets."}
|
||||||
metadata={"help": "Path to save or load the preprocessed datasets."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from typing import Literal, Optional
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from datasets import DownloadMode
|
from datasets import DownloadMode
|
||||||
|
|
||||||
|
@ -10,36 +10,18 @@ class EvaluationArguments:
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to specify the evaluation parameters.
|
Arguments pertaining to specify the evaluation parameters.
|
||||||
"""
|
"""
|
||||||
task: str = field(
|
task: str = field(metadata={"help": "Name of the evaluation task."})
|
||||||
metadata={"help": "Name of the evaluation task."}
|
|
||||||
)
|
|
||||||
task_dir: Optional[str] = field(
|
task_dir: Optional[str] = field(
|
||||||
default="evaluation",
|
default="evaluation", metadata={"help": "Path to the folder containing the evaluation datasets."}
|
||||||
metadata={"help": "Path to the folder containing the evaluation datasets."}
|
|
||||||
)
|
|
||||||
batch_size: Optional[int] = field(
|
|
||||||
default=4,
|
|
||||||
metadata={"help": "The batch size per GPU for evaluation."}
|
|
||||||
)
|
|
||||||
seed: Optional[int] = field(
|
|
||||||
default=42,
|
|
||||||
metadata={"help": "Random seed to be used with data loaders."}
|
|
||||||
)
|
|
||||||
lang: Optional[Literal["en", "zh"]] = field(
|
|
||||||
default="en",
|
|
||||||
metadata={"help": "Language used at evaluation."}
|
|
||||||
)
|
|
||||||
n_shot: Optional[int] = field(
|
|
||||||
default=5,
|
|
||||||
metadata={"help": "Number of examplars for few-shot learning."}
|
|
||||||
)
|
|
||||||
save_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to save the evaluation results."}
|
|
||||||
)
|
)
|
||||||
|
batch_size: Optional[int] = field(default=4, metadata={"help": "The batch size per GPU for evaluation."})
|
||||||
|
seed: Optional[int] = field(default=42, metadata={"help": "Random seed to be used with data loaders."})
|
||||||
|
lang: Optional[Literal["en", "zh"]] = field(default="en", metadata={"help": "Language used at evaluation."})
|
||||||
|
n_shot: Optional[int] = field(default=5, metadata={"help": "Number of examplars for few-shot learning."})
|
||||||
|
save_dir: Optional[str] = field(default=None, metadata={"help": "Path to save the evaluation results."})
|
||||||
download_mode: Optional[DownloadMode] = field(
|
download_mode: Optional[DownloadMode] = field(
|
||||||
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||||
metadata={"help": "Download mode used for the evaluation datasets."}
|
metadata={"help": "Download mode used for the evaluation datasets."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
from typing import Literal, Optional
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -10,17 +10,18 @@ class FreezeArguments:
|
||||||
"""
|
"""
|
||||||
name_module_trainable: Optional[str] = field(
|
name_module_trainable: Optional[str] = field(
|
||||||
default="mlp",
|
default="mlp",
|
||||||
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
metadata={
|
||||||
|
"help": 'Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||||
Use commas to separate multiple modules. \
|
Use commas to separate multiple modules. \
|
||||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
LLaMA choices: ["mlp", "self_attn"], \
|
||||||
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
|
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
|
||||||
Qwen choices: [\"mlp\", \"attn\"], \
|
Qwen choices: ["mlp", "attn"], \
|
||||||
Phi choices: [\"mlp\", \"mixer\"], \
|
Phi choices: ["mlp", "mixer"], \
|
||||||
Others choices: the same as LLaMA."}
|
Others choices: the same as LLaMA.'
|
||||||
|
},
|
||||||
)
|
)
|
||||||
num_layer_trainable: Optional[int] = field(
|
num_layer_trainable: Optional[int] = field(
|
||||||
default=3,
|
default=3, metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||||
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,37 +32,32 @@ class LoraArguments:
|
||||||
"""
|
"""
|
||||||
additional_target: Optional[str] = field(
|
additional_target: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
|
metadata={
|
||||||
|
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
lora_alpha: Optional[int] = field(
|
lora_alpha: Optional[int] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
|
||||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
|
|
||||||
)
|
|
||||||
lora_dropout: Optional[float] = field(
|
|
||||||
default=0.0,
|
|
||||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
|
||||||
)
|
|
||||||
lora_rank: Optional[int] = field(
|
|
||||||
default=8,
|
|
||||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
|
||||||
)
|
)
|
||||||
|
lora_dropout: Optional[float] = field(default=0.0, metadata={"help": "Dropout rate for the LoRA fine-tuning."})
|
||||||
|
lora_rank: Optional[int] = field(default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."})
|
||||||
lora_target: Optional[str] = field(
|
lora_target: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
metadata={
|
||||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
"help": 'Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||||
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
|
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||||
Phi choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
|
||||||
Others choices: the same as LLaMA."}
|
Phi choices: ["Wqkv", "out_proj", "fc1", "fc2"], \
|
||||||
|
Others choices: the same as LLaMA.'
|
||||||
|
},
|
||||||
)
|
)
|
||||||
lora_bf16_mode: Optional[bool] = field(
|
lora_bf16_mode: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
|
||||||
metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
|
|
||||||
)
|
)
|
||||||
create_new_adapter: Optional[bool] = field(
|
create_new_adapter: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
|
||||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,69 +66,53 @@ class RLHFArguments:
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to the PPO and DPO training.
|
Arguments pertaining to the PPO and DPO training.
|
||||||
"""
|
"""
|
||||||
dpo_beta: Optional[float] = field(
|
dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for the DPO loss."})
|
||||||
default=0.1,
|
|
||||||
metadata={"help": "The beta parameter for the DPO loss."}
|
|
||||||
)
|
|
||||||
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
|
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
|
||||||
default="sigmoid",
|
default="sigmoid", metadata={"help": "The type of DPO loss to use."}
|
||||||
metadata={"help": "The type of DPO loss to use."}
|
|
||||||
)
|
)
|
||||||
dpo_ftx: Optional[float] = field(
|
dpo_ftx: Optional[float] = field(
|
||||||
default=0,
|
default=0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
|
||||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
|
|
||||||
)
|
)
|
||||||
ppo_buffer_size: Optional[int] = field(
|
ppo_buffer_size: Optional[int] = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
|
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
||||||
)
|
)
|
||||||
ppo_epochs: Optional[int] = field(
|
ppo_epochs: Optional[int] = field(
|
||||||
default=4,
|
default=4, metadata={"help": "The number of epochs to perform in a PPO optimization step."}
|
||||||
metadata={"help": "The number of epochs to perform in a PPO optimization step."}
|
|
||||||
)
|
)
|
||||||
ppo_logger: Optional[str] = field(
|
ppo_logger: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'}
|
||||||
metadata={"help": "Log with either \"wandb\" or \"tensorboard\" in PPO training."}
|
|
||||||
)
|
)
|
||||||
ppo_score_norm: Optional[bool] = field(
|
ppo_score_norm: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Use score normalization in PPO training."}
|
||||||
metadata={"help": "Use score normalization in PPO training."}
|
|
||||||
)
|
)
|
||||||
ppo_target: Optional[float] = field(
|
ppo_target: Optional[float] = field(
|
||||||
default=6.0,
|
default=6.0, metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
||||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
|
||||||
)
|
)
|
||||||
ppo_whiten_rewards: Optional[bool] = field(
|
ppo_whiten_rewards: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
|
||||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
|
|
||||||
)
|
)
|
||||||
ref_model: Optional[str] = field(
|
ref_model: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Path to the reference model used for the PPO or DPO training."}
|
||||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
|
|
||||||
)
|
)
|
||||||
ref_model_adapters: Optional[str] = field(
|
ref_model_adapters: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Path to the adapters of the reference model."}
|
||||||
metadata={"help": "Path to the adapters of the reference model."}
|
|
||||||
)
|
)
|
||||||
ref_model_quantization_bit: Optional[int] = field(
|
ref_model_quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "The number of bits to quantize the reference model."}
|
||||||
metadata={"help": "The number of bits to quantize the reference model."}
|
|
||||||
)
|
)
|
||||||
reward_model: Optional[str] = field(
|
reward_model: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Path to the reward model used for the PPO training."}
|
||||||
metadata={"help": "Path to the reward model used for the PPO training."}
|
|
||||||
)
|
)
|
||||||
reward_model_adapters: Optional[str] = field(
|
reward_model_adapters: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Path to the adapters of the reward model."}
|
||||||
metadata={"help": "Path to the adapters of the reward model."}
|
|
||||||
)
|
)
|
||||||
reward_model_quantization_bit: Optional[int] = field(
|
reward_model_quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "The number of bits to quantize the reward model."}
|
||||||
metadata={"help": "The number of bits to quantize the reward model."}
|
|
||||||
)
|
)
|
||||||
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}
|
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,16 +122,13 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||||
default="sft",
|
default="sft", metadata={"help": "Which stage will be performed in training."}
|
||||||
metadata={"help": "Which stage will be performed in training."}
|
|
||||||
)
|
)
|
||||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
||||||
default="lora",
|
default="lora", metadata={"help": "Which fine-tuning method to use."}
|
||||||
metadata={"help": "Which fine-tuning method to use."}
|
|
||||||
)
|
)
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Whether or not to save the training loss curves."}
|
||||||
metadata={"help": "Whether or not to save the training loss curves."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -8,40 +8,37 @@ class GeneratingArguments:
|
||||||
Arguments pertaining to specify the decoding parameters.
|
Arguments pertaining to specify the decoding parameters.
|
||||||
"""
|
"""
|
||||||
do_sample: Optional[bool] = field(
|
do_sample: Optional[bool] = field(
|
||||||
default=True,
|
default=True, metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
||||||
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
|
||||||
)
|
)
|
||||||
temperature: Optional[float] = field(
|
temperature: Optional[float] = field(
|
||||||
default=0.95,
|
default=0.95, metadata={"help": "The value used to modulate the next token probabilities."}
|
||||||
metadata={"help": "The value used to modulate the next token probabilities."}
|
|
||||||
)
|
)
|
||||||
top_p: Optional[float] = field(
|
top_p: Optional[float] = field(
|
||||||
default=0.7,
|
default=0.7,
|
||||||
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
metadata={
|
||||||
|
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
top_k: Optional[int] = field(
|
top_k: Optional[int] = field(
|
||||||
default=50,
|
default=50,
|
||||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
|
||||||
)
|
)
|
||||||
num_beams: Optional[int] = field(
|
num_beams: Optional[int] = field(
|
||||||
default=1,
|
default=1, metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
|
||||||
)
|
)
|
||||||
max_length: Optional[int] = field(
|
max_length: Optional[int] = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
||||||
)
|
)
|
||||||
max_new_tokens: Optional[int] = field(
|
max_new_tokens: Optional[int] = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
||||||
)
|
)
|
||||||
repetition_penalty: Optional[float] = field(
|
repetition_penalty: Optional[float] = field(
|
||||||
default=1.0,
|
default=1.0, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
||||||
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
|
||||||
)
|
)
|
||||||
length_penalty: Optional[float] = field(
|
length_penalty: Optional[float] = field(
|
||||||
default=1.0,
|
default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
||||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Any, Dict, Literal, Optional
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from typing import Any, Dict, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -11,108 +11,82 @@ class ModelArguments:
|
||||||
metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."}
|
metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."}
|
||||||
)
|
)
|
||||||
adapter_name_or_path: Optional[str] = field(
|
adapter_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}
|
||||||
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}
|
|
||||||
)
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}
|
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||||
)
|
)
|
||||||
use_fast_tokenizer: Optional[bool] = field(
|
use_fast_tokenizer: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}
|
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
||||||
)
|
)
|
||||||
resize_vocab: Optional[bool] = field(
|
resize_vocab: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}
|
||||||
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}
|
|
||||||
)
|
)
|
||||||
split_special_tokens: Optional[bool] = field(
|
split_special_tokens: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
|
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||||
)
|
)
|
||||||
model_revision: Optional[str] = field(
|
model_revision: Optional[str] = field(
|
||||||
default="main",
|
default="main",
|
||||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||||
)
|
)
|
||||||
quantization_bit: Optional[int] = field(
|
quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "The number of bits to quantize the model."}
|
||||||
metadata={"help": "The number of bits to quantize the model."}
|
|
||||||
)
|
)
|
||||||
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
||||||
default="nf4",
|
default="nf4", metadata={"help": "Quantization data type to use in int4 training."}
|
||||||
metadata={"help": "Quantization data type to use in int4 training."}
|
|
||||||
)
|
)
|
||||||
double_quantization: Optional[bool] = field(
|
double_quantization: Optional[bool] = field(
|
||||||
default=True,
|
default=True, metadata={"help": "Whether or not to use double quantization in int4 training."}
|
||||||
metadata={"help": "Whether or not to use double quantization in int4 training."}
|
|
||||||
)
|
)
|
||||||
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}
|
||||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}
|
|
||||||
)
|
)
|
||||||
flash_attn: Optional[bool] = field(
|
flash_attn: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Enable FlashAttention-2 for faster training."}
|
||||||
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
|
||||||
)
|
)
|
||||||
shift_attn: Optional[bool] = field(
|
shift_attn: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
|
||||||
)
|
)
|
||||||
use_unsloth: Optional[bool] = field(
|
use_unsloth: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}
|
||||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}
|
|
||||||
)
|
)
|
||||||
disable_gradient_checkpointing: Optional[bool] = field(
|
disable_gradient_checkpointing: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Whether or not to disable gradient checkpointing."}
|
||||||
metadata={"help": "Whether or not to disable gradient checkpointing."}
|
|
||||||
)
|
)
|
||||||
upcast_layernorm: Optional[bool] = field(
|
upcast_layernorm: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
|
||||||
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
|
|
||||||
)
|
)
|
||||||
upcast_lmhead_output: Optional[bool] = field(
|
upcast_lmhead_output: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}
|
||||||
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}
|
|
||||||
)
|
|
||||||
hf_hub_token: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
|
||||||
)
|
|
||||||
ms_hub_token: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Auth token to log in with ModelScope Hub."}
|
|
||||||
)
|
)
|
||||||
|
hf_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."})
|
||||||
|
ms_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with ModelScope Hub."})
|
||||||
export_dir: Optional[str] = field(
|
export_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Path to the directory to save the exported model."}
|
||||||
metadata={"help": "Path to the directory to save the exported model."}
|
|
||||||
)
|
)
|
||||||
export_size: Optional[int] = field(
|
export_size: Optional[int] = field(
|
||||||
default=1,
|
default=1, metadata={"help": "The file shard size (in GB) of the exported model."}
|
||||||
metadata={"help": "The file shard size (in GB) of the exported model."}
|
|
||||||
)
|
)
|
||||||
export_quantization_bit: Optional[int] = field(
|
export_quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "The number of bits to quantize the exported model."}
|
||||||
metadata={"help": "The number of bits to quantize the exported model."}
|
|
||||||
)
|
)
|
||||||
export_quantization_dataset: Optional[str] = field(
|
export_quantization_dataset: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
|
||||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
|
|
||||||
)
|
)
|
||||||
export_quantization_nsamples: Optional[int] = field(
|
export_quantization_nsamples: Optional[int] = field(
|
||||||
default=128,
|
default=128, metadata={"help": "The number of samples used for quantization."}
|
||||||
metadata={"help": "The number of samples used for quantization."}
|
|
||||||
)
|
)
|
||||||
export_quantization_maxlen: Optional[int] = field(
|
export_quantization_maxlen: Optional[int] = field(
|
||||||
default=1024,
|
default=1024, metadata={"help": "The maximum length of the model inputs used for quantization."}
|
||||||
metadata={"help": "The maximum length of the model inputs used for quantization."}
|
|
||||||
)
|
)
|
||||||
export_legacy_format: Optional[bool] = field(
|
export_legacy_format: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}
|
||||||
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}
|
|
||||||
)
|
)
|
||||||
export_hub_model_id: Optional[str] = field(
|
export_hub_model_id: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}
|
||||||
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -122,7 +96,7 @@ class ModelArguments:
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||||
|
|
||||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||||
|
|
||||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import torch
|
|
||||||
import logging
|
|
||||||
import datasets
|
|
||||||
import transformers
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
|
@ -19,24 +20,12 @@ from .model_args import ModelArguments
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_TRAIN_ARGS = [
|
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
]
|
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_TRAIN_CLS = Tuple[
|
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
]
|
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
_INFER_ARGS = [
|
|
||||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
]
|
|
||||||
_INFER_CLS = Tuple[
|
|
||||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
]
|
|
||||||
_EVAL_ARGS = [
|
|
||||||
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
|
||||||
]
|
|
||||||
_EVAL_CLS = Tuple[
|
|
||||||
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
|
@ -77,7 +66,7 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("Multiple adapters are only available for LoRA tuning.")
|
raise ValueError("Multiple adapters are only available for LoRA tuning.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
|
@ -181,18 +170,22 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
logger.info(
|
||||||
training_args.resume_from_checkpoint
|
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
||||||
))
|
training_args.resume_from_checkpoint
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.stage in ["rm", "ppo"]
|
finetuning_args.stage in ["rm", "ppo"]
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
and training_args.resume_from_checkpoint is not None
|
and training_args.resume_from_checkpoint is not None
|
||||||
):
|
):
|
||||||
logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
logger.warning(
|
||||||
training_args.resume_from_checkpoint
|
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||||
))
|
training_args.resume_from_checkpoint
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# postprocess model_args
|
# postprocess model_args
|
||||||
model_args.compute_dtype = (
|
model_args.compute_dtype = (
|
||||||
|
@ -201,10 +194,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
logger.info(
|
||||||
training_args.local_rank, training_args.device, training_args.n_gpu,
|
"Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||||
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
|
training_args.local_rank,
|
||||||
))
|
training_args.device,
|
||||||
|
training_args.n_gpu,
|
||||||
|
bool(training_args.local_rank != -1),
|
||||||
|
str(model_args.compute_dtype),
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
# Set seed before initializing model.
|
# Set seed before initializing model.
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
import torch
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .utils import find_all_linear_modules
|
from .utils import find_all_linear_modules
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from ..hparams import ModelArguments, FinetuningArguments
|
|
||||||
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def init_adapter(
|
def init_adapter(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
|
||||||
model_args: "ModelArguments",
|
|
||||||
finetuning_args: "FinetuningArguments",
|
|
||||||
is_trainable: bool
|
|
||||||
) -> "PreTrainedModel":
|
) -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Initializes the adapters.
|
Initializes the adapters.
|
||||||
|
@ -47,10 +47,10 @@ def init_adapter(
|
||||||
if not num_layers:
|
if not num_layers:
|
||||||
raise ValueError("Current model does not support freeze tuning.")
|
raise ValueError("Current model does not support freeze tuning.")
|
||||||
|
|
||||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||||
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
||||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||||
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
|
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] # noqa: C416
|
||||||
|
|
||||||
trainable_layers = []
|
trainable_layers = []
|
||||||
for module_name in finetuning_args.name_module_trainable:
|
for module_name in finetuning_args.name_module_trainable:
|
||||||
|
@ -69,7 +69,7 @@ def init_adapter(
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
is_mergeable = True
|
is_mergeable = True
|
||||||
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
|
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
|
||||||
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
|
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
|
|
||||||
|
@ -90,10 +90,10 @@ def init_adapter(
|
||||||
if len(adapter_to_merge) > 0:
|
if len(adapter_to_merge) > 0:
|
||||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||||
|
|
||||||
if adapter_to_resume is not None: # resume lora training
|
if adapter_to_resume is not None: # resume lora training
|
||||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
|
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
|
||||||
|
|
||||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
target_modules = find_all_linear_modules(model)
|
target_modules = find_all_linear_modules(model)
|
||||||
else:
|
else:
|
||||||
|
@ -103,11 +103,12 @@ def init_adapter(
|
||||||
"r": finetuning_args.lora_rank,
|
"r": finetuning_args.lora_rank,
|
||||||
"target_modules": target_modules,
|
"target_modules": target_modules,
|
||||||
"lora_alpha": finetuning_args.lora_alpha,
|
"lora_alpha": finetuning_args.lora_alpha,
|
||||||
"lora_dropout": finetuning_args.lora_dropout
|
"lora_dropout": finetuning_args.lora_dropout,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
||||||
|
|
||||||
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
|
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
|
||||||
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters:
|
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters:
|
||||||
unsloth_peft_kwargs["loftq_config"] = {}
|
unsloth_peft_kwargs["loftq_config"] = {}
|
||||||
|
@ -124,7 +125,7 @@ def init_adapter(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
inference_mode=False,
|
inference_mode=False,
|
||||||
modules_to_save=finetuning_args.additional_target,
|
modules_to_save=finetuning_args.additional_target,
|
||||||
**peft_kwargs
|
**peft_kwargs,
|
||||||
)
|
)
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
@ -7,12 +8,14 @@ from trl import AutoModelForCausalLMWithValueHead
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
|
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||||
from .utils import load_valuehead_params, register_autoclass
|
from .utils import load_valuehead_params, register_autoclass
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
from ..hparams import ModelArguments, FinetuningArguments
|
|
||||||
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -29,7 +32,7 @@ def load_model_and_tokenizer(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: Optional[bool] = False,
|
is_trainable: Optional[bool] = False,
|
||||||
add_valuehead: Optional[bool] = False
|
add_valuehead: Optional[bool] = False,
|
||||||
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
|
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
|
||||||
r"""
|
r"""
|
||||||
Loads pretrained model and tokenizer.
|
Loads pretrained model and tokenizer.
|
||||||
|
@ -43,7 +46,7 @@ def load_model_and_tokenizer(
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
"revision": model_args.model_revision,
|
"revision": model_args.model_revision,
|
||||||
"token": model_args.hf_hub_token
|
"token": model_args.hf_hub_token,
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
@ -51,7 +54,7 @@ def load_model_and_tokenizer(
|
||||||
use_fast=model_args.use_fast_tokenizer,
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
split_special_tokens=model_args.split_special_tokens,
|
split_special_tokens=model_args.split_special_tokens,
|
||||||
padding_side="right",
|
padding_side="right",
|
||||||
**config_kwargs
|
**config_kwargs,
|
||||||
)
|
)
|
||||||
patch_tokenizer(tokenizer)
|
patch_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
@ -61,7 +64,8 @@ def load_model_and_tokenizer(
|
||||||
model = None
|
model = None
|
||||||
if is_trainable and model_args.use_unsloth:
|
if is_trainable and model_args.use_unsloth:
|
||||||
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
|
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
|
||||||
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
||||||
|
|
||||||
unsloth_kwargs = {
|
unsloth_kwargs = {
|
||||||
"model_name": model_args.model_name_or_path,
|
"model_name": model_args.model_name_or_path,
|
||||||
"max_seq_length": model_args.model_max_length,
|
"max_seq_length": model_args.model_max_length,
|
||||||
|
@ -69,7 +73,7 @@ def load_model_and_tokenizer(
|
||||||
"load_in_4bit": model_args.quantization_bit == 4,
|
"load_in_4bit": model_args.quantization_bit == 4,
|
||||||
"token": model_args.hf_hub_token,
|
"token": model_args.hf_hub_token,
|
||||||
"device_map": get_current_device(),
|
"device_map": get_current_device(),
|
||||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||||
}
|
}
|
||||||
if getattr(config, "model_type", None) == "llama":
|
if getattr(config, "model_type", None) == "llama":
|
||||||
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
|
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
|
||||||
|
@ -89,7 +93,7 @@ def load_model_and_tokenizer(
|
||||||
config=config,
|
config=config,
|
||||||
torch_dtype=model_args.compute_dtype,
|
torch_dtype=model_args.compute_dtype,
|
||||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||||
**config_kwargs
|
**config_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
patch_model(model, tokenizer, model_args, is_trainable)
|
patch_model(model, tokenizer, model_args, is_trainable)
|
||||||
|
@ -119,9 +123,11 @@ def load_model_and_tokenizer(
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
trainable_params, all_param = count_parameters(model)
|
trainable_params, all_param = count_parameters(model)
|
||||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
logger.info(
|
||||||
trainable_params, all_param, 100 * trainable_params / all_param
|
"trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||||
))
|
trainable_params, all_param, 100 * trainable_params / all_param
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import os
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import os
|
||||||
import random
|
import random
|
||||||
|
from contextlib import nullcontext
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
from datasets import load_dataset
|
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
@ -17,9 +17,11 @@ from ..extras.misc import get_current_device, infer_optim_dtype
|
||||||
from ..extras.packages import is_flash_attn2_available
|
from ..extras.packages import is_flash_attn2_available
|
||||||
from ..extras.patches.llama_patch import apply_llama_patch
|
from ..extras.patches.llama_patch import apply_llama_patch
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..hparams import ModelArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,7 +42,8 @@ def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToke
|
||||||
Resize token embeddings.
|
Resize token embeddings.
|
||||||
"""
|
"""
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
import deepspeed # type: ignore
|
import deepspeed # type: ignore
|
||||||
|
|
||||||
params = [model.get_input_embeddings().weight]
|
params = [model.get_input_embeddings().weight]
|
||||||
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
||||||
params.append(model.get_output_embeddings().weight)
|
params.append(model.get_output_embeddings().weight)
|
||||||
|
@ -88,7 +91,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
||||||
sample_idx = random.randint(0, len(dataset) - 1)
|
sample_idx = random.randint(0, len(dataset) - 1)
|
||||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||||
if sample["input_ids"].size(1) >= maxlen:
|
if sample["input_ids"].size(1) >= maxlen:
|
||||||
break # TODO: fix large maxlen
|
break # TODO: fix large maxlen
|
||||||
|
|
||||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||||
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
||||||
|
@ -119,9 +122,9 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is
|
||||||
scaling_factor = 2.0
|
scaling_factor = 2.0
|
||||||
|
|
||||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||||
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
logger.info(
|
||||||
model_args.rope_scaling, scaling_factor
|
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
|
||||||
))
|
)
|
||||||
|
|
||||||
|
|
||||||
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
|
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
|
||||||
|
@ -146,22 +149,22 @@ def _configure_quantization(
|
||||||
config: "PretrainedConfig",
|
config: "PretrainedConfig",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
config_kwargs: Dict[str, Any]
|
config_kwargs: Dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||||
"""
|
"""
|
||||||
if getattr(config, "quantization_config", None): # gptq
|
if getattr(config, "quantization_config", None): # gptq
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
|
||||||
config_kwargs["device_map"] = {"": get_current_device()}
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
|
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
|
||||||
quantization_config["use_exllama"] = False # disable exllama
|
quantization_config["use_exllama"] = False # disable exllama
|
||||||
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
|
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
|
||||||
|
|
||||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||||
from accelerate.utils import get_max_memory
|
from accelerate.utils import get_max_memory
|
||||||
|
@ -172,13 +175,13 @@ def _configure_quantization(
|
||||||
config_kwargs["quantization_config"] = GPTQConfig(
|
config_kwargs["quantization_config"] = GPTQConfig(
|
||||||
bits=model_args.export_quantization_bit,
|
bits=model_args.export_quantization_bit,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
dataset=_get_quantization_dataset(tokenizer, model_args)
|
dataset=_get_quantization_dataset(tokenizer, model_args),
|
||||||
)
|
)
|
||||||
config_kwargs["device_map"] = "auto"
|
config_kwargs["device_map"] = "auto"
|
||||||
config_kwargs["max_memory"] = get_max_memory()
|
config_kwargs["max_memory"] = get_max_memory()
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||||
|
|
||||||
elif model_args.quantization_bit is not None: # bnb
|
elif model_args.quantization_bit is not None: # bnb
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
|
||||||
|
@ -192,7 +195,7 @@ def _configure_quantization(
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||||
bnb_4bit_quant_type=model_args.quantization_type
|
bnb_4bit_quant_type=model_args.quantization_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
config_kwargs["device_map"] = {"": get_current_device()}
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
|
@ -200,9 +203,7 @@ def _configure_quantization(
|
||||||
|
|
||||||
|
|
||||||
def _prepare_model_for_training(
|
def _prepare_model_for_training(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head"
|
||||||
model_args: "ModelArguments",
|
|
||||||
output_layer_name: Optional[str] = "lm_head"
|
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Includes:
|
Includes:
|
||||||
|
@ -222,10 +223,11 @@ def _prepare_model_for_training(
|
||||||
logger.warning("Current model does not support gradient checkpointing.")
|
logger.warning("Current model does not support gradient checkpointing.")
|
||||||
else:
|
else:
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||||
logger.info("Gradient checkpointing enabled.")
|
logger.info("Gradient checkpointing enabled.")
|
||||||
|
|
||||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||||
|
|
||||||
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||||
return output.to(torch.float32)
|
return output.to(torch.float32)
|
||||||
|
|
||||||
|
@ -244,9 +246,9 @@ def patch_config(
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
config_kwargs: Dict[str, Any],
|
config_kwargs: Dict[str, Any],
|
||||||
is_trainable: bool
|
is_trainable: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "qwen":
|
if getattr(config, "model_type", None) == "qwen":
|
||||||
|
@ -266,10 +268,7 @@ def patch_config(
|
||||||
|
|
||||||
|
|
||||||
def patch_model(
|
def patch_model(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
model_args: "ModelArguments",
|
|
||||||
is_trainable: bool
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if "GenerationMixin" not in str(model.generate.__func__):
|
if "GenerationMixin" not in str(model.generate.__func__):
|
||||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||||
|
|
|
@ -1,16 +1,19 @@
|
||||||
import torch
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List
|
from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from ..extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import get_current_device
|
from ..extras.misc import get_current_device
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
from ..hparams import ModelArguments, DataArguments, FinetuningArguments
|
|
||||||
|
from ..hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -21,7 +24,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
|
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
|
||||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
|
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
|
||||||
"""
|
"""
|
||||||
if getattr(model, "quantization_method", None): # already set on current device
|
if getattr(model, "quantization_method", None): # already set on current device
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -31,7 +34,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
and model.config.model_type != "chatglm"
|
and model.config.model_type != "chatglm"
|
||||||
):
|
):
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
||||||
|
|
||||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
||||||
max_memory = get_balanced_memory(model, **kwargs)
|
max_memory = get_balanced_memory(model, **kwargs)
|
||||||
|
@ -55,6 +58,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||||
linear_cls = torch.nn.Linear
|
linear_cls = torch.nn.Linear
|
||||||
elif quantization_method == "bitsandbytes":
|
elif quantization_method == "bitsandbytes":
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
||||||
else:
|
else:
|
||||||
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
|
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
|
||||||
|
@ -65,10 +69,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||||
|
|
||||||
module_names = set()
|
module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
|
||||||
isinstance(module, linear_cls)
|
|
||||||
and not any([output_layer in name for output_layer in output_layer_names])
|
|
||||||
):
|
|
||||||
module_names.add(name.split(".")[-1])
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||||
|
@ -76,16 +77,14 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||||
|
|
||||||
|
|
||||||
def get_modelcard_args(
|
def get_modelcard_args(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments"
|
||||||
data_args: "DataArguments",
|
|
||||||
finetuning_args: "FinetuningArguments"
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"tasks": "text-generation",
|
"tasks": "text-generation",
|
||||||
"license": "other",
|
"license": "other",
|
||||||
"finetuned_from": model_args.model_name_or_path,
|
"finetuned_from": model_args.model_name_or_path,
|
||||||
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
||||||
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
|
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,14 +94,11 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
||||||
|
|
||||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||||
"""
|
"""
|
||||||
kwargs = {
|
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
|
||||||
"path_or_repo_id": path_or_repo_id,
|
|
||||||
"cache_dir": model_args.cache_dir,
|
|
||||||
"token": model_args.hf_hub_token
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
|
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
|
||||||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||||
return {key: f.get_tensor(key) for key in f.keys()}
|
return {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Sequence, Tuple
|
from typing import Any, Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +21,7 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||||
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
||||||
padded_tensor[start:end] = feature[start:end]
|
padded_tensor[start:end] = feature[start:end]
|
||||||
padded_labels.append(padded_tensor)
|
padded_labels.append(padded_tensor)
|
||||||
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||||
|
|
||||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -34,10 +35,12 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||||
for key in ("chosen_ids", "rejected_ids"):
|
for key in ("chosen_ids", "rejected_ids"):
|
||||||
for feature in features:
|
for feature in features:
|
||||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
||||||
concatenated_features.append({
|
concatenated_features.append(
|
||||||
"input_ids": feature["prompt_ids"] + feature[key],
|
{
|
||||||
"attention_mask": [1] * (prompt_len + answer_len)
|
"input_ids": feature["prompt_ids"] + feature[key],
|
||||||
})
|
"attention_mask": [1] * (prompt_len + answer_len),
|
||||||
|
}
|
||||||
|
)
|
||||||
label_positions.append((prompt_len, answer_len))
|
label_positions.append((prompt_len, answer_len))
|
||||||
|
|
||||||
batch = self.tokenizer.pad(
|
batch = self.tokenizer.pad(
|
||||||
|
|
|
@ -1,19 +1,20 @@
|
||||||
import torch
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import BatchEncoding, Trainer
|
from transformers import BatchEncoding, Trainer
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
from trl.trainer.utils import disable_dropout_in_model
|
from trl.trainer.utils import disable_dropout_in_model
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
class CustomDPOTrainer(DPOTrainer):
|
class CustomDPOTrainer(DPOTrainer):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
beta: float,
|
beta: float,
|
||||||
|
@ -22,15 +23,15 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
model: Union["PreTrainedModel", torch.nn.Module],
|
model: Union["PreTrainedModel", torch.nn.Module],
|
||||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||||
disable_dropout: Optional[bool] = True,
|
disable_dropout: Optional[bool] = True,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
if disable_dropout:
|
if disable_dropout:
|
||||||
disable_dropout_in_model(model)
|
disable_dropout_in_model(model)
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
disable_dropout_in_model(ref_model)
|
disable_dropout_in_model(ref_model)
|
||||||
|
|
||||||
self.use_dpo_data_collator = True # hack to avoid warning
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
self.generate_during_eval = False # disable at evaluation
|
self.generate_during_eval = False # disable at evaluation
|
||||||
self.label_pad_token_id = IGNORE_INDEX
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
self.padding_value = 0
|
self.padding_value = 0
|
||||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||||
|
@ -53,42 +54,29 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
if not (
|
if not (
|
||||||
getattr(ref_model, "is_loaded_in_8bit", False)
|
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||||
or getattr(ref_model, "is_loaded_in_4bit", False)
|
): # quantized models are already set on the correct device
|
||||||
): # quantized models are already set on the correct device
|
|
||||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
def sft_loss(
|
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
|
||||||
self,
|
|
||||||
chosen_logits: torch.FloatTensor,
|
|
||||||
chosen_labels: torch.LongTensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""
|
r"""
|
||||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||||
"""
|
"""
|
||||||
all_logps = self.get_batch_logps(
|
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
|
||||||
chosen_logits,
|
|
||||||
chosen_labels,
|
|
||||||
average_log_prob=True
|
|
||||||
)
|
|
||||||
return -all_logps
|
return -all_logps
|
||||||
|
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self,
|
self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
|
||||||
model: "PreTrainedModel",
|
|
||||||
batch: Dict[str, torch.Tensor]
|
|
||||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||||
|
|
||||||
all_logits = model(
|
all_logits = model(
|
||||||
input_ids=batch_copied["input_ids"],
|
input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
|
||||||
attention_mask=batch_copied["attention_mask"],
|
|
||||||
return_dict=True
|
|
||||||
).logits.to(torch.float32)
|
).logits.to(torch.float32)
|
||||||
|
|
||||||
all_logps = self.get_batch_logps(
|
all_logps = self.get_batch_logps(
|
||||||
|
@ -106,7 +94,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
self,
|
self,
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
batch: Dict[str, torch.Tensor],
|
batch: Dict[str, torch.Tensor],
|
||||||
train_eval: Optional[Literal["train", "eval"]] = "train"
|
train_eval: Optional[Literal["train", "eval"]] = "train",
|
||||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ...data import get_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
|
@ -12,8 +13,10 @@ from ...train.dpo.collator import DPODataCollatorWithPadding
|
||||||
from ...train.dpo.trainer import CustomDPOTrainer
|
from ...train.dpo.trainer import CustomDPOTrainer
|
||||||
from ...train.utils import create_modelcard_and_push, create_ref_model
|
from ...train.utils import create_modelcard_and_push, create_ref_model
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
from ...hparams import DataArguments, FinetuningArguments
|
from ...hparams import DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,25 +25,25 @@ def run_dpo(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||||
data_collator = DPODataCollatorWithPadding(
|
data_collator = DPODataCollatorWithPadding(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=8,
|
pad_to_multiple_of=8,
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create reference model
|
# Create reference model
|
||||||
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||||
ref_model = model
|
ref_model = model
|
||||||
else:
|
else:
|
||||||
ref_model = create_ref_model(model_args, finetuning_args)
|
ref_model = create_ref_model(model_args, finetuning_args)
|
||||||
|
|
||||||
# Update arguments
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
@ -54,7 +57,7 @@ def run_dpo(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**split_dataset(dataset, data_args, training_args)
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
@ -70,7 +73,7 @@ def run_dpo(
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||||
for key in remove_keys:
|
for key in remove_keys:
|
||||||
metrics.pop(key)
|
metrics.pop(key)
|
||||||
|
|
|
@ -1,27 +1,28 @@
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
import torch
|
||||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
from tqdm import tqdm
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
|
||||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||||
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import PPODecorators, logprobs_from_logits
|
from trl.core import PPODecorators, logprobs_from_logits
|
||||||
|
|
||||||
from ...extras.callbacks import LogCallback, FixValueHeadModelCallback
|
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from .utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from ...hparams import ModelArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
|
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -40,7 +41,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: List["TrainerCallback"],
|
callbacks: List["TrainerCallback"],
|
||||||
reward_model: "AutoModelForCausalLMWithValueHead",
|
reward_model: "AutoModelForCausalLMWithValueHead",
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
|
|
||||||
|
@ -52,7 +53,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.generation_config = GenerationConfig(
|
self.generation_config = GenerationConfig(
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
**generating_args.to_dict()
|
**generating_args.to_dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
|
@ -71,7 +72,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
if not (
|
if not (
|
||||||
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||||
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
||||||
): # quantized models are already set on the correct device
|
): # quantized models are already set on the correct device
|
||||||
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
||||||
else:
|
else:
|
||||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||||
|
@ -111,9 +112,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
logger.info(" Num examples = {}".format(num_examples))
|
logger.info(" Num examples = {}".format(num_examples))
|
||||||
logger.info(" Num Epochs = {}".format(num_train_epochs))
|
logger.info(" Num Epochs = {}".format(num_train_epochs))
|
||||||
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
|
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
|
||||||
logger.info(" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
|
logger.info(
|
||||||
total_train_batch_size
|
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
|
||||||
))
|
total_train_batch_size
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
|
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
|
||||||
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
|
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
|
||||||
logger.info(" Total training steps = {}".format(max_steps))
|
logger.info(" Total training steps = {}".format(max_steps))
|
||||||
|
@ -138,10 +141,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Get inputs
|
# Get inputs
|
||||||
self.tokenizer.padding_side = "right" # change padding side
|
self.tokenizer.padding_side = "right" # change padding side
|
||||||
queries, responses, rewards = [], [], []
|
queries, responses, rewards = [], [], []
|
||||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||||
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
|
mini_batch_queries, mini_batch_responses = self.get_inputs(
|
||||||
|
batch[idx : idx + self.config.mini_batch_size]
|
||||||
|
)
|
||||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
||||||
queries.extend(mini_batch_queries)
|
queries.extend(mini_batch_queries)
|
||||||
responses.extend(mini_batch_responses)
|
responses.extend(mini_batch_responses)
|
||||||
|
@ -154,7 +159,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
|
|
||||||
# Run PPO step
|
# Run PPO step
|
||||||
stats = self.step(queries, responses, rewards)
|
stats = self.step(queries, responses, rewards)
|
||||||
self.tokenizer.padding_side = "left" # restore padding side
|
self.tokenizer.padding_side = "left" # restore padding side
|
||||||
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
||||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||||
|
|
||||||
|
@ -163,18 +168,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
|
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
|
||||||
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||||
self.log_stats(stats, batch, rewards)
|
self.log_stats(stats, batch, rewards)
|
||||||
except:
|
except Exception:
|
||||||
logger.warning("Failed to save stats due to unknown errors.")
|
logger.warning("Failed to save stats due to unknown errors.")
|
||||||
|
|
||||||
self.state.global_step += 1
|
self.state.global_step += 1
|
||||||
self.log_callback.on_step_end(self.args, self.state, self.control)
|
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
|
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
|
||||||
logs = dict(
|
logs = dict(
|
||||||
loss=round(loss_meter.avg, 4),
|
loss=round(loss_meter.avg, 4),
|
||||||
reward=round(reward_meter.avg, 4),
|
reward=round(reward_meter.avg, 4),
|
||||||
learning_rate=stats["ppo/learning_rate"],
|
learning_rate=stats["ppo/learning_rate"],
|
||||||
epoch=round(step / steps_in_epoch, 2)
|
epoch=round(step / steps_in_epoch, 2),
|
||||||
)
|
)
|
||||||
tqdm.write(str(logs))
|
tqdm.write(str(logs))
|
||||||
logs["step"] = step
|
logs["step"] = step
|
||||||
|
@ -183,10 +188,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
loss_meter.reset()
|
loss_meter.reset()
|
||||||
reward_meter.reset()
|
reward_meter.reset()
|
||||||
|
|
||||||
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
if (step + 1) % self.args.save_steps == 0: # save checkpoint
|
||||||
self.save_model(os.path.join(
|
self.save_model(
|
||||||
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
|
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
|
||||||
))
|
)
|
||||||
self.save_callback.on_save(
|
self.save_callback.on_save(
|
||||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||||
)
|
)
|
||||||
|
@ -207,35 +212,33 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
if self.model_args.upcast_layernorm:
|
if self.model_args.upcast_layernorm:
|
||||||
layernorm_params = dump_layernorm(self.model)
|
layernorm_params = dump_layernorm(self.model)
|
||||||
|
|
||||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||||
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||||
for k, v in batch.items():
|
for k, v in batch.items():
|
||||||
batch[k] = v[:, start_index:]
|
batch[k] = v[:, start_index:]
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||||
generation_config=self.generation_config,
|
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||||
logits_processor=get_logits_processor(),
|
|
||||||
**batch
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.model_args.upcast_layernorm:
|
if self.model_args.upcast_layernorm:
|
||||||
restore_layernorm(self.model, layernorm_params)
|
restore_layernorm(self.model, layernorm_params)
|
||||||
|
|
||||||
query = batch["input_ids"].detach().cpu()
|
query = batch["input_ids"].detach().cpu()
|
||||||
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
|
response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
|
||||||
queries, responses = [], []
|
queries, responses = [], []
|
||||||
for i in range(len(query)):
|
for i in range(len(query)):
|
||||||
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||||
|
|
||||||
if len(response_index) == 0:
|
if len(response_index) == 0:
|
||||||
response_length = 1 # allow empty response
|
response_length = 1 # allow empty response
|
||||||
else:
|
else:
|
||||||
response_length = response_index[-1].item() + 1
|
response_length = response_index[-1].item() + 1
|
||||||
|
|
||||||
queries.append(query[i, query_start_index:]) # remove padding from left
|
queries.append(query[i, query_start_index:]) # remove padding from left
|
||||||
responses.append(response[i, :response_length]) # remove padding from right
|
responses.append(response[i, :response_length]) # remove padding from right
|
||||||
|
|
||||||
return queries, responses
|
return queries, responses
|
||||||
|
|
||||||
|
@ -244,7 +247,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self,
|
self,
|
||||||
queries: List[torch.Tensor],
|
queries: List[torch.Tensor],
|
||||||
responses: List[torch.Tensor],
|
responses: List[torch.Tensor],
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead"
|
unwrapped_model: "AutoModelForCausalLMWithValueHead",
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Computes scores using given reward model.
|
Computes scores using given reward model.
|
||||||
|
@ -264,17 +267,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
|
|
||||||
batch = self.prepare_model_inputs(queries, responses)
|
batch = self.prepare_model_inputs(queries, responses)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
rewards = []
|
rewards = []
|
||||||
for i in range(values.size(0)):
|
for i in range(values.size(0)):
|
||||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||||
|
|
||||||
if self.finetuning_args.reward_model_type == "lora":
|
if self.finetuning_args.reward_model_type == "lora":
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
@ -289,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
responses: torch.Tensor,
|
responses: torch.Tensor,
|
||||||
model_inputs: dict,
|
model_inputs: dict,
|
||||||
return_logits: Optional[bool] = False,
|
return_logits: Optional[bool] = False,
|
||||||
response_masks: Optional[torch.Tensor] = None
|
response_masks: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Calculates model outputs in multiple batches.
|
Calculates model outputs in multiple batches.
|
||||||
|
@ -312,7 +315,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
input_ids = input_kwargs["input_ids"]
|
input_ids = input_kwargs["input_ids"]
|
||||||
attention_mask = input_kwargs["attention_mask"]
|
attention_mask = input_kwargs["attention_mask"]
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
logits, _, values = model(**input_kwargs)
|
logits, _, values = model(**input_kwargs)
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
|
@ -325,14 +328,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
|
|
||||||
for j in range(len(query_batch)):
|
for j in range(len(query_batch)):
|
||||||
start = len(query_batch[j]) - 1
|
start = len(query_batch[j]) - 1
|
||||||
if attention_mask[j, 0] == 0: # offset left padding
|
if attention_mask[j, 0] == 0: # offset left padding
|
||||||
start += attention_mask[j, :].nonzero()[0].item()
|
start += attention_mask[j, :].nonzero()[0].item()
|
||||||
end = start + len(response_batch[j])
|
end = start + len(response_batch[j])
|
||||||
|
|
||||||
if response_masks is not None:
|
if response_masks is not None:
|
||||||
response_masks_batch = torch.cat(
|
response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
|
||||||
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
|
|
||||||
)[1:]
|
|
||||||
|
|
||||||
masks[j, :start] = 0
|
masks[j, :start] = 0
|
||||||
masks[j, end:] = 0
|
masks[j, end:] = 0
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import json
|
import json
|
||||||
import torch
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from ...extras.packages import is_requests_available
|
from ...extras.packages import is_requests_available
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
@ -21,16 +23,18 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
|
||||||
|
|
||||||
|
|
||||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||||
if target == "reward": # save default head temporarily
|
if target == "reward": # save default head temporarily
|
||||||
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
|
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
|
||||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
||||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
||||||
|
|
||||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||||
model.v_head.load_state_dict({
|
model.v_head.load_state_dict(
|
||||||
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
{
|
||||||
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
||||||
})
|
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||||
|
|
|
@ -1,23 +1,26 @@
|
||||||
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from trl import PPOConfig
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
|
||||||
from transformers import DataCollatorWithPadding
|
from transformers import DataCollatorWithPadding
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
from trl import PPOConfig
|
||||||
|
|
||||||
from ...data import get_dataset
|
from ...data import get_dataset
|
||||||
from ...extras.callbacks import FixValueHeadModelCallback
|
from ...extras.callbacks import FixValueHeadModelCallback
|
||||||
from ...extras.misc import fix_valuehead_checkpoint
|
from ...extras.misc import fix_valuehead_checkpoint
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model_and_tokenizer
|
from ...model import load_model_and_tokenizer
|
||||||
from ...train.utils import create_ref_model, create_reward_model
|
|
||||||
from ...train.ppo.trainer import CustomPPOTrainer
|
from ...train.ppo.trainer import CustomPPOTrainer
|
||||||
|
from ...train.utils import create_ref_model, create_reward_model
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
def run_ppo(
|
def run_ppo(
|
||||||
|
@ -26,12 +29,14 @@ def run_ppo(
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
model, tokenizer = load_model_and_tokenizer(
|
||||||
|
model_args, finetuning_args, training_args.do_train, add_valuehead=True
|
||||||
|
)
|
||||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
|
||||||
|
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||||
|
|
||||||
# Create reference model and reward model
|
# Create reference model and reward model
|
||||||
|
@ -55,7 +60,7 @@ def run_ppo(
|
||||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||||
use_score_norm=finetuning_args.ppo_score_norm,
|
use_score_norm=finetuning_args.ppo_score_norm,
|
||||||
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
|
@ -70,7 +75,7 @@ def run_ppo(
|
||||||
training_args.lr_scheduler_type,
|
training_args.lr_scheduler_type,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
||||||
num_training_steps=num_training_steps
|
num_training_steps=num_training_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
@ -88,7 +93,7 @@ def run_ppo(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=lr_scheduler
|
lr_scheduler=lr_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
@ -97,6 +102,6 @@ def run_ppo(
|
||||||
ppo_trainer.save_model()
|
ppo_trainer.save_model()
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
from transformers import DataCollatorForLanguageModeling, Trainer
|
||||||
|
|
||||||
from ...data import get_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
|
@ -9,9 +10,11 @@ from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model_and_tokenizer
|
from ...model import load_model_and_tokenizer
|
||||||
from ...train.utils import create_modelcard_and_push
|
from ...train.utils import create_modelcard_and_push
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from ...hparams import ModelArguments, DataArguments, FinetuningArguments
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
def run_pt(
|
def run_pt(
|
||||||
|
@ -19,7 +22,7 @@ def run_pt(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
|
||||||
|
@ -32,7 +35,7 @@ def run_pt(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**split_dataset(dataset, data_args, training_args)
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Sequence
|
from typing import Any, Dict, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import DataCollatorWithPadding
|
from transformers import DataCollatorWithPadding
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,8 +21,9 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||||
features = [
|
features = [
|
||||||
{
|
{
|
||||||
"input_ids": feature["prompt_ids"] + feature[key],
|
"input_ids": feature["prompt_ids"] + feature[key],
|
||||||
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
|
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])),
|
||||||
}
|
}
|
||||||
for key in ("chosen_ids", "rejected_ids") for feature in features
|
for key in ("chosen_ids", "rejected_ids")
|
||||||
|
for feature in features
|
||||||
]
|
]
|
||||||
return super().__call__(features)
|
return super().__call__(features)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import numpy as np
|
|
||||||
from typing import Dict, Sequence, Tuple, Union
|
from typing import Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||||
preds, _ = eval_preds
|
preds, _ = eval_preds
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import os
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.trainer import PredictionOutput
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -21,13 +23,10 @@ class PairwiseTrainer(Trainer):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.can_return_loss = True # override property to return eval_loss
|
self.can_return_loss = True # override property to return eval_loss
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self,
|
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
|
||||||
model: "PreTrainedModel",
|
|
||||||
inputs: Dict[str, torch.Tensor],
|
|
||||||
return_outputs: Optional[bool] = False
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
r"""
|
r"""
|
||||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||||
|
@ -68,9 +67,9 @@ class PairwiseTrainer(Trainer):
|
||||||
assert div_index > 0
|
assert div_index > 0
|
||||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||||
if return_outputs: # use the score on the last token except pad token for inference
|
if return_outputs: # use the score on the last token except pad token for inference
|
||||||
chosen_scores.append(chosen_rewards[i, chosen_length-1])
|
chosen_scores.append(chosen_rewards[i, chosen_length - 1])
|
||||||
rejected_scores.append(rejected_rewards[i, rejected_length-1])
|
rejected_scores.append(rejected_rewards[i, rejected_length - 1])
|
||||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||||
|
|
||||||
loss = loss / batch_size
|
loss = loss / batch_size
|
||||||
|
@ -80,10 +79,7 @@ class PairwiseTrainer(Trainer):
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||||
self,
|
|
||||||
predict_results: "PredictionOutput"
|
|
||||||
) -> None:
|
|
||||||
r"""
|
r"""
|
||||||
Saves model predictions to `output_dir`.
|
Saves model predictions to `output_dir`.
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ...data import get_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
|
@ -13,9 +14,11 @@ from ...train.rm.metric import compute_accuracy
|
||||||
from ...train.rm.trainer import PairwiseTrainer
|
from ...train.rm.trainer import PairwiseTrainer
|
||||||
from ...train.utils import create_modelcard_and_push
|
from ...train.utils import create_modelcard_and_push
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from ...hparams import ModelArguments, DataArguments, FinetuningArguments
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
def run_rm(
|
def run_rm(
|
||||||
|
@ -23,15 +26,17 @@ def run_rm(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
model, tokenizer = load_model_and_tokenizer(
|
||||||
|
model_args, finetuning_args, training_args.do_train, add_valuehead=True
|
||||||
|
)
|
||||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||||
|
|
||||||
# Update arguments
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
@ -42,7 +47,7 @@ def run_rm(
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||||
compute_metrics=compute_accuracy,
|
compute_metrics=compute_accuracy,
|
||||||
**split_dataset(dataset, data_args, training_args)
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import numpy as np
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.packages import (
|
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
|
||||||
is_jieba_available, is_nltk_available, is_rouge_available
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
@ -14,7 +14,7 @@ if is_jieba_available():
|
||||||
import jieba
|
import jieba
|
||||||
|
|
||||||
if is_nltk_available():
|
if is_nltk_available():
|
||||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
||||||
|
|
||||||
if is_rouge_available():
|
if is_rouge_available():
|
||||||
from rouge_chinese import Rouge
|
from rouge_chinese import Rouge
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import os
|
||||||
import numpy as np
|
|
||||||
import torch.nn as nn
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from transformers import Seq2SeqTrainer
|
from transformers import Seq2SeqTrainer
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
|
@ -33,16 +35,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
|
|
||||||
Subclass and override to inject custom behavior.
|
Subclass and override to inject custom behavior.
|
||||||
"""
|
"""
|
||||||
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
|
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
|
||||||
if self.args.predict_with_generate:
|
if self.args.predict_with_generate:
|
||||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||||
if prompt_len > label_len:
|
if prompt_len > label_len:
|
||||||
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
||||||
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
|
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
|
||||||
inputs["labels"] = inputs["labels"][:, :prompt_len]
|
inputs["labels"] = inputs["labels"][:, :prompt_len]
|
||||||
|
|
||||||
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
|
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
|
||||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||||
)
|
)
|
||||||
if generated_tokens is not None and self.args.predict_with_generate:
|
if generated_tokens is not None and self.args.predict_with_generate:
|
||||||
|
@ -51,23 +53,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
|
|
||||||
return loss, generated_tokens, labels
|
return loss, generated_tokens, labels
|
||||||
|
|
||||||
def _pad_tensors_to_target_len(
|
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
self,
|
|
||||||
src_tensor: torch.Tensor,
|
|
||||||
tgt_tensor: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""
|
r"""
|
||||||
Pads the tensor to the same length as the target tensor.
|
Pads the tensor to the same length as the target tensor.
|
||||||
"""
|
"""
|
||||||
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
|
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
|
||||||
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
|
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
|
||||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
|
||||||
return padded_tensor.contiguous() # in contiguous memory
|
return padded_tensor.contiguous() # in contiguous memory
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||||
self,
|
|
||||||
predict_results: "PredictionOutput"
|
|
||||||
) -> None:
|
|
||||||
r"""
|
r"""
|
||||||
Saves model predictions to `output_dir`.
|
Saves model predictions to `output_dir`.
|
||||||
|
|
||||||
|
@ -79,15 +74,23 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||||
|
|
||||||
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
labels = np.where(
|
||||||
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
preds = np.where(
|
||||||
|
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(len(preds)):
|
for i in range(len(preds)):
|
||||||
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
|
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
|
||||||
if len(pad_len):
|
if len(pad_len):
|
||||||
preds[i] = np.concatenate((preds[i][pad_len[0]:], preds[i][:pad_len[0]]), axis=-1) # move pad token to last
|
preds[i] = np.concatenate(
|
||||||
|
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
|
||||||
|
) # move pad token to last
|
||||||
|
|
||||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
decoded_labels = self.tokenizer.batch_decode(
|
||||||
|
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||||
|
|
||||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ...data import get_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
|
@ -15,7 +16,8 @@ from ...train.utils import create_modelcard_and_push
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
def run_sft(
|
def run_sft(
|
||||||
|
@ -24,29 +26,31 @@ def run_sft(
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation
|
tokenizer.padding_side = "left" # use left-padding in generation
|
||||||
|
|
||||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||||
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||||
|
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Override the decoding parameters of Seq2SeqTrainer
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
training_args_dict.update(dict(
|
training_args_dict.update(
|
||||||
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
|
dict(
|
||||||
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
|
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
|
||||||
))
|
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams,
|
||||||
|
)
|
||||||
|
)
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
@ -57,7 +61,7 @@ def run_sft(
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
**split_dataset(dataset, data_args, training_args)
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
|
@ -79,7 +83,7 @@ def run_sft(
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||||
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
|
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
|
||||||
metrics.pop("eval_loss", None)
|
metrics.pop("eval_loss", None)
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
@ -87,7 +91,7 @@ def run_sft(
|
||||||
# Predict
|
# Predict
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
||||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||||
predict_results.metrics.pop("predict_loss", None)
|
predict_results.metrics.pop("predict_loss", None)
|
||||||
trainer.log_metrics("predict", predict_results.metrics)
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_metrics("predict", predict_results.metrics)
|
trainer.save_metrics("predict", predict_results.metrics)
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
import torch
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
from ..extras.callbacks import LogCallback
|
from ..extras.callbacks import LogCallback
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..hparams import get_train_args, get_infer_args
|
from ..hparams import get_infer_args, get_train_args
|
||||||
from ..model import load_model_and_tokenizer
|
from ..model import load_model_and_tokenizer
|
||||||
from .pt import run_pt
|
|
||||||
from .sft import run_sft
|
|
||||||
from .rm import run_rm
|
|
||||||
from .ppo import run_ppo
|
|
||||||
from .dpo import run_dpo
|
from .dpo import run_dpo
|
||||||
|
from .ppo import run_ppo
|
||||||
|
from .pt import run_pt
|
||||||
|
from .rm import run_rm
|
||||||
|
from .sft import run_sft
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
@ -64,23 +66,23 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
save_directory=model_args.export_dir,
|
save_directory=model_args.export_dir,
|
||||||
max_shard_size="{}GB".format(model_args.export_size),
|
max_shard_size="{}GB".format(model_args.export_size),
|
||||||
safe_serialization=(not model_args.export_legacy_format)
|
safe_serialization=(not model_args.export_legacy_format),
|
||||||
)
|
)
|
||||||
if model_args.export_hub_model_id is not None:
|
if model_args.export_hub_model_id is not None:
|
||||||
model.push_to_hub(
|
model.push_to_hub(
|
||||||
model_args.export_hub_model_id,
|
model_args.export_hub_model_id,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
max_shard_size="{}GB".format(model_args.export_size),
|
max_shard_size="{}GB".format(model_args.export_size),
|
||||||
safe_serialization=(not model_args.export_legacy_format)
|
safe_serialization=(not model_args.export_legacy_format),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer.padding_side = "left" # restore padding side
|
tokenizer.padding_side = "left" # restore padding side
|
||||||
tokenizer.init_kwargs["padding_side"] = "left"
|
tokenizer.init_kwargs["padding_side"] = "left"
|
||||||
tokenizer.save_pretrained(model_args.export_dir)
|
tokenizer.save_pretrained(model_args.export_dir)
|
||||||
if model_args.export_hub_model_id is not None:
|
if model_args.export_hub_model_id is not None:
|
||||||
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||||
except:
|
except Exception:
|
||||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,17 @@
|
||||||
import torch
|
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..hparams import ModelArguments, FinetuningArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
|
from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, Trainer
|
from transformers import Seq2SeqTrainingArguments, Trainer
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +23,7 @@ def create_modelcard_and_push(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments"
|
finetuning_args: "FinetuningArguments",
|
||||||
) -> None:
|
) -> None:
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
|
@ -33,9 +36,7 @@ def create_modelcard_and_push(
|
||||||
|
|
||||||
|
|
||||||
def create_ref_model(
|
def create_ref_model(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: Optional[bool] = False
|
||||||
finetuning_args: "FinetuningArguments",
|
|
||||||
add_valuehead: Optional[bool] = False
|
|
||||||
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
||||||
r"""
|
r"""
|
||||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||||
|
@ -44,11 +45,13 @@ def create_ref_model(
|
||||||
"""
|
"""
|
||||||
if finetuning_args.ref_model is not None:
|
if finetuning_args.ref_model is not None:
|
||||||
ref_model_args_dict = model_args.to_dict()
|
ref_model_args_dict = model_args.to_dict()
|
||||||
ref_model_args_dict.update(dict(
|
ref_model_args_dict.update(
|
||||||
model_name_or_path=finetuning_args.ref_model,
|
dict(
|
||||||
adapter_name_or_path=finetuning_args.ref_model_adapters,
|
model_name_or_path=finetuning_args.ref_model,
|
||||||
quantization_bit=finetuning_args.ref_model_quantization_bit
|
adapter_name_or_path=finetuning_args.ref_model_adapters,
|
||||||
))
|
quantization_bit=finetuning_args.ref_model_quantization_bit,
|
||||||
|
)
|
||||||
|
)
|
||||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||||
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||||
ref_model, _ = load_model_and_tokenizer(
|
ref_model, _ = load_model_and_tokenizer(
|
||||||
|
@ -68,9 +71,7 @@ def create_ref_model(
|
||||||
|
|
||||||
|
|
||||||
def create_reward_model(
|
def create_reward_model(
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
|
||||||
model_args: "ModelArguments",
|
|
||||||
finetuning_args: "FinetuningArguments"
|
|
||||||
) -> "AutoModelForCausalLMWithValueHead":
|
) -> "AutoModelForCausalLMWithValueHead":
|
||||||
r"""
|
r"""
|
||||||
Creates reward model for PPO training.
|
Creates reward model for PPO training.
|
||||||
|
@ -81,24 +82,30 @@ def create_reward_model(
|
||||||
return finetuning_args.reward_model
|
return finetuning_args.reward_model
|
||||||
elif finetuning_args.reward_model_type == "lora":
|
elif finetuning_args.reward_model_type == "lora":
|
||||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||||
if "default" in name:
|
if "default" in name:
|
||||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||||
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
|
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
|
||||||
assert vhead_params is not None, "Reward model is not correctly loaded."
|
assert vhead_params is not None, "Reward model is not correctly loaded."
|
||||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
model.register_buffer(
|
||||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
"default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False
|
||||||
|
)
|
||||||
|
model.register_buffer(
|
||||||
|
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
|
||||||
|
)
|
||||||
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
reward_model_args_dict = model_args.to_dict()
|
reward_model_args_dict = model_args.to_dict()
|
||||||
reward_model_args_dict.update(dict(
|
reward_model_args_dict.update(
|
||||||
model_name_or_path=finetuning_args.reward_model,
|
dict(
|
||||||
adapter_name_or_path=finetuning_args.reward_model_adapters,
|
model_name_or_path=finetuning_args.reward_model,
|
||||||
quantization_bit=finetuning_args.reward_model_quantization_bit
|
adapter_name_or_path=finetuning_args.reward_model_adapters,
|
||||||
))
|
quantization_bit=finetuning_args.reward_model_quantization_bit,
|
||||||
|
)
|
||||||
|
)
|
||||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||||
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||||
reward_model, _ = load_model_and_tokenizer(
|
reward_model, _ = load_model_and_tokenizer(
|
||||||
|
|
|
@ -1,24 +1,22 @@
|
||||||
import gradio as gr
|
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
from ..extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
from ..hparams import GeneratingArguments
|
from ..hparams import GeneratingArguments
|
||||||
from .common import get_save_dir
|
from .common import get_save_dir
|
||||||
from .locales import ALERTS
|
from .locales import ALERTS
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .manager import Manager
|
from .manager import Manager
|
||||||
|
|
||||||
|
|
||||||
class WebChatModel(ChatModel):
|
class WebChatModel(ChatModel):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
|
||||||
manager: "Manager",
|
|
||||||
demo_mode: Optional[bool] = False,
|
|
||||||
lazy_init: Optional[bool] = True
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.manager = manager
|
self.manager = manager
|
||||||
self.demo_mode = demo_mode
|
self.demo_mode = demo_mode
|
||||||
|
@ -26,11 +24,12 @@ class WebChatModel(ChatModel):
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.generating_args = GeneratingArguments()
|
self.generating_args = GeneratingArguments()
|
||||||
|
|
||||||
if not lazy_init: # read arguments from command line
|
if not lazy_init: # read arguments from command line
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if demo_mode: # load demo_config.json if exists
|
if demo_mode: # load demo_config.json if exists
|
||||||
import json
|
import json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open("demo_config.json", "r", encoding="utf-8") as f:
|
with open("demo_config.json", "r", encoding="utf-8") as f:
|
||||||
args = json.load(f)
|
args = json.load(f)
|
||||||
|
@ -38,7 +37,7 @@ class WebChatModel(ChatModel):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
print("Please provided model name and template in `demo_config.json`.")
|
print("Please provided model name and template in `demo_config.json`.")
|
||||||
except:
|
except Exception:
|
||||||
print("Cannot find `demo_config.json` at current directory.")
|
print("Cannot find `demo_config.json` at current directory.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -64,9 +63,12 @@ class WebChatModel(ChatModel):
|
||||||
return
|
return
|
||||||
|
|
||||||
if get("top.adapter_path"):
|
if get("top.adapter_path"):
|
||||||
adapter_name_or_path = ",".join([
|
adapter_name_or_path = ",".join(
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
[
|
||||||
for adapter in get("top.adapter_path")])
|
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||||
|
for adapter in get("top.adapter_path")
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
adapter_name_or_path = None
|
adapter_name_or_path = None
|
||||||
|
|
||||||
|
@ -79,7 +81,7 @@ class WebChatModel(ChatModel):
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
flash_attn=(get("top.booster") == "flash_attn"),
|
flash_attn=(get("top.booster") == "flash_attn"),
|
||||||
use_unsloth=(get("top.booster") == "unsloth"),
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
)
|
)
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
|
|
||||||
|
@ -108,7 +110,7 @@ class WebChatModel(ChatModel):
|
||||||
tools: str,
|
tools: str,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temperature: float
|
temperature: float,
|
||||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
||||||
chatbot.append([query, ""])
|
chatbot.append([query, ""])
|
||||||
response = ""
|
response = ""
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import gradio as gr
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
|
|
||||||
|
import gradio as gr
|
||||||
|
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
|
||||||
|
|
||||||
from ..extras.constants import (
|
from ..extras.constants import (
|
||||||
DATA_CONFIG,
|
DATA_CONFIG,
|
||||||
|
@ -12,7 +13,7 @@ from ..extras.constants import (
|
||||||
PEFT_METHODS,
|
PEFT_METHODS,
|
||||||
SUPPORTED_MODELS,
|
SUPPORTED_MODELS,
|
||||||
TRAINING_STAGES,
|
TRAINING_STAGES,
|
||||||
DownloadSource
|
DownloadSource,
|
||||||
)
|
)
|
||||||
from ..extras.misc import use_modelscope
|
from ..extras.misc import use_modelscope
|
||||||
|
|
||||||
|
@ -36,7 +37,7 @@ def load_config() -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except:
|
except Exception:
|
||||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,7 +60,7 @@ def get_model_path(model_name: str) -> str:
|
||||||
use_modelscope()
|
use_modelscope()
|
||||||
and path_dict.get(DownloadSource.MODELSCOPE)
|
and path_dict.get(DownloadSource.MODELSCOPE)
|
||||||
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
||||||
): # replace path
|
): # replace path
|
||||||
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
@ -87,9 +88,8 @@ def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||||
save_dir = get_save_dir(model_name, finetuning_type)
|
save_dir = get_save_dir(model_name, finetuning_type)
|
||||||
if save_dir and os.path.isdir(save_dir):
|
if save_dir and os.path.isdir(save_dir):
|
||||||
for adapter in os.listdir(save_dir):
|
for adapter in os.listdir(save_dir):
|
||||||
if (
|
if os.path.isdir(os.path.join(save_dir, adapter)) and any(
|
||||||
os.path.isdir(os.path.join(save_dir, adapter))
|
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
|
||||||
and any([os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES])
|
|
||||||
):
|
):
|
||||||
adapters.append(adapter)
|
adapters.append(adapter)
|
||||||
return gr.update(value=[], choices=adapters, interactive=True)
|
return gr.update(value=[], choices=adapters, interactive=True)
|
||||||
|
|
|
@ -1,11 +1,16 @@
|
||||||
|
from .chatbot import create_chat_box
|
||||||
|
from .eval import create_eval_tab
|
||||||
|
from .export import create_export_tab
|
||||||
|
from .infer import create_infer_tab
|
||||||
from .top import create_top
|
from .top import create_top
|
||||||
from .train import create_train_tab
|
from .train import create_train_tab
|
||||||
from .eval import create_eval_tab
|
|
||||||
from .infer import create_infer_tab
|
|
||||||
from .export import create_export_tab
|
|
||||||
from .chatbot import create_chat_box
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box"
|
"create_top",
|
||||||
|
"create_train_tab",
|
||||||
|
"create_eval_tab",
|
||||||
|
"create_infer_tab",
|
||||||
|
"create_export_tab",
|
||||||
|
"create_chat_box",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from ..utils import check_json_schema
|
from ..utils import check_json_schema
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,8 +13,7 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
def create_chat_box(
|
def create_chat_box(
|
||||||
engine: "Engine",
|
engine: "Engine", visible: Optional[bool] = False
|
||||||
visible: Optional[bool] = False
|
|
||||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||||
with gr.Box(visible=visible) as chat_box:
|
with gr.Box(visible=visible) as chat_box:
|
||||||
chatbot = gr.Chatbot()
|
chatbot = gr.Chatbot()
|
||||||
|
@ -38,20 +38,23 @@ def create_chat_box(
|
||||||
engine.chatter.predict,
|
engine.chatter.predict,
|
||||||
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
|
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
|
||||||
[chatbot, history],
|
[chatbot, history],
|
||||||
show_progress=True
|
show_progress=True,
|
||||||
).then(
|
).then(lambda: gr.update(value=""), outputs=[query])
|
||||||
lambda: gr.update(value=""), outputs=[query]
|
|
||||||
)
|
|
||||||
|
|
||||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||||
|
|
||||||
return chat_box, chatbot, history, dict(
|
return (
|
||||||
system=system,
|
chat_box,
|
||||||
tools=tools,
|
chatbot,
|
||||||
query=query,
|
history,
|
||||||
submit_btn=submit_btn,
|
dict(
|
||||||
clear_btn=clear_btn,
|
system=system,
|
||||||
max_new_tokens=max_new_tokens,
|
tools=tools,
|
||||||
top_p=top_p,
|
query=query,
|
||||||
temperature=temperature
|
submit_btn=submit_btn,
|
||||||
|
clear_btn=clear_btn,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import gradio as gr
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from ...extras.constants import DATA_CONFIG
|
from ...extras.constants import DATA_CONFIG
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
|
|
||||||
|
@ -24,7 +26,7 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
except:
|
except Exception:
|
||||||
return gr.update(interactive=False)
|
return gr.update(interactive=False)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -48,7 +50,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
|
||||||
elif data_file.endswith(".jsonl"):
|
elif data_file.endswith(".jsonl"):
|
||||||
data = [json.loads(line) for line in f]
|
data = [json.loads(line) for line in f]
|
||||||
else:
|
else:
|
||||||
data = [line for line in f]
|
data = [line for line in f] # noqa: C416
|
||||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
|
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,32 +69,17 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
preview_samples = gr.JSON(interactive=False)
|
preview_samples = gr.JSON(interactive=False)
|
||||||
|
|
||||||
dataset.change(
|
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
|
||||||
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
|
|
||||||
).then(
|
|
||||||
lambda: 0, outputs=[page_index], queue=False
|
lambda: 0, outputs=[page_index], queue=False
|
||||||
)
|
)
|
||||||
data_preview_btn.click(
|
data_preview_btn.click(
|
||||||
get_preview,
|
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||||
[dataset_dir, dataset, page_index],
|
|
||||||
[preview_count, preview_samples, preview_box],
|
|
||||||
queue=False
|
|
||||||
)
|
)
|
||||||
prev_btn.click(
|
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
|
||||||
prev_page, [page_index], [page_index], queue=False
|
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||||
).then(
|
|
||||||
get_preview,
|
|
||||||
[dataset_dir, dataset, page_index],
|
|
||||||
[preview_count, preview_samples, preview_box],
|
|
||||||
queue=False
|
|
||||||
)
|
)
|
||||||
next_btn.click(
|
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
|
||||||
next_page, [page_index, preview_count], [page_index], queue=False
|
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||||
).then(
|
|
||||||
get_preview,
|
|
||||||
[dataset_dir, dataset, page_index],
|
|
||||||
[preview_count, preview_samples, preview_box],
|
|
||||||
queue=False
|
|
||||||
)
|
)
|
||||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
||||||
return dict(
|
return dict(
|
||||||
|
@ -102,5 +89,5 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
|
||||||
prev_btn=prev_btn,
|
prev_btn=prev_btn,
|
||||||
next_btn=next_btn,
|
next_btn=next_btn,
|
||||||
close_btn=close_btn,
|
close_btn=close_btn,
|
||||||
preview_samples=preview_samples
|
preview_samples=preview_samples,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from ..common import list_dataset, DEFAULT_DATA_DIR
|
import gradio as gr
|
||||||
|
|
||||||
|
from ..common import DEFAULT_DATA_DIR, list_dataset
|
||||||
from .data import create_preview_box
|
from .data import create_preview_box
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
|
|
||||||
|
@ -31,9 +33,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
predict = gr.Checkbox(value=True)
|
predict = gr.Checkbox(value=True)
|
||||||
|
|
||||||
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
|
||||||
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
|
|
||||||
))
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||||
|
@ -42,9 +42,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
output_dir = gr.Textbox()
|
output_dir = gr.Textbox()
|
||||||
|
|
||||||
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
|
||||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
|
|
||||||
))
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cmd_preview_btn = gr.Button()
|
cmd_preview_btn = gr.Button()
|
||||||
|
@ -59,10 +57,16 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
output_box = gr.Markdown()
|
output_box = gr.Markdown()
|
||||||
|
|
||||||
output_elems = [output_box, process_bar]
|
output_elems = [output_box, process_bar]
|
||||||
elem_dict.update(dict(
|
elem_dict.update(
|
||||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
|
dict(
|
||||||
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box
|
cmd_preview_btn=cmd_preview_btn,
|
||||||
))
|
start_btn=start_btn,
|
||||||
|
stop_btn=stop_btn,
|
||||||
|
resume_btn=resume_btn,
|
||||||
|
process_bar=process_bar,
|
||||||
|
output_box=output_box,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
||||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from ...train import export_model
|
from ...train import export_model
|
||||||
from ..common import get_save_dir
|
from ..common import get_save_dir
|
||||||
from ..locales import ALERTS
|
from ..locales import ALERTS
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
|
|
||||||
|
@ -24,7 +26,7 @@ def save_model(
|
||||||
max_shard_size: int,
|
max_shard_size: int,
|
||||||
export_quantization_bit: int,
|
export_quantization_bit: int,
|
||||||
export_quantization_dataset: str,
|
export_quantization_dataset: str,
|
||||||
export_dir: str
|
export_dir: str,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
error = ""
|
error = ""
|
||||||
if not model_name:
|
if not model_name:
|
||||||
|
@ -44,7 +46,9 @@ def save_model(
|
||||||
return
|
return
|
||||||
|
|
||||||
if adapter_path:
|
if adapter_path:
|
||||||
adapter_name_or_path = ",".join([get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path])
|
adapter_name_or_path = ",".join(
|
||||||
|
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
adapter_name_or_path = None
|
adapter_name_or_path = None
|
||||||
|
|
||||||
|
@ -56,7 +60,7 @@ def save_model(
|
||||||
export_dir=export_dir,
|
export_dir=export_dir,
|
||||||
export_size=max_shard_size,
|
export_size=max_shard_size,
|
||||||
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
|
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
|
||||||
export_quantization_dataset=export_quantization_dataset
|
export_quantization_dataset=export_quantization_dataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ALERTS["info_exporting"][lang]
|
yield ALERTS["info_exporting"][lang]
|
||||||
|
@ -86,9 +90,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
max_shard_size,
|
max_shard_size,
|
||||||
export_quantization_bit,
|
export_quantization_bit,
|
||||||
export_quantization_dataset,
|
export_quantization_dataset,
|
||||||
export_dir
|
export_dir,
|
||||||
],
|
],
|
||||||
[info_box]
|
[info_box],
|
||||||
)
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
|
@ -97,5 +101,5 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
export_quantization_dataset=export_quantization_dataset,
|
export_quantization_dataset=export_quantization_dataset,
|
||||||
export_dir=export_dir,
|
export_dir=export_dir,
|
||||||
export_btn=export_btn,
|
export_btn=export_btn,
|
||||||
info_box=info_box
|
info_box=info_box,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from .chatbot import create_chat_box
|
from .chatbot import create_chat_box
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
|
|
||||||
|
@ -23,18 +25,12 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
||||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||||
|
|
||||||
load_btn.click(
|
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
|
||||||
engine.chatter.load_model, input_elems, [info_box]
|
|
||||||
).then(
|
|
||||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||||
)
|
)
|
||||||
|
|
||||||
unload_btn.click(
|
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
|
||||||
engine.chatter.unload_model, input_elems, [info_box]
|
|
||||||
).then(
|
|
||||||
lambda: ([], []), outputs=[chatbot, history]
|
lambda: ([], []), outputs=[chatbot, history]
|
||||||
).then(
|
).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box])
|
||||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
|
||||||
)
|
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from ...data import templates
|
from ...data import templates
|
||||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||||
from ..common import get_model_path, get_template, list_adapters, save_config
|
from ..common import get_model_path, get_template, list_adapters, save_config
|
||||||
from ..utils import can_quantize
|
from ..utils import can_quantize
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
|
|
||||||
|
@ -30,25 +32,19 @@ def create_top() -> Dict[str, "Component"]:
|
||||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||||
booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none")
|
booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none")
|
||||||
|
|
||||||
model_name.change(
|
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||||
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
|
|
||||||
).then(
|
|
||||||
get_model_path, [model_name], [model_path], queue=False
|
get_model_path, [model_name], [model_path], queue=False
|
||||||
).then(
|
).then(
|
||||||
get_template, [model_name], [template], queue=False
|
get_template, [model_name], [template], queue=False
|
||||||
) # do not save config since the below line will save
|
) # do not save config since the below line will save
|
||||||
|
|
||||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||||
|
|
||||||
finetuning_type.change(
|
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||||
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
|
|
||||||
).then(
|
|
||||||
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_btn.click(
|
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
|
||||||
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
|
|
||||||
)
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
lang=lang,
|
lang=lang,
|
||||||
|
@ -61,5 +57,5 @@ def create_top() -> Dict[str, "Component"]:
|
||||||
quantization_bit=quantization_bit,
|
quantization_bit=quantization_bit,
|
||||||
template=template,
|
template=template,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
booster=booster
|
booster=booster,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
from transformers.trainer_utils import SchedulerType
|
from transformers.trainer_utils import SchedulerType
|
||||||
|
|
||||||
from ...extras.constants import TRAINING_STAGES
|
from ...extras.constants import TRAINING_STAGES
|
||||||
from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR
|
from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
|
||||||
from ..components.data import create_preview_box
|
from ..components.data import create_preview_box
|
||||||
from ..utils import gen_plot
|
from ..utils import gen_plot
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
|
|
||||||
|
@ -29,9 +31,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||||
|
|
||||||
input_elems.update({training_stage, dataset_dir, dataset})
|
input_elems.update({training_stage, dataset_dir, dataset})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||||
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
|
|
||||||
))
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||||
|
@ -41,25 +41,33 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
|
compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
|
||||||
|
|
||||||
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
|
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(
|
||||||
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
|
dict(
|
||||||
max_samples=max_samples, compute_type=compute_type
|
cutoff_len=cutoff_len,
|
||||||
))
|
learning_rate=learning_rate,
|
||||||
|
num_train_epochs=num_train_epochs,
|
||||||
|
max_samples=max_samples,
|
||||||
|
compute_type=compute_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||||
lr_scheduler_type = gr.Dropdown(
|
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
|
||||||
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
|
|
||||||
)
|
|
||||||
max_grad_norm = gr.Textbox(value="1.0")
|
max_grad_norm = gr.Textbox(value="1.0")
|
||||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||||
|
|
||||||
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
|
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(
|
||||||
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
|
dict(
|
||||||
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
|
batch_size=batch_size,
|
||||||
))
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
|
max_grad_norm=max_grad_norm,
|
||||||
|
val_size=val_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Accordion(label="Extra config", open=False) as extra_tab:
|
with gr.Accordion(label="Extra config", open=False) as extra_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -73,10 +81,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
upcast_layernorm = gr.Checkbox(value=False)
|
upcast_layernorm = gr.Checkbox(value=False)
|
||||||
|
|
||||||
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, sft_packing, upcast_layernorm})
|
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, sft_packing, upcast_layernorm})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(
|
||||||
extra_tab=extra_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
|
dict(
|
||||||
neftune_alpha=neftune_alpha, sft_packing=sft_packing, upcast_layernorm=upcast_layernorm
|
extra_tab=extra_tab,
|
||||||
))
|
logging_steps=logging_steps,
|
||||||
|
save_steps=save_steps,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
neftune_alpha=neftune_alpha,
|
||||||
|
sft_packing=sft_packing,
|
||||||
|
upcast_layernorm=upcast_layernorm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -87,10 +102,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
create_new_adapter = gr.Checkbox(scale=1)
|
create_new_adapter = gr.Checkbox(scale=1)
|
||||||
|
|
||||||
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, create_new_adapter})
|
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, create_new_adapter})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(
|
||||||
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
|
dict(
|
||||||
additional_target=additional_target, create_new_adapter=create_new_adapter
|
lora_tab=lora_tab,
|
||||||
))
|
lora_rank=lora_rank,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
lora_target=lora_target,
|
||||||
|
additional_target=additional_target,
|
||||||
|
create_new_adapter=create_new_adapter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -103,13 +124,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
list_adapters,
|
list_adapters,
|
||||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||||
[reward_model],
|
[reward_model],
|
||||||
queue=False
|
queue=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(
|
||||||
rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn
|
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
|
||||||
))
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cmd_preview_btn = gr.Button()
|
cmd_preview_btn = gr.Button()
|
||||||
|
@ -139,20 +160,28 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||||
|
|
||||||
elem_dict.update(dict(
|
elem_dict.update(
|
||||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
|
dict(
|
||||||
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
|
cmd_preview_btn=cmd_preview_btn,
|
||||||
))
|
start_btn=start_btn,
|
||||||
|
stop_btn=stop_btn,
|
||||||
|
output_dir=output_dir,
|
||||||
|
resume_btn=resume_btn,
|
||||||
|
process_bar=process_bar,
|
||||||
|
output_box=output_box,
|
||||||
|
loss_viewer=loss_viewer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
output_box.change(
|
output_box.change(
|
||||||
gen_plot,
|
gen_plot,
|
||||||
[
|
[
|
||||||
engine.manager.get_elem_by_name("top.model_name"),
|
engine.manager.get_elem_by_name("top.model_name"),
|
||||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||||
output_dir
|
output_dir,
|
||||||
],
|
],
|
||||||
loss_viewer,
|
loss_viewer,
|
||||||
queue=False
|
queue=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import gradio as gr
|
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
|
||||||
from typing import Any, Dict, Generator, Optional
|
from typing import Any, Dict, Generator, Optional
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
|
|
||||||
from .chatter import WebChatModel
|
from .chatter import WebChatModel
|
||||||
from .common import get_model_path, list_dataset, load_config
|
from .common import get_model_path, list_dataset, load_config
|
||||||
from .locales import LOCALES
|
from .locales import LOCALES
|
||||||
|
@ -11,7 +12,6 @@ from .utils import get_time
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
|
|
||||||
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
|
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
|
||||||
self.demo_mode = demo_mode
|
self.demo_mode = demo_mode
|
||||||
self.pure_chat = pure_chat
|
self.pure_chat = pure_chat
|
||||||
|
@ -26,10 +26,7 @@ class Engine:
|
||||||
user_config = load_config() if not self.demo_mode else {}
|
user_config = load_config() if not self.demo_mode else {}
|
||||||
lang = user_config.get("lang", None) or "en"
|
lang = user_config.get("lang", None) or "en"
|
||||||
|
|
||||||
init_dict = {
|
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
|
||||||
"top.lang": {"value": lang},
|
|
||||||
"infer.chat_box": {"visible": self.chatter.loaded}
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.pure_chat:
|
if not self.pure_chat:
|
||||||
init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
||||||
|
@ -49,13 +46,17 @@ class Engine:
|
||||||
else:
|
else:
|
||||||
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
||||||
else:
|
else:
|
||||||
yield self._form_dict({
|
yield self._form_dict(
|
||||||
"train.output_dir": {"value": "train_" + get_time()},
|
{
|
||||||
"eval.output_dir": {"value": "eval_" + get_time()},
|
"train.output_dir": {"value": "train_" + get_time()},
|
||||||
})
|
"eval.output_dir": {"value": "eval_" + get_time()},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
component: gr.update(**LOCALES[name][lang])
|
component: gr.update(**LOCALES[name][lang])
|
||||||
for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
|
for elems in self.manager.all_elems.values()
|
||||||
|
for name, component in elems.items()
|
||||||
|
if name in LOCALES
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,21 +1,22 @@
|
||||||
import gradio as gr
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from .common import save_config
|
||||||
from .components import (
|
from .components import (
|
||||||
|
create_chat_box,
|
||||||
|
create_eval_tab,
|
||||||
|
create_export_tab,
|
||||||
|
create_infer_tab,
|
||||||
create_top,
|
create_top,
|
||||||
create_train_tab,
|
create_train_tab,
|
||||||
create_eval_tab,
|
|
||||||
create_infer_tab,
|
|
||||||
create_export_tab,
|
|
||||||
create_chat_box
|
|
||||||
)
|
)
|
||||||
from .common import save_config
|
|
||||||
from .css import CSS
|
from .css import CSS
|
||||||
from .engine import Engine
|
from .engine import Engine
|
||||||
|
|
||||||
|
|
||||||
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
|
require_version("gradio>=3.38.0,<4.0.0", 'To fix: pip install "gradio>=3.38.0,<4.0.0"')
|
||||||
|
|
||||||
|
|
||||||
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
|
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
|
||||||
|
@ -23,11 +24,9 @@ def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
|
||||||
|
|
||||||
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
||||||
if demo_mode:
|
if demo_mode:
|
||||||
|
gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
|
||||||
gr.HTML(
|
gr.HTML(
|
||||||
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
|
'<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
|
||||||
)
|
|
||||||
gr.HTML(
|
|
||||||
"<h3><center>Visit <a href=\"https://github.com/hiyouga/LLaMA-Factory\" target=\"_blank\">"
|
|
||||||
"LLaMA Factory</a> for details.</center></h3>"
|
"LLaMA Factory</a> for details.</center></h3>"
|
||||||
)
|
)
|
||||||
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
||||||
|
|
|
@ -1,726 +1,220 @@
|
||||||
LOCALES = {
|
LOCALES = {
|
||||||
"lang": {
|
"lang": {"en": {"label": "Lang"}, "zh": {"label": "语言"}},
|
||||||
"en": {
|
"model_name": {"en": {"label": "Model name"}, "zh": {"label": "模型名称"}},
|
||||||
"label": "Lang"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "语言"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"model_name": {
|
|
||||||
"en": {
|
|
||||||
"label": "Model name"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "模型名称"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"model_path": {
|
"model_path": {
|
||||||
"en": {
|
"en": {"label": "Model path", "info": "Path to pretrained model or model identifier from Hugging Face."},
|
||||||
"label": "Model path",
|
"zh": {"label": "模型路径", "info": "本地模型的文件路径或 Hugging Face 的模型标识符。"},
|
||||||
"info": "Path to pretrained model or model identifier from Hugging Face."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "模型路径",
|
|
||||||
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"finetuning_type": {
|
|
||||||
"en": {
|
|
||||||
"label": "Finetuning method"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "微调方法"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"adapter_path": {
|
|
||||||
"en": {
|
|
||||||
"label": "Adapter path"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "适配器路径"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"refresh_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Refresh adapters"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "刷新适配器"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"advanced_tab": {
|
|
||||||
"en": {
|
|
||||||
"label": "Advanced configurations"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "高级设置"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
"finetuning_type": {"en": {"label": "Finetuning method"}, "zh": {"label": "微调方法"}},
|
||||||
|
"adapter_path": {"en": {"label": "Adapter path"}, "zh": {"label": "适配器路径"}},
|
||||||
|
"refresh_btn": {"en": {"value": "Refresh adapters"}, "zh": {"value": "刷新适配器"}},
|
||||||
|
"advanced_tab": {"en": {"label": "Advanced configurations"}, "zh": {"label": "高级设置"}},
|
||||||
"quantization_bit": {
|
"quantization_bit": {
|
||||||
"en": {
|
"en": {"label": "Quantization bit", "info": "Enable 4/8-bit model quantization (QLoRA)."},
|
||||||
"label": "Quantization bit",
|
"zh": {"label": "量化等级", "info": "启用 4/8 比特模型量化(QLoRA)。"},
|
||||||
"info": "Enable 4/8-bit model quantization (QLoRA)."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "量化等级",
|
|
||||||
"info": "启用 4/8 比特模型量化(QLoRA)。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"template": {
|
"template": {
|
||||||
"en": {
|
"en": {"label": "Prompt template", "info": "The template used in constructing prompts."},
|
||||||
"label": "Prompt template",
|
"zh": {"label": "提示模板", "info": "构建提示词时使用的模板"},
|
||||||
"info": "The template used in constructing prompts."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "提示模板",
|
|
||||||
"info": "构建提示词时使用的模板"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"rope_scaling": {
|
|
||||||
"en": {
|
|
||||||
"label": "RoPE scaling"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "RoPE 插值方法"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"booster": {
|
|
||||||
"en": {
|
|
||||||
"label": "Booster"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "加速方式"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
"rope_scaling": {"en": {"label": "RoPE scaling"}, "zh": {"label": "RoPE 插值方法"}},
|
||||||
|
"booster": {"en": {"label": "Booster"}, "zh": {"label": "加速方式"}},
|
||||||
"training_stage": {
|
"training_stage": {
|
||||||
"en": {
|
"en": {"label": "Stage", "info": "The stage to perform in training."},
|
||||||
"label": "Stage",
|
"zh": {"label": "训练阶段", "info": "目前采用的训练方式。"},
|
||||||
"info": "The stage to perform in training."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "训练阶段",
|
|
||||||
"info": "目前采用的训练方式。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"dataset_dir": {
|
"dataset_dir": {
|
||||||
"en": {
|
"en": {"label": "Data dir", "info": "Path to the data directory."},
|
||||||
"label": "Data dir",
|
"zh": {"label": "数据路径", "info": "数据文件夹的路径。"},
|
||||||
"info": "Path to the data directory."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "数据路径",
|
|
||||||
"info": "数据文件夹的路径。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"dataset": {
|
|
||||||
"en": {
|
|
||||||
"label": "Dataset"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "数据集"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"data_preview_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Preview dataset"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "预览数据集"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"preview_count": {
|
|
||||||
"en": {
|
|
||||||
"label": "Count"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "数量"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"page_index": {
|
|
||||||
"en": {
|
|
||||||
"label": "Page"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "页数"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"prev_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Prev"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "上一页"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"next_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Next"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "下一页"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"close_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Close"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "关闭"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"preview_samples": {
|
|
||||||
"en": {
|
|
||||||
"label": "Samples"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "样例"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
"dataset": {"en": {"label": "Dataset"}, "zh": {"label": "数据集"}},
|
||||||
|
"data_preview_btn": {"en": {"value": "Preview dataset"}, "zh": {"value": "预览数据集"}},
|
||||||
|
"preview_count": {"en": {"label": "Count"}, "zh": {"label": "数量"}},
|
||||||
|
"page_index": {"en": {"label": "Page"}, "zh": {"label": "页数"}},
|
||||||
|
"prev_btn": {"en": {"value": "Prev"}, "zh": {"value": "上一页"}},
|
||||||
|
"next_btn": {"en": {"value": "Next"}, "zh": {"value": "下一页"}},
|
||||||
|
"close_btn": {"en": {"value": "Close"}, "zh": {"value": "关闭"}},
|
||||||
|
"preview_samples": {"en": {"label": "Samples"}, "zh": {"label": "样例"}},
|
||||||
"cutoff_len": {
|
"cutoff_len": {
|
||||||
"en": {
|
"en": {"label": "Cutoff length", "info": "Max tokens in input sequence."},
|
||||||
"label": "Cutoff length",
|
"zh": {"label": "截断长度", "info": "输入序列分词后的最大长度。"},
|
||||||
"info": "Max tokens in input sequence."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "截断长度",
|
|
||||||
"info": "输入序列分词后的最大长度。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"learning_rate": {
|
"learning_rate": {
|
||||||
"en": {
|
"en": {"label": "Learning rate", "info": "Initial learning rate for AdamW."},
|
||||||
"label": "Learning rate",
|
"zh": {"label": "学习率", "info": "AdamW 优化器的初始学习率。"},
|
||||||
"info": "Initial learning rate for AdamW."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "学习率",
|
|
||||||
"info": "AdamW 优化器的初始学习率。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"num_train_epochs": {
|
"num_train_epochs": {
|
||||||
"en": {
|
"en": {"label": "Epochs", "info": "Total number of training epochs to perform."},
|
||||||
"label": "Epochs",
|
"zh": {"label": "训练轮数", "info": "需要执行的训练总轮数。"},
|
||||||
"info": "Total number of training epochs to perform."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "训练轮数",
|
|
||||||
"info": "需要执行的训练总轮数。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"max_samples": {
|
"max_samples": {
|
||||||
"en": {
|
"en": {"label": "Max samples", "info": "Maximum samples per dataset."},
|
||||||
"label": "Max samples",
|
"zh": {"label": "最大样本数", "info": "每个数据集最多使用的样本数。"},
|
||||||
"info": "Maximum samples per dataset."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "最大样本数",
|
|
||||||
"info": "每个数据集最多使用的样本数。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"compute_type": {
|
"compute_type": {
|
||||||
"en": {
|
"en": {"label": "Compute type", "info": "Whether to use fp16 or bf16 mixed precision training."},
|
||||||
"label": "Compute type",
|
"zh": {"label": "计算类型", "info": "是否启用 FP16 或 BF16 混合精度训练。"},
|
||||||
"info": "Whether to use fp16 or bf16 mixed precision training."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "计算类型",
|
|
||||||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"batch_size": {
|
"batch_size": {
|
||||||
"en": {
|
"en": {"label": "Batch size", "info": "Number of samples to process per GPU."},
|
||||||
"label": "Batch size",
|
"zh": {"label": "批处理大小", "info": "每块 GPU 上处理的样本数量。"},
|
||||||
"info": "Number of samples to process per GPU."
|
|
||||||
},
|
|
||||||
"zh":{
|
|
||||||
"label": "批处理大小",
|
|
||||||
"info": "每块 GPU 上处理的样本数量。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"gradient_accumulation_steps": {
|
"gradient_accumulation_steps": {
|
||||||
"en": {
|
"en": {"label": "Gradient accumulation", "info": "Number of gradient accumulation steps."},
|
||||||
"label": "Gradient accumulation",
|
"zh": {"label": "梯度累积", "info": "梯度累积的步数。"},
|
||||||
"info": "Number of gradient accumulation steps."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "梯度累积",
|
|
||||||
"info": "梯度累积的步数。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"lr_scheduler_type": {
|
"lr_scheduler_type": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "LR Scheduler",
|
"label": "LR Scheduler",
|
||||||
"info": "Name of learning rate scheduler.",
|
"info": "Name of learning rate scheduler.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {"label": "学习率调节器", "info": "采用的学习率调节器名称。"},
|
||||||
"label": "学习率调节器",
|
|
||||||
"info": "采用的学习率调节器名称。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"max_grad_norm": {
|
"max_grad_norm": {
|
||||||
"en": {
|
"en": {"label": "Maximum gradient norm", "info": "Norm for gradient clipping.."},
|
||||||
"label": "Maximum gradient norm",
|
"zh": {"label": "最大梯度范数", "info": "用于梯度裁剪的范数。"},
|
||||||
"info": "Norm for gradient clipping.."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "最大梯度范数",
|
|
||||||
"info": "用于梯度裁剪的范数。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"val_size": {
|
"val_size": {
|
||||||
"en": {
|
"en": {"label": "Val size", "info": "Proportion of data in the dev set."},
|
||||||
"label": "Val size",
|
"zh": {"label": "验证集比例", "info": "验证集占全部样本的百分比。"},
|
||||||
"info": "Proportion of data in the dev set."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "验证集比例",
|
|
||||||
"info": "验证集占全部样本的百分比。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"extra_tab": {
|
|
||||||
"en": {
|
|
||||||
"label": "Extra configurations"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "其它参数设置"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
"extra_tab": {"en": {"label": "Extra configurations"}, "zh": {"label": "其它参数设置"}},
|
||||||
"logging_steps": {
|
"logging_steps": {
|
||||||
"en": {
|
"en": {"label": "Logging steps", "info": "Number of steps between two logs."},
|
||||||
"label": "Logging steps",
|
"zh": {"label": "日志间隔", "info": "每两次日志输出间的更新步数。"},
|
||||||
"info": "Number of steps between two logs."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "日志间隔",
|
|
||||||
"info": "每两次日志输出间的更新步数。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"save_steps": {
|
"save_steps": {
|
||||||
"en": {
|
"en": {"label": "Save steps", "info": "Number of steps between two checkpoints."},
|
||||||
"label": "Save steps",
|
"zh": {"label": "保存间隔", "info": "每两次断点保存间的更新步数。"},
|
||||||
"info": "Number of steps between two checkpoints."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "保存间隔",
|
|
||||||
"info": "每两次断点保存间的更新步数。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"warmup_steps": {
|
"warmup_steps": {
|
||||||
"en": {
|
"en": {"label": "Warmup steps", "info": "Number of steps used for warmup."},
|
||||||
"label": "Warmup steps",
|
"zh": {"label": "预热步数", "info": "学习率预热采用的步数。"},
|
||||||
"info": "Number of steps used for warmup."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "预热步数",
|
|
||||||
"info": "学习率预热采用的步数。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"neftune_alpha": {
|
"neftune_alpha": {
|
||||||
"en": {
|
"en": {"label": "NEFTune Alpha", "info": "Magnitude of noise adding to embedding vectors."},
|
||||||
"label": "NEFTune Alpha",
|
"zh": {"label": "NEFTune 噪声参数", "info": "嵌入向量所添加的噪声大小。"},
|
||||||
"info": "Magnitude of noise adding to embedding vectors."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "NEFTune 噪声参数",
|
|
||||||
"info": "嵌入向量所添加的噪声大小。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"sft_packing": {
|
"sft_packing": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Pack sequences",
|
"label": "Pack sequences",
|
||||||
"info": "Pack sequences into samples of fixed length in supervised fine-tuning."
|
"info": "Pack sequences into samples of fixed length in supervised fine-tuning.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {"label": "序列打包", "info": "在有监督微调阶段将序列打包为相同长度的样本。"},
|
||||||
"label": "序列打包",
|
|
||||||
"info": "在有监督微调阶段将序列打包为相同长度的样本。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"upcast_layernorm": {
|
"upcast_layernorm": {
|
||||||
"en": {
|
"en": {"label": "Upcast LayerNorm", "info": "Upcast weights of layernorm in float32."},
|
||||||
"label": "Upcast LayerNorm",
|
"zh": {"label": "缩放归一化层", "info": "将归一化层权重缩放至 32 位精度。"},
|
||||||
"info": "Upcast weights of layernorm in float32."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "缩放归一化层",
|
|
||||||
"info": "将归一化层权重缩放至 32 位精度。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"lora_tab": {
|
|
||||||
"en": {
|
|
||||||
"label": "LoRA configurations"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "LoRA 参数设置"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
"lora_tab": {"en": {"label": "LoRA configurations"}, "zh": {"label": "LoRA 参数设置"}},
|
||||||
"lora_rank": {
|
"lora_rank": {
|
||||||
"en": {
|
"en": {"label": "LoRA rank", "info": "The rank of LoRA matrices."},
|
||||||
"label": "LoRA rank",
|
"zh": {"label": "LoRA 秩", "info": "LoRA 矩阵的秩。"},
|
||||||
"info": "The rank of LoRA matrices."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "LoRA 秩",
|
|
||||||
"info": "LoRA 矩阵的秩。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"lora_dropout": {
|
"lora_dropout": {
|
||||||
"en": {
|
"en": {"label": "LoRA Dropout", "info": "Dropout ratio of LoRA weights."},
|
||||||
"label": "LoRA Dropout",
|
"zh": {"label": "LoRA 随机丢弃", "info": "LoRA 权重随机丢弃的概率。"},
|
||||||
"info": "Dropout ratio of LoRA weights."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "LoRA 随机丢弃",
|
|
||||||
"info": "LoRA 权重随机丢弃的概率。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"lora_target": {
|
"lora_target": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "LoRA modules (optional)",
|
"label": "LoRA modules (optional)",
|
||||||
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
|
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {"label": "LoRA 作用模块(非必填)", "info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"},
|
||||||
"label": "LoRA 作用模块(非必填)",
|
|
||||||
"info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"additional_target": {
|
"additional_target": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Additional modules (optional)",
|
"label": "Additional modules (optional)",
|
||||||
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules."
|
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {"label": "附加模块(非必填)", "info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"},
|
||||||
"label": "附加模块(非必填)",
|
|
||||||
"info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"create_new_adapter": {
|
"create_new_adapter": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Create new adapter",
|
"label": "Create new adapter",
|
||||||
"info": "Whether to create a new adapter with randomly initialized weight or not."
|
"info": "Whether to create a new adapter with randomly initialized weight or not.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {"label": "新建适配器", "info": "是否创建一个经过随机初始化的新适配器。"},
|
||||||
"label": "新建适配器",
|
|
||||||
"info": "是否创建一个经过随机初始化的新适配器。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"rlhf_tab": {
|
|
||||||
"en": {
|
|
||||||
"label": "RLHF configurations"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "RLHF 参数设置"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
"rlhf_tab": {"en": {"label": "RLHF configurations"}, "zh": {"label": "RLHF 参数设置"}},
|
||||||
"dpo_beta": {
|
"dpo_beta": {
|
||||||
"en": {
|
"en": {"label": "DPO beta", "info": "Value of the beta parameter in the DPO loss."},
|
||||||
"label": "DPO beta",
|
"zh": {"label": "DPO beta 参数", "info": "DPO 损失函数中 beta 超参数大小。"},
|
||||||
"info": "Value of the beta parameter in the DPO loss."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "DPO beta 参数",
|
|
||||||
"info": "DPO 损失函数中 beta 超参数大小。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"dpo_ftx": {
|
"dpo_ftx": {
|
||||||
"en": {
|
"en": {"label": "DPO-ftx weight", "info": "The weight of SFT loss in the DPO-ftx."},
|
||||||
"label": "DPO-ftx weight",
|
"zh": {"label": "DPO-ftx 权重", "info": "DPO-ftx 中 SFT 损失的权重大小。"},
|
||||||
"info": "The weight of SFT loss in the DPO-ftx."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "DPO-ftx 权重",
|
|
||||||
"info": "DPO-ftx 中 SFT 损失的权重大小。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"reward_model": {
|
"reward_model": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Reward model",
|
"label": "Reward model",
|
||||||
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)"
|
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {"label": "奖励模型", "info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)"},
|
||||||
"label": "奖励模型",
|
|
||||||
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"cmd_preview_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Preview command"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "预览命令"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"start_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Start"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "开始"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"stop_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Abort"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "中断"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
"cmd_preview_btn": {"en": {"value": "Preview command"}, "zh": {"value": "预览命令"}},
|
||||||
|
"start_btn": {"en": {"value": "Start"}, "zh": {"value": "开始"}},
|
||||||
|
"stop_btn": {"en": {"value": "Abort"}, "zh": {"value": "中断"}},
|
||||||
"output_dir": {
|
"output_dir": {
|
||||||
"en": {
|
"en": {"label": "Output dir", "info": "Directory for saving results."},
|
||||||
"label": "Output dir",
|
"zh": {"label": "输出目录", "info": "保存结果的路径。"},
|
||||||
"info": "Directory for saving results."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "输出目录",
|
|
||||||
"info": "保存结果的路径。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"output_box": {
|
|
||||||
"en": {
|
|
||||||
"value": "Ready."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "准备就绪。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"loss_viewer": {
|
|
||||||
"en": {
|
|
||||||
"label": "Loss"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "损失"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"predict": {
|
|
||||||
"en": {
|
|
||||||
"label": "Save predictions"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "保存预测结果"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"load_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Load model"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "加载模型"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"unload_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Unload model"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "卸载模型"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"info_box": {
|
|
||||||
"en": {
|
|
||||||
"value": "Model unloaded, please load a model first."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "模型未加载,请先加载模型。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"system": {
|
|
||||||
"en": {
|
|
||||||
"placeholder": "System prompt (optional)"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"placeholder": "系统提示词(非必填)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"tools": {
|
|
||||||
"en": {
|
|
||||||
"placeholder": "Tools (optional)"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"placeholder": "工具列表(非必填)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"query": {
|
|
||||||
"en": {
|
|
||||||
"placeholder": "Input..."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"placeholder": "输入..."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"submit_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Submit"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "提交"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"clear_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Clear history"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "清空历史"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"max_length": {
|
|
||||||
"en": {
|
|
||||||
"label": "Maximum length"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "最大长度"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"max_new_tokens": {
|
|
||||||
"en": {
|
|
||||||
"label": "Maximum new tokens"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "最大生成长度"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"top_p": {
|
|
||||||
"en": {
|
|
||||||
"label": "Top-p"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "Top-p 采样值"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"temperature": {
|
|
||||||
"en": {
|
|
||||||
"label": "Temperature"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "温度系数"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
"output_box": {"en": {"value": "Ready."}, "zh": {"value": "准备就绪。"}},
|
||||||
|
"loss_viewer": {"en": {"label": "Loss"}, "zh": {"label": "损失"}},
|
||||||
|
"predict": {"en": {"label": "Save predictions"}, "zh": {"label": "保存预测结果"}},
|
||||||
|
"load_btn": {"en": {"value": "Load model"}, "zh": {"value": "加载模型"}},
|
||||||
|
"unload_btn": {"en": {"value": "Unload model"}, "zh": {"value": "卸载模型"}},
|
||||||
|
"info_box": {"en": {"value": "Model unloaded, please load a model first."}, "zh": {"value": "模型未加载,请先加载模型。"}},
|
||||||
|
"system": {"en": {"placeholder": "System prompt (optional)"}, "zh": {"placeholder": "系统提示词(非必填)"}},
|
||||||
|
"tools": {"en": {"placeholder": "Tools (optional)"}, "zh": {"placeholder": "工具列表(非必填)"}},
|
||||||
|
"query": {"en": {"placeholder": "Input..."}, "zh": {"placeholder": "输入..."}},
|
||||||
|
"submit_btn": {"en": {"value": "Submit"}, "zh": {"value": "提交"}},
|
||||||
|
"clear_btn": {"en": {"value": "Clear history"}, "zh": {"value": "清空历史"}},
|
||||||
|
"max_length": {"en": {"label": "Maximum length"}, "zh": {"label": "最大长度"}},
|
||||||
|
"max_new_tokens": {"en": {"label": "Maximum new tokens"}, "zh": {"label": "最大生成长度"}},
|
||||||
|
"top_p": {"en": {"label": "Top-p"}, "zh": {"label": "Top-p 采样值"}},
|
||||||
|
"temperature": {"en": {"label": "Temperature"}, "zh": {"label": "温度系数"}},
|
||||||
"max_shard_size": {
|
"max_shard_size": {
|
||||||
"en": {
|
"en": {"label": "Max shard size (GB)", "info": "The maximum size for a model file."},
|
||||||
"label": "Max shard size (GB)",
|
"zh": {"label": "最大分块大小(GB)", "info": "单个模型文件的最大大小。"},
|
||||||
"info": "The maximum size for a model file."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "最大分块大小(GB)",
|
|
||||||
"info": "单个模型文件的最大大小。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"export_quantization_bit": {
|
"export_quantization_bit": {
|
||||||
"en": {
|
"en": {"label": "Export quantization bit.", "info": "Quantizing the exported model."},
|
||||||
"label": "Export quantization bit.",
|
"zh": {"label": "导出量化等级", "info": "量化导出模型。"},
|
||||||
"info": "Quantizing the exported model."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "导出量化等级",
|
|
||||||
"info": "量化导出模型。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"export_quantization_dataset": {
|
"export_quantization_dataset": {
|
||||||
"en": {
|
"en": {"label": "Export quantization dataset.", "info": "The calibration dataset used for quantization."},
|
||||||
"label": "Export quantization dataset.",
|
"zh": {"label": "导出量化数据集", "info": "量化过程中使用的校准数据集。"},
|
||||||
"info": "The calibration dataset used for quantization."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "导出量化数据集",
|
|
||||||
"info": "量化过程中使用的校准数据集。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"export_dir": {
|
"export_dir": {
|
||||||
"en": {
|
"en": {"label": "Export dir", "info": "Directory to save exported model."},
|
||||||
"label": "Export dir",
|
"zh": {"label": "导出目录", "info": "保存导出模型的文件夹路径。"},
|
||||||
"info": "Directory to save exported model."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "导出目录",
|
|
||||||
"info": "保存导出模型的文件夹路径。"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"export_btn": {
|
"export_btn": {"en": {"value": "Export"}, "zh": {"value": "开始导出"}},
|
||||||
"en": {
|
|
||||||
"value": "Export"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "开始导出"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ALERTS = {
|
ALERTS = {
|
||||||
"err_conflict": {
|
"err_conflict": {"en": "A process is in running, please abort it firstly.", "zh": "任务已存在,请先中断训练。"},
|
||||||
"en": "A process is in running, please abort it firstly.",
|
"err_exists": {"en": "You have loaded a model, please unload it first.", "zh": "模型已存在,请先卸载模型。"},
|
||||||
"zh": "任务已存在,请先中断训练。"
|
"err_no_model": {"en": "Please select a model.", "zh": "请选择模型。"},
|
||||||
},
|
"err_no_path": {"en": "Model not found.", "zh": "模型未找到。"},
|
||||||
"err_exists": {
|
"err_no_dataset": {"en": "Please choose a dataset.", "zh": "请选择数据集。"},
|
||||||
"en": "You have loaded a model, please unload it first.",
|
"err_no_adapter": {"en": "Please select an adapter.", "zh": "请选择一个适配器。"},
|
||||||
"zh": "模型已存在,请先卸载模型。"
|
"err_no_export_dir": {"en": "Please provide export dir.", "zh": "请填写导出目录"},
|
||||||
},
|
"err_failed": {"en": "Failed.", "zh": "训练出错。"},
|
||||||
"err_no_model": {
|
|
||||||
"en": "Please select a model.",
|
|
||||||
"zh": "请选择模型。"
|
|
||||||
},
|
|
||||||
"err_no_path": {
|
|
||||||
"en": "Model not found.",
|
|
||||||
"zh": "模型未找到。"
|
|
||||||
},
|
|
||||||
"err_no_dataset": {
|
|
||||||
"en": "Please choose a dataset.",
|
|
||||||
"zh": "请选择数据集。"
|
|
||||||
},
|
|
||||||
"err_no_adapter": {
|
|
||||||
"en": "Please select an adapter.",
|
|
||||||
"zh": "请选择一个适配器。"
|
|
||||||
},
|
|
||||||
"err_no_export_dir": {
|
|
||||||
"en": "Please provide export dir.",
|
|
||||||
"zh": "请填写导出目录"
|
|
||||||
},
|
|
||||||
"err_failed": {
|
|
||||||
"en": "Failed.",
|
|
||||||
"zh": "训练出错。"
|
|
||||||
},
|
|
||||||
"err_demo": {
|
"err_demo": {
|
||||||
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
|
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
|
||||||
"zh": "展示模式不支持训练,请先复制到私人空间。"
|
"zh": "展示模式不支持训练,请先复制到私人空间。",
|
||||||
},
|
},
|
||||||
"err_device_count": {
|
"err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
|
||||||
"en": "Multiple GPUs are not supported yet.",
|
"info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
|
||||||
"zh": "尚不支持多 GPU 训练。"
|
"info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
|
||||||
},
|
"info_finished": {"en": "Finished.", "zh": "训练完毕。"},
|
||||||
"info_aborting": {
|
"info_loading": {"en": "Loading model...", "zh": "加载中……"},
|
||||||
"en": "Aborted, wait for terminating...",
|
"info_unloading": {"en": "Unloading model...", "zh": "卸载中……"},
|
||||||
"zh": "训练中断,正在等待线程结束……"
|
"info_loaded": {"en": "Model loaded, now you can chat with your model!", "zh": "模型已加载,可以开始聊天了!"},
|
||||||
},
|
"info_unloaded": {"en": "Model unloaded.", "zh": "模型已卸载。"},
|
||||||
"info_aborted": {
|
"info_exporting": {"en": "Exporting model...", "zh": "正在导出模型……"},
|
||||||
"en": "Ready.",
|
"info_exported": {"en": "Model exported.", "zh": "模型导出完成。"},
|
||||||
"zh": "准备就绪。"
|
|
||||||
},
|
|
||||||
"info_finished": {
|
|
||||||
"en": "Finished.",
|
|
||||||
"zh": "训练完毕。"
|
|
||||||
},
|
|
||||||
"info_loading": {
|
|
||||||
"en": "Loading model...",
|
|
||||||
"zh": "加载中……"
|
|
||||||
},
|
|
||||||
"info_unloading": {
|
|
||||||
"en": "Unloading model...",
|
|
||||||
"zh": "卸载中……"
|
|
||||||
},
|
|
||||||
"info_loaded": {
|
|
||||||
"en": "Model loaded, now you can chat with your model!",
|
|
||||||
"zh": "模型已加载,可以开始聊天了!"
|
|
||||||
},
|
|
||||||
"info_unloaded": {
|
|
||||||
"en": "Model unloaded.",
|
|
||||||
"zh": "模型已卸载。"
|
|
||||||
},
|
|
||||||
"info_exporting": {
|
|
||||||
"en": "Exporting model...",
|
|
||||||
"zh": "正在导出模型……"
|
|
||||||
},
|
|
||||||
"info_exported": {
|
|
||||||
"en": "Model exported.",
|
|
||||||
"zh": "模型导出完成。"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
from typing import TYPE_CHECKING, Dict, List, Set
|
from typing import TYPE_CHECKING, Dict, List, Set
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
|
|
||||||
|
|
||||||
class Manager:
|
class Manager:
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
|
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class Manager:
|
||||||
self.all_elems["top"]["quantization_bit"],
|
self.all_elems["top"]["quantization_bit"],
|
||||||
self.all_elems["top"]["template"],
|
self.all_elems["top"]["template"],
|
||||||
self.all_elems["top"]["rope_scaling"],
|
self.all_elems["top"]["rope_scaling"],
|
||||||
self.all_elems["top"]["booster"]
|
self.all_elems["top"]["booster"],
|
||||||
}
|
}
|
||||||
|
|
||||||
def list_elems(self) -> List["Component"]:
|
def list_elems(self) -> List["Component"]:
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
|
||||||
import gradio as gr
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
import transformers
|
import transformers
|
||||||
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
|
|
||||||
from ..extras.callbacks import LogCallback
|
from ..extras.callbacks import LogCallback
|
||||||
|
@ -18,12 +18,12 @@ from .common import get_module, get_save_dir, load_config
|
||||||
from .locales import ALERTS
|
from .locales import ALERTS
|
||||||
from .utils import gen_cmd, get_eval_results, update_process_bar
|
from .utils import gen_cmd, get_eval_results, update_process_bar
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .manager import Manager
|
from .manager import Manager
|
||||||
|
|
||||||
|
|
||||||
class Runner:
|
class Runner:
|
||||||
|
|
||||||
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
|
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
|
||||||
self.manager = manager
|
self.manager = manager
|
||||||
self.demo_mode = demo_mode
|
self.demo_mode = demo_mode
|
||||||
|
@ -90,9 +90,12 @@ class Runner:
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.adapter_path"):
|
if get("top.adapter_path"):
|
||||||
adapter_name_or_path = ",".join([
|
adapter_name_or_path = ",".join(
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
[
|
||||||
for adapter in get("top.adapter_path")])
|
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||||
|
for adapter in get("top.adapter_path")
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
adapter_name_or_path = None
|
adapter_name_or_path = None
|
||||||
|
|
||||||
|
@ -131,12 +134,12 @@ class Runner:
|
||||||
create_new_adapter=get("train.create_new_adapter"),
|
create_new_adapter=get("train.create_new_adapter"),
|
||||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||||
fp16=(get("train.compute_type") == "fp16"),
|
fp16=(get("train.compute_type") == "fp16"),
|
||||||
bf16=(get("train.compute_type") == "bf16")
|
bf16=(get("train.compute_type") == "bf16"),
|
||||||
)
|
)
|
||||||
args["disable_tqdm"] = True
|
args["disable_tqdm"] = True
|
||||||
|
|
||||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||||
args["create_new_adapter"] = (args["quantization_bit"] is None)
|
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||||
|
|
||||||
if args["stage"] == "ppo":
|
if args["stage"] == "ppo":
|
||||||
args["reward_model"] = get_save_dir(
|
args["reward_model"] = get_save_dir(
|
||||||
|
@ -161,9 +164,12 @@ class Runner:
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.adapter_path"):
|
if get("top.adapter_path"):
|
||||||
adapter_name_or_path = ",".join([
|
adapter_name_or_path = ",".join(
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
[
|
||||||
for adapter in get("top.adapter_path")])
|
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||||
|
for adapter in get("top.adapter_path")
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
adapter_name_or_path = None
|
adapter_name_or_path = None
|
||||||
|
|
||||||
|
@ -187,7 +193,7 @@ class Runner:
|
||||||
max_new_tokens=get("eval.max_new_tokens"),
|
max_new_tokens=get("eval.max_new_tokens"),
|
||||||
top_p=get("eval.top_p"),
|
top_p=get("eval.top_p"),
|
||||||
temperature=get("eval.temperature"),
|
temperature=get("eval.temperature"),
|
||||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
|
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
|
||||||
)
|
)
|
||||||
|
|
||||||
if get("eval.predict"):
|
if get("eval.predict"):
|
||||||
|
@ -197,7 +203,9 @@ class Runner:
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def _preview(
|
||||||
|
self, data: Dict[Component, Any], do_train: bool
|
||||||
|
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
error = self._initialize(data, do_train, from_preview=True)
|
error = self._initialize(data, do_train, from_preview=True)
|
||||||
if error:
|
if error:
|
||||||
gr.Warning(error)
|
gr.Warning(error)
|
||||||
|
@ -235,9 +243,11 @@ class Runner:
|
||||||
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
|
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
|
||||||
self.running = True
|
self.running = True
|
||||||
lang = get("top.lang")
|
lang = get("top.lang")
|
||||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
output_dir = get_save_dir(
|
||||||
"{}.output_dir".format("train" if self.do_train else "eval")
|
get("top.model_name"),
|
||||||
))
|
get("top.finetuning_type"),
|
||||||
|
get("{}.output_dir".format("train" if self.do_train else "eval")),
|
||||||
|
)
|
||||||
|
|
||||||
while self.thread.is_alive():
|
while self.thread.is_alive():
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import gradio as gr
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Dict
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from ..extras.packages import is_matplotlib_available
|
from ..extras.packages import is_matplotlib_available
|
||||||
from ..extras.ploting import smooth
|
from ..extras.ploting import smooth
|
||||||
from .common import get_save_dir
|
from .common import get_save_dir
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..extras.callbacks import LogCallback
|
from ..extras.callbacks import LogCallback
|
||||||
|
|
||||||
|
@ -22,16 +24,13 @@ def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
||||||
|
|
||||||
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
|
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
|
||||||
label = "Running {:d}/{:d}: {} < {}".format(
|
label = "Running {:d}/{:d}: {} < {}".format(
|
||||||
callback.cur_steps,
|
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
|
||||||
callback.max_steps,
|
|
||||||
callback.elapsed_time,
|
|
||||||
callback.remaining_time
|
|
||||||
)
|
)
|
||||||
return gr.update(label=label, value=percentage, visible=True)
|
return gr.update(label=label, value=percentage, visible=True)
|
||||||
|
|
||||||
|
|
||||||
def get_time() -> str:
|
def get_time() -> str:
|
||||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
|
|
||||||
|
|
||||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||||
|
|
|
@ -3,11 +3,12 @@
|
||||||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||||
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from deepspeed.accelerator import get_accelerator # type: ignore
|
||||||
from deepspeed.accelerator import get_accelerator # type: ignore
|
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
||||||
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
|
||||||
|
|
||||||
from llmtuner import ChatModel
|
from llmtuner import ChatModel
|
||||||
|
|
||||||
|
@ -16,25 +17,13 @@ def calculate_flops(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
batch_size: Optional[int] = 1,
|
batch_size: Optional[int] = 1,
|
||||||
seq_length: Optional[int] = 256,
|
seq_length: Optional[int] = 256,
|
||||||
flash_attn: Optional[bool] = False
|
flash_attn: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
with get_accelerator().device(0):
|
with get_accelerator().device(0):
|
||||||
chat_model = ChatModel(dict(
|
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="vanilla", flash_attn=flash_attn))
|
||||||
model_name_or_path=model_name_or_path,
|
|
||||||
template="vanilla",
|
|
||||||
flash_attn=flash_attn
|
|
||||||
))
|
|
||||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
||||||
input_dict = {
|
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
|
||||||
"input_ids": fake_input,
|
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
|
||||||
"labels": fake_input.clone()
|
|
||||||
}
|
|
||||||
flops, macs, params = get_model_profile(
|
|
||||||
chat_model.model,
|
|
||||||
kwargs=input_dict,
|
|
||||||
print_profile=True,
|
|
||||||
detailed=True
|
|
||||||
)
|
|
||||||
print("FLOPs:", flops)
|
print("FLOPs:", flops)
|
||||||
print("MACs:", macs)
|
print("MACs:", macs)
|
||||||
print("Params:", params)
|
print("Params:", params)
|
||||||
|
|
|
@ -3,12 +3,13 @@
|
||||||
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
||||||
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
||||||
|
|
||||||
import fire
|
|
||||||
import math
|
import math
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from llmtuner.data import get_dataset
|
from llmtuner.data import get_dataset
|
||||||
|
@ -17,8 +18,8 @@ from llmtuner.hparams import get_train_args
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from llmtuner.model import load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
|
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
|
||||||
BASE_BS = 4_000_000 # from llama paper
|
BASE_BS = 4_000_000 # from llama paper
|
||||||
|
|
||||||
|
|
||||||
def calculate_lr(
|
def calculate_lr(
|
||||||
|
@ -26,18 +27,20 @@ def calculate_lr(
|
||||||
dataset: str,
|
dataset: str,
|
||||||
cutoff_len: int, # i.e. maximum input length during training
|
cutoff_len: int, # i.e. maximum input length during training
|
||||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||||
is_mistral: bool, # mistral model uses a smaller learning rate,
|
is_mistral: bool, # mistral model uses a smaller learning rate,
|
||||||
dataset_dir: Optional[str] = "data"
|
dataset_dir: Optional[str] = "data",
|
||||||
):
|
):
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
|
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||||
stage="sft",
|
dict(
|
||||||
model_name_or_path=model_name_or_path,
|
stage="sft",
|
||||||
dataset=dataset,
|
model_name_or_path=model_name_or_path,
|
||||||
dataset_dir=dataset_dir,
|
dataset=dataset,
|
||||||
template="default",
|
dataset_dir=dataset_dir,
|
||||||
cutoff_len=cutoff_len,
|
template="default",
|
||||||
output_dir="dummy_dir"
|
cutoff_len=cutoff_len,
|
||||||
))
|
output_dir="dummy_dir",
|
||||||
|
)
|
||||||
|
)
|
||||||
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
|
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
|
||||||
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||||
|
@ -49,14 +52,16 @@ def calculate_lr(
|
||||||
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
||||||
total_tokens += torch.numel(batch["labels"])
|
total_tokens += torch.numel(batch["labels"])
|
||||||
|
|
||||||
batch_max_len = cutoff_len * batch_size # max tokens in a batch
|
batch_max_len = cutoff_len * batch_size # max tokens in a batch
|
||||||
valid_ratio = valid_tokens / total_tokens
|
valid_ratio = valid_tokens / total_tokens
|
||||||
batch_valid_len = batch_max_len * valid_ratio
|
batch_valid_len = batch_max_len * valid_ratio
|
||||||
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
|
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
|
||||||
lr = lr / 6.0 if is_mistral else lr
|
lr = lr / 6.0 if is_mistral else lr
|
||||||
print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
print(
|
||||||
lr, valid_ratio * 100, batch_valid_len
|
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
||||||
))
|
lr, valid_ratio * 100, batch_valid_len
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -4,32 +4,28 @@
|
||||||
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
|
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
|
||||||
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
||||||
|
|
||||||
import os
|
|
||||||
import fire
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import os
|
||||||
from tqdm import tqdm
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from safetensors.torch import save_file
|
|
||||||
from transformers.modeling_utils import (
|
|
||||||
shard_checkpoint,
|
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
WEIGHTS_INDEX_NAME
|
|
||||||
)
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers.modeling_utils import (
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
shard_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
|
|
||||||
|
|
||||||
def save_weight(
|
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
|
||||||
input_dir: str,
|
|
||||||
output_dir: str,
|
|
||||||
shard_size: str,
|
|
||||||
save_safetensors: bool
|
|
||||||
):
|
|
||||||
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||||
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
||||||
|
@ -41,8 +37,8 @@ def save_weight(
|
||||||
if "W_pack" in key:
|
if "W_pack" in key:
|
||||||
proj_size = value.size(0) // 3
|
proj_size = value.size(0) // 3
|
||||||
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
|
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
|
||||||
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size:2*proj_size, :]
|
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size : 2 * proj_size, :]
|
||||||
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*proj_size:, :]
|
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2 * proj_size :, :]
|
||||||
elif "lm_head" in key:
|
elif "lm_head" in key:
|
||||||
llama2_state_dict[key] = torch.nn.functional.normalize(value)
|
llama2_state_dict[key] = torch.nn.functional.normalize(value)
|
||||||
else:
|
else:
|
||||||
|
@ -56,7 +52,7 @@ def save_weight(
|
||||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||||
else:
|
else:
|
||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||||
else:
|
else:
|
||||||
|
@ -66,10 +62,7 @@ def save_weight(
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print("Model weights saved in {}".format(output_dir))
|
||||||
|
|
||||||
|
|
||||||
def save_config(
|
def save_config(input_dir: str, output_dir: str):
|
||||||
input_dir: str,
|
|
||||||
output_dir: str
|
|
||||||
):
|
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
|
@ -83,19 +76,14 @@ def save_config(
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||||
|
|
||||||
|
|
||||||
def llamafy_baichuan2(
|
def llamafy_baichuan2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
|
||||||
input_dir: str,
|
|
||||||
output_dir: str,
|
|
||||||
shard_size: str,
|
|
||||||
save_safetensors: Optional[bool] = False
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
os.makedirs(output_dir, exist_ok=False)
|
os.makedirs(output_dir, exist_ok=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise print("Output dir already exists", e)
|
raise print("Output dir already exists", e)
|
||||||
|
|
||||||
save_weight(input_dir, output_dir, shard_size, save_safetensors)
|
save_weight(input_dir, output_dir, shard_size, save_safetensors)
|
||||||
save_config(input_dir, output_dir)
|
save_config(input_dir, output_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -3,32 +3,28 @@
|
||||||
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
|
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
|
||||||
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
|
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
|
||||||
|
|
||||||
import os
|
|
||||||
import fire
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import os
|
||||||
from tqdm import tqdm
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from safetensors.torch import save_file
|
|
||||||
from transformers.modeling_utils import (
|
|
||||||
shard_checkpoint,
|
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
WEIGHTS_INDEX_NAME
|
|
||||||
)
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers.modeling_utils import (
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
shard_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
|
|
||||||
|
|
||||||
def save_weight(
|
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
|
||||||
input_dir: str,
|
|
||||||
output_dir: str,
|
|
||||||
shard_size: str,
|
|
||||||
save_safetensors: bool
|
|
||||||
):
|
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||||
internlm2_config_dict: Dict[str, Any] = json.load(f)
|
internlm2_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
|
@ -50,8 +46,10 @@ def save_weight(
|
||||||
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
|
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
|
||||||
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
|
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
|
||||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
|
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
|
||||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...]
|
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[
|
||||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...]
|
q_size : q_size + kv_size, ...
|
||||||
|
]
|
||||||
|
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...]
|
||||||
elif "wo" in key:
|
elif "wo" in key:
|
||||||
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
|
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
|
||||||
elif "attention_norm" in key:
|
elif "attention_norm" in key:
|
||||||
|
@ -85,10 +83,7 @@ def save_weight(
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print("Model weights saved in {}".format(output_dir))
|
||||||
|
|
||||||
|
|
||||||
def save_config(
|
def save_config(input_dir: str, output_dir: str):
|
||||||
input_dir: str,
|
|
||||||
output_dir: str
|
|
||||||
):
|
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
|
@ -103,12 +98,7 @@ def save_config(
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||||
|
|
||||||
|
|
||||||
def llamafy_internlm2(
|
def llamafy_internlm2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
|
||||||
input_dir: str,
|
|
||||||
output_dir: str,
|
|
||||||
shard_size: str,
|
|
||||||
save_safetensors: Optional[bool] = False
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
os.makedirs(output_dir, exist_ok=False)
|
os.makedirs(output_dir, exist_ok=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -3,39 +3,36 @@
|
||||||
# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB
|
# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB
|
||||||
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
||||||
|
|
||||||
import os
|
|
||||||
import fire
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import os
|
||||||
from tqdm import tqdm
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
from tqdm import tqdm
|
||||||
from transformers.modeling_utils import (
|
from transformers.modeling_utils import (
|
||||||
shard_checkpoint,
|
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME
|
shard_checkpoint,
|
||||||
)
|
)
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
check_min_version("4.34.0")
|
check_min_version("4.34.0")
|
||||||
except:
|
except Exception:
|
||||||
raise ValueError("Please upgrade `transformers` to 4.34.0")
|
raise ValueError("Please upgrade `transformers` to 4.34.0")
|
||||||
|
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
|
|
||||||
|
|
||||||
def save_weight(
|
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
|
||||||
input_dir: str,
|
|
||||||
output_dir: str,
|
|
||||||
shard_size: str,
|
|
||||||
save_safetensors: bool
|
|
||||||
) -> str:
|
|
||||||
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||||
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
|
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
|
||||||
|
@ -57,13 +54,15 @@ def save_weight(
|
||||||
if "attn.c_attn" in key:
|
if "attn.c_attn" in key:
|
||||||
proj_size = value.size(0) // 3
|
proj_size = value.size(0) // 3
|
||||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
|
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
|
||||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[proj_size:2*proj_size, ...]
|
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[
|
||||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2*proj_size:, ...]
|
proj_size : 2 * proj_size, ...
|
||||||
|
]
|
||||||
|
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2 * proj_size :, ...]
|
||||||
elif "attn.c_proj" in key:
|
elif "attn.c_proj" in key:
|
||||||
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
|
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
|
||||||
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = (
|
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = torch.zeros_like(
|
||||||
torch.zeros_like(value[:, 0]).squeeze()
|
value[:, 0]
|
||||||
)
|
).squeeze()
|
||||||
elif "ln_1" in key:
|
elif "ln_1" in key:
|
||||||
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
|
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
|
||||||
elif "ln_2" in key:
|
elif "ln_2" in key:
|
||||||
|
@ -99,11 +98,7 @@ def save_weight(
|
||||||
return str(torch_dtype).replace("torch.", "")
|
return str(torch_dtype).replace("torch.", "")
|
||||||
|
|
||||||
|
|
||||||
def save_config(
|
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||||
input_dir: str,
|
|
||||||
output_dir: str,
|
|
||||||
torch_dtype: str
|
|
||||||
):
|
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
|
@ -133,12 +128,7 @@ def save_config(
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||||
|
|
||||||
|
|
||||||
def llamafy_qwen(
|
def llamafy_qwen(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
|
||||||
input_dir: str,
|
|
||||||
output_dir: str,
|
|
||||||
shard_size: str,
|
|
||||||
save_safetensors: Optional[bool] = False
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
os.makedirs(output_dir, exist_ok=False)
|
os.makedirs(output_dir, exist_ok=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -4,12 +4,13 @@
|
||||||
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
|
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import TYPE_CHECKING, Optional
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -17,7 +18,6 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
class Shell(nn.Module):
|
class Shell(nn.Module):
|
||||||
|
|
||||||
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||||
|
@ -26,7 +26,7 @@ class Shell(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
||||||
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]):
|
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]): # noqa: C403
|
||||||
parent_name = ".".join(name.split(".")[:-1])
|
parent_name = ".".join(name.split(".")[:-1])
|
||||||
child_name = name.split(".")[-1]
|
child_name = name.split(".")[-1]
|
||||||
parent_module = model.get_submodule(parent_name)
|
parent_module = model.get_submodule(parent_name)
|
||||||
|
@ -35,7 +35,7 @@ def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
||||||
weight = getattr(base_layer, "weight", None)
|
weight = getattr(base_layer, "weight", None)
|
||||||
bias = getattr(base_layer, "bias", None)
|
bias = getattr(base_layer, "bias", None)
|
||||||
setattr(parent_module, child_name, Shell(weight, bias))
|
setattr(parent_module, child_name, Shell(weight, bias))
|
||||||
|
|
||||||
print("Model unwrapped.")
|
print("Model unwrapped.")
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ def quantize_loftq(
|
||||||
lora_dropout=0.1,
|
lora_dropout=0.1,
|
||||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||||
init_lora_weights="loftq",
|
init_lora_weights="loftq",
|
||||||
loftq_config=loftq_config
|
loftq_config=loftq_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init LoftQ model
|
# Init LoftQ model
|
||||||
|
|
Loading…
Reference in New Issue