forked from p04798526/LLaMA-Factory-Mirror
disentangle model from tuner and rename modules
This commit is contained in:
parent
2f02f688e1
commit
4736344eb1
|
@ -1,9 +1,9 @@
|
||||||
# Level: api, webui > chat, eval > tuner > dsets > extras, hparams
|
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
||||||
|
|
||||||
from llmtuner.api import create_app
|
from llmtuner.api import create_app
|
||||||
from llmtuner.chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
from llmtuner.eval import Evaluator
|
from llmtuner.eval import Evaluator
|
||||||
from llmtuner.tuner import export_model, run_exp
|
from llmtuner.train import export_model, run_exp
|
||||||
from llmtuner.webui import create_ui, create_web_demo
|
from llmtuner.webui import create_ui, create_web_demo
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,8 @@
|
||||||
import json
|
import json
|
||||||
import uvicorn
|
|
||||||
from fastapi import FastAPI, HTTPException, status
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from sse_starlette import EventSourceResponse
|
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from llmtuner.extras.misc import torch_gc
|
|
||||||
from llmtuner.chat import ChatModel
|
|
||||||
from llmtuner.api.protocol import (
|
from llmtuner.api.protocol import (
|
||||||
Role,
|
Role,
|
||||||
Finish,
|
Finish,
|
||||||
|
@ -23,10 +17,28 @@ from llmtuner.api.protocol import (
|
||||||
ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionResponseUsage
|
ChatCompletionResponseUsage
|
||||||
)
|
)
|
||||||
|
from llmtuner.chat import ChatModel
|
||||||
|
from llmtuner.extras.misc import torch_gc
|
||||||
|
from llmtuner.extras.packages import (
|
||||||
|
is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_fastapi_availble():
|
||||||
|
from fastapi import FastAPI, HTTPException, status
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
if is_starlette_available():
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
|
|
||||||
|
|
||||||
|
if is_uvicorn_available():
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI): # collects GPU memory
|
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||||
yield
|
yield
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
@ -38,7 +50,7 @@ def to_json(data: BaseModel) -> str:
|
||||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def create_app(chat_model: ChatModel) -> FastAPI:
|
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
@ -56,12 +68,12 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
|
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
query = request.messages[-1].content
|
query = request.messages[-1].content
|
||||||
prev_messages = request.messages[:-1]
|
prev_messages = request.messages[:-1]
|
||||||
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
|
||||||
system = prev_messages.pop(0).content
|
system = prev_messages.pop(0).content
|
||||||
else:
|
else:
|
||||||
system = None
|
system = None
|
||||||
|
@ -73,12 +85,14 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(query, history, system, request)
|
generate = predict(query, history, system, request)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
response, (prompt_length, response_length) = chat_model.chat(
|
responses = chat_model.chat(
|
||||||
query, history, system,
|
query, history, system,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
|
@ -87,18 +101,23 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||||
num_return_sequences=request.n
|
num_return_sequences=request.n
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt_length, response_length = 0, 0
|
||||||
|
choices = []
|
||||||
|
for i, response in enumerate(responses):
|
||||||
|
choices.append(ChatCompletionResponseChoice(
|
||||||
|
index=i,
|
||||||
|
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
||||||
|
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||||
|
))
|
||||||
|
prompt_length = response.prompt_length
|
||||||
|
response_length += response.response_length
|
||||||
|
|
||||||
usage = ChatCompletionResponseUsage(
|
usage = ChatCompletionResponseUsage(
|
||||||
prompt_tokens=prompt_length,
|
prompt_tokens=prompt_length,
|
||||||
completion_tokens=response_length,
|
completion_tokens=response_length,
|
||||||
total_tokens=prompt_length+response_length
|
total_tokens=prompt_length+response_length
|
||||||
)
|
)
|
||||||
|
|
||||||
choices = [ChatCompletionResponseChoice(
|
|
||||||
index=i,
|
|
||||||
message=ChatMessage(role=Role.ASSISTANT, content=choice),
|
|
||||||
finish_reason=Finish.STOP
|
|
||||||
) for i, choice in enumerate(response)]
|
|
||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||||
|
|
||||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
from llmtuner.chat.stream_chat import ChatModel
|
from llmtuner.chat.chat_model import ChatModel
|
||||||
|
|
|
@ -1,11 +1,21 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||||
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Response:
|
||||||
|
|
||||||
|
response_text: str
|
||||||
|
response_length: int
|
||||||
|
prompt_length: int
|
||||||
|
finish_reason: Literal["stop", "length"]
|
||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
|
@ -18,7 +28,7 @@ class ChatModel:
|
||||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||||
self.system_prompt = data_args.system_prompt
|
self.system_prompt = data_args.system_prompt
|
||||||
|
|
||||||
def process_args(
|
def _process_args(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
@ -79,17 +89,30 @@ class ChatModel:
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> Tuple[List[str], Tuple[int, int]]:
|
) -> List[Response]:
|
||||||
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
|
r"""
|
||||||
|
Args: query, history, system, **input_kwargs
|
||||||
|
|
||||||
|
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
|
||||||
|
"""
|
||||||
|
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
|
||||||
generate_output = self.model.generate(**gen_kwargs)
|
generate_output = self.model.generate(**gen_kwargs)
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
response = self.tokenizer.batch_decode(
|
||||||
response_length = 0
|
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
for i in range(len(response_ids)):
|
)
|
||||||
|
results = []
|
||||||
|
for i in range(len(response)):
|
||||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
||||||
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
|
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
||||||
|
results.append(Response(
|
||||||
|
response_text=response[i],
|
||||||
|
response_length=response_length,
|
||||||
|
prompt_length=prompt_length,
|
||||||
|
finish_reason="stop" if len(eos_index) else "length"
|
||||||
|
))
|
||||||
|
|
||||||
return response, (prompt_length, response_length)
|
return results
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
|
@ -99,7 +122,7 @@ class ChatModel:
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
|
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
from llmtuner.data.loader import get_dataset
|
||||||
|
from llmtuner.data.preprocess import preprocess_dataset
|
||||||
|
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||||
|
from llmtuner.data.utils import split_dataset
|
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||||
|
|
||||||
from llmtuner.dsets.utils import checksum, EXT2TYPE
|
from llmtuner.data.utils import checksum, EXT2TYPE
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
|
@ -5,9 +5,9 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Un
|
||||||
|
|
||||||
from datasets import load_from_disk
|
from datasets import load_from_disk
|
||||||
|
|
||||||
|
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
|
@ -1,3 +0,0 @@
|
||||||
from llmtuner.dsets.loader import get_dataset
|
|
||||||
from llmtuner.dsets.preprocess import preprocess_dataset
|
|
||||||
from llmtuner.dsets.utils import split_dataset
|
|
|
@ -1 +1 @@
|
||||||
from llmtuner.eval.engine import Evaluator
|
from llmtuner.eval.evaluator import Evaluator
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
CHOICES = ["A", "B", "C", "D"]
|
|
||||||
|
|
||||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
|
|
@ -11,12 +11,10 @@ from typing import Any, Dict, List, Optional
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from llmtuner.eval.constants import CHOICES, SUBJECTS
|
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||||
from llmtuner.eval.parser import get_eval_args
|
|
||||||
from llmtuner.eval.template import get_eval_template
|
from llmtuner.eval.template import get_eval_template
|
||||||
from llmtuner.extras.misc import dispatch_model
|
from llmtuner.extras.constants import CHOICES, SUBJECTS
|
||||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
|
@ -1,49 +0,0 @@
|
||||||
import transformers
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
|
||||||
from transformers import HfArgumentParser
|
|
||||||
|
|
||||||
from llmtuner.extras.misc import parse_args
|
|
||||||
from llmtuner.hparams import (
|
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
EvaluationArguments,
|
|
||||||
FinetuningArguments
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_eval_args(
|
|
||||||
args: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Tuple[
|
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
EvaluationArguments,
|
|
||||||
FinetuningArguments
|
|
||||||
]:
|
|
||||||
parser = HfArgumentParser((
|
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
EvaluationArguments,
|
|
||||||
FinetuningArguments
|
|
||||||
))
|
|
||||||
return parse_args(parser, args)
|
|
||||||
|
|
||||||
|
|
||||||
def get_eval_args(
|
|
||||||
args: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Tuple[
|
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
EvaluationArguments,
|
|
||||||
FinetuningArguments
|
|
||||||
]:
|
|
||||||
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
|
||||||
|
|
||||||
if data_args.template is None:
|
|
||||||
raise ValueError("Please specify which `template` to use.")
|
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
|
||||||
|
|
||||||
transformers.set_seed(eval_args.seed)
|
|
||||||
|
|
||||||
return model_args, data_args, eval_args, finetuning_args
|
|
|
@ -1,7 +1,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||||
|
|
||||||
from llmtuner.eval.constants import CHOICES
|
from llmtuner.extras.constants import CHOICES
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
|
@ -2,12 +2,24 @@ from collections import defaultdict, OrderedDict
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
CHOICES = ["A", "B", "C", "D"]
|
||||||
|
|
||||||
|
DEFAULT_MODULE = defaultdict(str)
|
||||||
|
|
||||||
|
DEFAULT_TEMPLATE = defaultdict(str)
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
LAYERNORM_NAMES = {"norm", "ln"}
|
||||||
|
|
||||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora"]
|
METHODS = ["full", "freeze", "lora"]
|
||||||
|
|
||||||
|
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||||
|
|
||||||
|
SUPPORTED_MODELS = OrderedDict()
|
||||||
|
|
||||||
TRAINING_STAGES = {
|
TRAINING_STAGES = {
|
||||||
"Supervised Fine-Tuning": "sft",
|
"Supervised Fine-Tuning": "sft",
|
||||||
"Reward Modeling": "rm",
|
"Reward Modeling": "rm",
|
||||||
|
@ -16,14 +28,6 @@ TRAINING_STAGES = {
|
||||||
"Pre-Training": "pt"
|
"Pre-Training": "pt"
|
||||||
}
|
}
|
||||||
|
|
||||||
LAYERNORM_NAMES = {"norm", "ln"}
|
|
||||||
|
|
||||||
SUPPORTED_MODELS = OrderedDict()
|
|
||||||
|
|
||||||
DEFAULT_MODULE = defaultdict(str)
|
|
||||||
|
|
||||||
DEFAULT_TEMPLATE = defaultdict(str)
|
|
||||||
|
|
||||||
|
|
||||||
def register_model_group(
|
def register_model_group(
|
||||||
models: Dict[str, str],
|
models: Dict[str, str],
|
||||||
|
|
|
@ -13,14 +13,13 @@ try:
|
||||||
is_torch_npu_available
|
is_torch_npu_available
|
||||||
)
|
)
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
|
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_is_fp16_available = torch.cuda.is_available()
|
_is_fp16_available = torch.cuda.is_available()
|
||||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
_is_bf16_available = torch.cuda.is_bf16_supported()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
|
@ -65,6 +64,15 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits_processor() -> "LogitsProcessorList":
|
||||||
|
r"""
|
||||||
|
Gets logits processor that removes NaN and Inf logits.
|
||||||
|
"""
|
||||||
|
logits_processor = LogitsProcessorList()
|
||||||
|
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||||
|
return logits_processor
|
||||||
|
|
||||||
|
|
||||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||||
r"""
|
r"""
|
||||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||||
|
@ -77,25 +85,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> "LogitsProcessorList":
|
|
||||||
r"""
|
|
||||||
Gets logits processor that removes NaN and Inf logits.
|
|
||||||
"""
|
|
||||||
logits_processor = LogitsProcessorList()
|
|
||||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
|
||||||
return logits_processor
|
|
||||||
|
|
||||||
|
|
||||||
def torch_gc() -> None:
|
|
||||||
r"""
|
|
||||||
Collects GPU memory.
|
|
||||||
"""
|
|
||||||
gc.collect()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
if args is not None:
|
if args is not None:
|
||||||
return parser.parse_dict(args)
|
return parser.parse_dict(args)
|
||||||
|
@ -107,26 +96,11 @@ def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None
|
||||||
return parser.parse_args_into_dataclasses()
|
return parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
|
||||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
def torch_gc() -> None:
|
||||||
r"""
|
r"""
|
||||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
Collects GPU memory.
|
||||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
|
||||||
"""
|
"""
|
||||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
gc.collect()
|
||||||
return model
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
if torch.cuda.device_count() > 1:
|
torch.cuda.ipc_collect()
|
||||||
from accelerate import dispatch_model
|
|
||||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
|
||||||
|
|
||||||
if model._no_split_modules is None:
|
|
||||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
|
||||||
|
|
||||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
|
||||||
max_memory = get_balanced_memory(model, **kwargs)
|
|
||||||
# Make sure tied weights are tied before creating the device map.
|
|
||||||
model.tie_weights()
|
|
||||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
|
||||||
return dispatch_model(model, device_map)
|
|
||||||
else:
|
|
||||||
return model.cuda()
|
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
import importlib.metadata
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
|
||||||
|
def is_package_available(name: str) -> bool:
|
||||||
|
return importlib.util.find_spec(name) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def get_package_version(name: str) -> str:
|
||||||
|
try:
|
||||||
|
return importlib.metadata.version(name)
|
||||||
|
except:
|
||||||
|
return "0.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
_fastapi_available = is_package_available("fastapi")
|
||||||
|
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
||||||
|
_jieba_available = is_package_available("jieba")
|
||||||
|
_matplotlib_available = is_package_available("matplotlib")
|
||||||
|
_nltk_available = is_package_available("nltk")
|
||||||
|
_rouge_available = is_package_available("rouge-chinese")
|
||||||
|
_starlette_available = is_package_available("sse-starlette")
|
||||||
|
_uvicorn_available = is_package_available("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
|
def is_fastapi_availble():
|
||||||
|
return _fastapi_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_flash_attn2_available():
|
||||||
|
return _flash_attn2_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_jieba_available():
|
||||||
|
return _jieba_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_matplotlib_available():
|
||||||
|
return _matplotlib_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_nltk_available():
|
||||||
|
return _nltk_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_rouge_available():
|
||||||
|
return _rouge_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_starlette_available():
|
||||||
|
return _starlette_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_uvicorn_available():
|
||||||
|
return _uvicorn_available
|
|
@ -3,16 +3,19 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||||
|
|
||||||
is_flash_attn_2_available = False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from transformers.models.llama.modeling_llama import repeat_kv
|
||||||
|
except ImportError:
|
||||||
|
print("Please upgrade `transformers`.")
|
||||||
|
|
||||||
|
from llmtuner.extras.packages import is_flash_attn2_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn2_available():
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
||||||
is_flash_attn_2_available = True
|
|
||||||
except ImportError:
|
|
||||||
is_flash_attn_2_available = False
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import json
|
import json
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from transformers.trainer import TRAINER_STATE_NAME
|
from transformers.trainer import TRAINER_STATE_NAME
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.extras.packages import is_matplotlib_available
|
||||||
|
|
||||||
|
if is_matplotlib_available():
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
from llmtuner.model.loader import load_model_and_tokenizer
|
||||||
|
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
|
||||||
|
from llmtuner.model.utils import dispatch_model, generate_model_card
|
|
@ -1,18 +1,12 @@
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
from peft import (
|
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
||||||
PeftModel,
|
|
||||||
TaskType,
|
|
||||||
LoraConfig,
|
|
||||||
get_peft_model
|
|
||||||
)
|
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.tuner.core.utils import find_all_linear_modules
|
from llmtuner.model.utils import find_all_linear_modules
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
|
@ -25,10 +25,11 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
||||||
|
|
||||||
from llmtuner.extras.logging import reset_logging, get_logger
|
from llmtuner.extras.logging import reset_logging, get_logger
|
||||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
||||||
|
from llmtuner.extras.packages import is_flash_attn2_available
|
||||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params
|
from llmtuner.model.adapter import init_adapter, load_valuehead_params
|
||||||
from llmtuner.tuner.core.utils import prepare_model_for_training
|
from llmtuner.model.utils import prepare_model_for_training
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
@ -122,7 +123,7 @@ def load_model_and_tokenizer(
|
||||||
# Set FlashAttention-2
|
# Set FlashAttention-2
|
||||||
if model_args.flash_attn:
|
if model_args.flash_attn:
|
||||||
if getattr(config, "model_type", None) == "llama":
|
if getattr(config, "model_type", None) == "llama":
|
||||||
if LlamaPatches.is_flash_attn_2_available:
|
if is_flash_attn2_available():
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
|
@ -131,7 +132,7 @@ def load_model_and_tokenizer(
|
||||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support FlashAttention-2.")
|
logger.warning("Current model does not support FlashAttention.")
|
||||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
|
@ -11,6 +11,7 @@ from llmtuner.extras.misc import parse_args
|
||||||
from llmtuner.hparams import (
|
from llmtuner.hparams import (
|
||||||
ModelArguments,
|
ModelArguments,
|
||||||
DataArguments,
|
DataArguments,
|
||||||
|
EvaluationArguments,
|
||||||
FinetuningArguments,
|
FinetuningArguments,
|
||||||
GeneratingArguments
|
GeneratingArguments
|
||||||
)
|
)
|
||||||
|
@ -19,51 +20,42 @@ from llmtuner.hparams import (
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def parse_train_args(
|
_TRAIN_ARGS = [
|
||||||
args: Optional[Dict[str, Any]] = None
|
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||||
) -> Tuple[
|
]
|
||||||
ModelArguments,
|
_TRAIN_CLS = Tuple[
|
||||||
DataArguments,
|
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||||
Seq2SeqTrainingArguments,
|
]
|
||||||
FinetuningArguments,
|
_INFER_ARGS = [
|
||||||
GeneratingArguments
|
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
]:
|
]
|
||||||
parser = HfArgumentParser((
|
_INFER_CLS = Tuple[
|
||||||
ModelArguments,
|
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
DataArguments,
|
]
|
||||||
Seq2SeqTrainingArguments,
|
_EVAL_ARGS = [
|
||||||
FinetuningArguments,
|
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
||||||
GeneratingArguments
|
]
|
||||||
))
|
_EVAL_CLS = Tuple[
|
||||||
|
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
|
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||||
return parse_args(parser, args)
|
return parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
def parse_infer_args(
|
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
args: Optional[Dict[str, Any]] = None
|
parser = HfArgumentParser(_INFER_ARGS)
|
||||||
) -> Tuple[
|
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
]:
|
|
||||||
parser = HfArgumentParser((
|
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
))
|
|
||||||
return parse_args(parser, args)
|
return parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
def get_train_args(
|
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||||
args: Optional[Dict[str, Any]] = None
|
parser = HfArgumentParser(_EVAL_ARGS)
|
||||||
) -> Tuple[
|
return parse_args(parser, args)
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
Seq2SeqTrainingArguments,
|
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
]:
|
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
|
@ -187,14 +179,7 @@ def get_train_args(
|
||||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
def get_infer_args(
|
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
args: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Tuple[
|
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
]:
|
|
||||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
|
@ -211,3 +196,17 @@ def get_infer_args(
|
||||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
|
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||||
|
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
||||||
|
|
||||||
|
if data_args.template is None:
|
||||||
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
|
transformers.set_seed(eval_args.seed)
|
||||||
|
|
||||||
|
return model_args, data_args, eval_args, finetuning_args
|
|
@ -12,6 +12,31 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
|
r"""
|
||||||
|
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||||
|
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||||
|
"""
|
||||||
|
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||||
|
return model
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
from accelerate import dispatch_model
|
||||||
|
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||||
|
|
||||||
|
if model._no_split_modules is None:
|
||||||
|
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||||
|
|
||||||
|
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
||||||
|
max_memory = get_balanced_memory(model, **kwargs)
|
||||||
|
# Make sure tied weights are tied before creating the device map.
|
||||||
|
model.tie_weights()
|
||||||
|
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||||
|
return dispatch_model(model, device_map)
|
||||||
|
else:
|
||||||
|
return model.cuda()
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(
|
def find_all_linear_modules(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
quantization_bit: Optional[int] = None
|
quantization_bit: Optional[int] = None
|
|
@ -0,0 +1 @@
|
||||||
|
from llmtuner.train.tuner import export_model, run_exp
|
|
@ -0,0 +1 @@
|
||||||
|
from llmtuner.train.dpo.workflow import run_dpo
|
|
@ -4,14 +4,14 @@ from peft import PeftModel
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.hparams import ModelArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||||
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
|
||||||
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
|
from llmtuner.train.dpo.trainer import CustomDPOTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
|
@ -0,0 +1 @@
|
||||||
|
from llmtuner.train.ppo.workflow import run_ppo
|
|
@ -3,7 +3,7 @@ import sys
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
||||||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
|
@ -7,11 +7,11 @@ from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorWithPadding
|
from transformers import DataCollatorWithPadding
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset
|
||||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.model import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
|
from llmtuner.train.ppo.trainer import CustomPPOTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
|
@ -0,0 +1 @@
|
||||||
|
from llmtuner.train.pt.workflow import run_pt
|
|
@ -4,9 +4,9 @@ import math
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
from transformers import DataCollatorForLanguageModeling, Trainer
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
|
@ -0,0 +1 @@
|
||||||
|
from llmtuner.train.rm.workflow import run_rm
|
|
@ -3,13 +3,13 @@
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
|
||||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
from llmtuner.train.rm.metric import compute_accuracy
|
||||||
from llmtuner.tuner.rm.trainer import PairwiseTrainer
|
from llmtuner.train.rm.trainer import PairwiseTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
|
@ -0,0 +1 @@
|
||||||
|
from llmtuner.train.sft.workflow import run_sft
|
|
@ -2,15 +2,23 @@ import numpy as np
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
import jieba
|
|
||||||
from rouge_chinese import Rouge
|
|
||||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
from llmtuner.extras.packages import (
|
||||||
|
is_jieba_available, is_nltk_available, is_rouge_available
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
if is_jieba_available():
|
||||||
|
import jieba
|
||||||
|
|
||||||
|
if is_nltk_available():
|
||||||
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||||
|
|
||||||
|
if is_rouge_available():
|
||||||
|
from rouge_chinese import Rouge
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ComputeMetrics:
|
class ComputeMetrics:
|
|
@ -3,13 +3,13 @@
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
from llmtuner.train.sft.metric import ComputeMetrics
|
||||||
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
|
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
|
@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||||
from llmtuner.tuner.pt import run_pt
|
from llmtuner.train.pt import run_pt
|
||||||
from llmtuner.tuner.sft import run_sft
|
from llmtuner.train.sft import run_sft
|
||||||
from llmtuner.tuner.rm import run_rm
|
from llmtuner.train.rm import run_rm
|
||||||
from llmtuner.tuner.ppo import run_ppo
|
from llmtuner.train.ppo import run_ppo
|
||||||
from llmtuner.tuner.dpo import run_dpo
|
from llmtuner.train.dpo import run_dpo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
|
@ -1 +0,0 @@
|
||||||
from llmtuner.tuner.tune import export_model, run_exp
|
|
|
@ -1,3 +0,0 @@
|
||||||
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
|
|
||||||
from llmtuner.tuner.core.loader import load_model_and_tokenizer
|
|
||||||
from llmtuner.tuner.core.utils import generate_model_card
|
|
|
@ -1 +0,0 @@
|
||||||
from llmtuner.tuner.dpo.workflow import run_dpo
|
|
|
@ -1 +0,0 @@
|
||||||
from llmtuner.tuner.ppo.workflow import run_ppo
|
|
|
@ -1 +0,0 @@
|
||||||
from llmtuner.tuner.pt.workflow import run_pt
|
|
|
@ -1 +0,0 @@
|
||||||
from llmtuner.tuner.rm.workflow import run_rm
|
|
|
@ -1 +0,0 @@
|
||||||
from llmtuner.tuner.sft.workflow import run_sft
|
|
|
@ -2,7 +2,7 @@ import gradio as gr
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
from llmtuner.chat.stream_chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.hparams import GeneratingArguments
|
from llmtuner.hparams import GeneratingArguments
|
||||||
from llmtuner.webui.common import get_save_dir
|
from llmtuner.webui.common import get_save_dir
|
||||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
|
@ -13,7 +13,7 @@ from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import TRAINING_STAGES
|
from llmtuner.extras.constants import TRAINING_STAGES
|
||||||
from llmtuner.extras.logging import LoggerHandler
|
from llmtuner.extras.logging import LoggerHandler
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.tuner import run_exp
|
from llmtuner.train import run_exp
|
||||||
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
||||||
from llmtuner.webui.locales import ALERTS
|
from llmtuner.webui.locales import ALERTS
|
||||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||||
|
|
|
@ -1,17 +1,20 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import matplotlib.figure
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from llmtuner.extras.packages import is_matplotlib_available
|
||||||
from llmtuner.extras.ploting import smooth
|
from llmtuner.extras.ploting import smooth
|
||||||
from llmtuner.webui.common import get_save_dir
|
from llmtuner.webui.common import get_save_dir
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
|
|
||||||
|
if is_matplotlib_available():
|
||||||
|
import matplotlib.figure
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
||||||
if not callback.max_steps:
|
if not callback.max_steps:
|
||||||
|
@ -56,7 +59,7 @@ def get_eval_results(path: os.PathLike) -> str:
|
||||||
return "```json\n{}\n```\n".format(result)
|
return "```json\n{}\n```\n".format(result)
|
||||||
|
|
||||||
|
|
||||||
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
|
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
|
||||||
if not base_model:
|
if not base_model:
|
||||||
return
|
return
|
||||||
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
|
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
|
||||||
|
|
|
@ -7,12 +7,13 @@ import fire
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from typing import Optional
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
from llmtuner.model import get_train_args, load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
|
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
|
||||||
|
@ -22,14 +23,16 @@ BASE_BS = 4_000_000 # from llama paper
|
||||||
def calculate_lr(
|
def calculate_lr(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
dataset: str,
|
dataset: str,
|
||||||
cutoff_len: int, # i.e. maximum input length during training
|
cutoff_len: int, # i.e. maximum input length during training
|
||||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||||
is_mistral: bool # mistral model uses a smaller learning rate
|
is_mistral: bool, # mistral model uses a smaller learning rate,
|
||||||
|
dataset_dir: Optional[str] = "data"
|
||||||
):
|
):
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
|
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
|
||||||
stage="sft",
|
stage="sft",
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
dataset_dir=dataset_dir,
|
||||||
template="default",
|
template="default",
|
||||||
cutoff_len=cutoff_len,
|
cutoff_len=cutoff_len,
|
||||||
output_dir="dummy_dir"
|
output_dir="dummy_dir"
|
||||||
|
|
Loading…
Reference in New Issue