forked from p04798526/LLaMA-Factory-Mirror
disentangle model from tuner and rename modules
This commit is contained in:
parent
2f02f688e1
commit
4736344eb1
|
@ -1,9 +1,9 @@
|
|||
# Level: api, webui > chat, eval > tuner > dsets > extras, hparams
|
||||
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
||||
|
||||
from llmtuner.api import create_app
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.eval import Evaluator
|
||||
from llmtuner.tuner import export_model, run_exp
|
||||
from llmtuner.train import export_model, run_exp
|
||||
from llmtuner.webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
import json
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from sse_starlette import EventSourceResponse
|
||||
from typing import List, Tuple
|
||||
from pydantic import BaseModel
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.api.protocol import (
|
||||
Role,
|
||||
Finish,
|
||||
|
@ -23,10 +17,28 @@ from llmtuner.api.protocol import (
|
|||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage
|
||||
)
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.extras.packages import (
|
||||
is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
if is_starlette_available():
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
|
||||
if is_uvicorn_available():
|
||||
import uvicorn
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI): # collects GPU memory
|
||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
yield
|
||||
torch_gc()
|
||||
|
||||
|
@ -38,7 +50,7 @@ def to_json(data: BaseModel) -> str:
|
|||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
|
@ -56,12 +68,12 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
|
||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
query = request.messages[-1].content
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
||||
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
|
||||
system = prev_messages.pop(0).content
|
||||
else:
|
||||
system = None
|
||||
|
@ -73,12 +85,14 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
if request.stream:
|
||||
generate = predict(query, history, system, request)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
response, (prompt_length, response_length) = chat_model.chat(
|
||||
responses = chat_model.chat(
|
||||
query, history, system,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
|
@ -87,18 +101,23 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
num_return_sequences=request.n
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
choices.append(ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
||||
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
))
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length+response_length
|
||||
)
|
||||
|
||||
choices = [ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=choice),
|
||||
finish_reason=Finish.STOP
|
||||
) for i, choice in enumerate(response)]
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
|
|
|
@ -1 +1 @@
|
|||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.chat.chat_model import ChatModel
|
||||
|
|
|
@ -1,11 +1,21 @@
|
|||
import torch
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
|
||||
response_text: str
|
||||
response_length: int
|
||||
prompt_length: int
|
||||
finish_reason: Literal["stop", "length"]
|
||||
|
||||
|
||||
class ChatModel:
|
||||
|
@ -18,7 +28,7 @@ class ChatModel:
|
|||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.system_prompt = data_args.system_prompt
|
||||
|
||||
def process_args(
|
||||
def _process_args(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
|
@ -79,17 +89,30 @@ class ChatModel:
|
|||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[List[str], Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
|
||||
) -> List[Response]:
|
||||
r"""
|
||||
Args: query, history, system, **input_kwargs
|
||||
|
||||
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
|
||||
"""
|
||||
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
|
||||
generate_output = self.model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
response_length = 0
|
||||
for i in range(len(response_ids)):
|
||||
response = self.tokenizer.batch_decode(
|
||||
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
results = []
|
||||
for i in range(len(response)):
|
||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
||||
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
|
||||
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
||||
results.append(Response(
|
||||
response_text=response[i],
|
||||
response_length=response_length,
|
||||
prompt_length=prompt_length,
|
||||
finish_reason="stop" if len(eos_index) else "length"
|
||||
))
|
||||
|
||||
return response, (prompt_length, response_length)
|
||||
return results
|
||||
|
||||
@torch.inference_mode()
|
||||
def stream_chat(
|
||||
|
@ -99,7 +122,7 @@ class ChatModel:
|
|||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
|
||||
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from llmtuner.data.loader import get_dataset
|
||||
from llmtuner.data.preprocess import preprocess_dataset
|
||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||
from llmtuner.data.utils import split_dataset
|
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
|||
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||
|
||||
from llmtuner.dsets.utils import checksum, EXT2TYPE
|
||||
from llmtuner.data.utils import checksum, EXT2TYPE
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
|
@ -5,9 +5,9 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Un
|
|||
|
||||
from datasets import load_from_disk
|
||||
|
||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
|
@ -1,3 +0,0 @@
|
|||
from llmtuner.dsets.loader import get_dataset
|
||||
from llmtuner.dsets.preprocess import preprocess_dataset
|
||||
from llmtuner.dsets.utils import split_dataset
|
|
@ -1 +1 @@
|
|||
from llmtuner.eval.engine import Evaluator
|
||||
from llmtuner.eval.evaluator import Evaluator
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
CHOICES = ["A", "B", "C", "D"]
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
|
@ -11,12 +11,10 @@ from typing import Any, Dict, List, Optional
|
|||
from datasets import load_dataset
|
||||
from transformers.utils import cached_file
|
||||
|
||||
from llmtuner.eval.constants import CHOICES, SUBJECTS
|
||||
from llmtuner.eval.parser import get_eval_args
|
||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||
from llmtuner.eval.template import get_eval_template
|
||||
from llmtuner.extras.misc import dispatch_model
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.extras.constants import CHOICES, SUBJECTS
|
||||
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
|
||||
|
||||
|
||||
class Evaluator:
|
|
@ -1,49 +0,0 @@
|
|||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from llmtuner.extras.misc import parse_args
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
)
|
||||
|
||||
|
||||
def parse_eval_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
))
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def get_eval_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
]:
|
||||
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
transformers.set_seed(eval_args.seed)
|
||||
|
||||
return model_args, data_args, eval_args, finetuning_args
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
|
||||
from llmtuner.eval.constants import CHOICES
|
||||
from llmtuner.extras.constants import CHOICES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
|
|
|
@ -2,12 +2,24 @@ from collections import defaultdict, OrderedDict
|
|||
from typing import Dict, Optional
|
||||
|
||||
|
||||
CHOICES = ["A", "B", "C", "D"]
|
||||
|
||||
DEFAULT_MODULE = defaultdict(str)
|
||||
|
||||
DEFAULT_TEMPLATE = defaultdict(str)
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
|
||||
SUPPORTED_MODELS = OrderedDict()
|
||||
|
||||
TRAINING_STAGES = {
|
||||
"Supervised Fine-Tuning": "sft",
|
||||
"Reward Modeling": "rm",
|
||||
|
@ -16,14 +28,6 @@ TRAINING_STAGES = {
|
|||
"Pre-Training": "pt"
|
||||
}
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
SUPPORTED_MODELS = OrderedDict()
|
||||
|
||||
DEFAULT_MODULE = defaultdict(str)
|
||||
|
||||
DEFAULT_TEMPLATE = defaultdict(str)
|
||||
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, str],
|
||||
|
|
|
@ -13,14 +13,13 @@ try:
|
|||
is_torch_npu_available
|
||||
)
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
|
||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
|
||||
except ImportError:
|
||||
_is_fp16_available = torch.cuda.is_available()
|
||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
|
@ -65,6 +64,15 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
return trainable_params, all_param
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
|
@ -77,25 +85,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
|||
return torch.float32
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
"""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
|
@ -107,26 +96,11 @@ def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None
|
|||
return parser.parse_args_into_dataclasses()
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||
Collects GPU memory.
|
||||
"""
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||
return model
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
|
||||
if model._no_split_modules is None:
|
||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||
|
||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
||||
max_memory = get_balanced_memory(model, **kwargs)
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||
return dispatch_model(model, device_map)
|
||||
else:
|
||||
return model.cuda()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
import importlib.metadata
|
||||
import importlib.util
|
||||
|
||||
|
||||
def is_package_available(name: str) -> bool:
|
||||
return importlib.util.find_spec(name) is not None
|
||||
|
||||
|
||||
def get_package_version(name: str) -> str:
|
||||
try:
|
||||
return importlib.metadata.version(name)
|
||||
except:
|
||||
return "0.0.0"
|
||||
|
||||
|
||||
_fastapi_available = is_package_available("fastapi")
|
||||
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
||||
_jieba_available = is_package_available("jieba")
|
||||
_matplotlib_available = is_package_available("matplotlib")
|
||||
_nltk_available = is_package_available("nltk")
|
||||
_rouge_available = is_package_available("rouge-chinese")
|
||||
_starlette_available = is_package_available("sse-starlette")
|
||||
_uvicorn_available = is_package_available("uvicorn")
|
||||
|
||||
|
||||
def is_fastapi_availble():
|
||||
return _fastapi_available
|
||||
|
||||
|
||||
def is_flash_attn2_available():
|
||||
return _flash_attn2_available
|
||||
|
||||
|
||||
def is_jieba_available():
|
||||
return _jieba_available
|
||||
|
||||
|
||||
def is_matplotlib_available():
|
||||
return _matplotlib_available
|
||||
|
||||
|
||||
def is_nltk_available():
|
||||
return _nltk_available
|
||||
|
||||
|
||||
def is_rouge_available():
|
||||
return _rouge_available
|
||||
|
||||
|
||||
def is_starlette_available():
|
||||
return _starlette_available
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _uvicorn_available
|
|
@ -3,16 +3,19 @@ import torch
|
|||
import torch.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
from transformers.utils import logging
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
is_flash_attn_2_available = False
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import repeat_kv
|
||||
except ImportError:
|
||||
print("Please upgrade `transformers`.")
|
||||
|
||||
from llmtuner.extras.packages import is_flash_attn2_available
|
||||
|
||||
|
||||
if is_flash_attn2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
||||
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
||||
is_flash_attn_2_available = True
|
||||
except ImportError:
|
||||
is_flash_attn_2_available = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
import os
|
||||
import math
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Optional
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.packages import is_matplotlib_available
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from llmtuner.model.loader import load_model_and_tokenizer
|
||||
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
|
||||
from llmtuner.model.utils import dispatch_model, generate_model_card
|
|
@ -1,18 +1,12 @@
|
|||
import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import cached_file
|
||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from peft import (
|
||||
PeftModel,
|
||||
TaskType,
|
||||
LoraConfig,
|
||||
get_peft_model
|
||||
)
|
||||
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.utils import find_all_linear_modules
|
||||
from llmtuner.model.utils import find_all_linear_modules
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
|
@ -25,10 +25,11 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
|||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
||||
from llmtuner.extras.packages import is_flash_attn2_available
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params
|
||||
from llmtuner.tuner.core.utils import prepare_model_for_training
|
||||
from llmtuner.model.adapter import init_adapter, load_valuehead_params
|
||||
from llmtuner.model.utils import prepare_model_for_training
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
@ -122,7 +123,7 @@ def load_model_and_tokenizer(
|
|||
# Set FlashAttention-2
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
if LlamaPatches.is_flash_attn_2_available:
|
||||
if is_flash_attn2_available():
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
|
@ -131,7 +132,7 @@ def load_model_and_tokenizer(
|
|||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||
else:
|
||||
logger.warning("Current model does not support FlashAttention-2.")
|
||||
logger.warning("Current model does not support FlashAttention.")
|
||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
|
@ -11,6 +11,7 @@ from llmtuner.extras.misc import parse_args
|
|||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
)
|
||||
|
@ -19,51 +20,42 @@ from llmtuner.hparams import (
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
))
|
||||
_TRAIN_ARGS = [
|
||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_TRAIN_CLS = Tuple[
|
||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_INFER_ARGS = [
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_INFER_CLS = Tuple[
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_EVAL_ARGS = [
|
||||
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
||||
]
|
||||
_EVAL_CLS = Tuple[
|
||||
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
||||
]
|
||||
|
||||
|
||||
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
))
|
||||
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
|
@ -187,14 +179,7 @@ def get_train_args(
|
|||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||
|
||||
if data_args.template is None:
|
||||
|
@ -211,3 +196,17 @@ def get_infer_args(
|
|||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
transformers.set_seed(eval_args.seed)
|
||||
|
||||
return model_args, data_args, eval_args, finetuning_args
|
|
@ -12,6 +12,31 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||
"""
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||
return model
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
|
||||
if model._no_split_modules is None:
|
||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||
|
||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
||||
max_memory = get_balanced_memory(model, **kwargs)
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||
return dispatch_model(model, device_map)
|
||||
else:
|
||||
return model.cuda()
|
||||
|
||||
|
||||
def find_all_linear_modules(
|
||||
model: "PreTrainedModel",
|
||||
quantization_bit: Optional[int] = None
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.train.tuner import export_model, run_exp
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.train.dpo.workflow import run_dpo
|
|
@ -4,14 +4,14 @@ from peft import PeftModel
|
|||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
|
||||
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
|
||||
from llmtuner.train.dpo.trainer import CustomDPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.train.ppo.workflow import run_ppo
|
|
@ -3,7 +3,7 @@ import sys
|
|||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
|||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
||||
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
|
@ -7,11 +7,11 @@ from typing import TYPE_CHECKING, Optional, List
|
|||
from transformers import DataCollatorWithPadding
|
||||
from transformers.optimization import get_scheduler
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.data import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
|
||||
from llmtuner.model import load_model_and_tokenizer
|
||||
from llmtuner.train.ppo.trainer import CustomPPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.train.pt.workflow import run_pt
|
|
@ -4,9 +4,9 @@ import math
|
|||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.train.rm.workflow import run_rm
|
|
@ -3,13 +3,13 @@
|
|||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from llmtuner.tuner.rm.trainer import PairwiseTrainer
|
||||
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from llmtuner.train.rm.metric import compute_accuracy
|
||||
from llmtuner.train.rm.trainer import PairwiseTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.train.sft.workflow import run_sft
|
|
@ -2,15 +2,23 @@ import numpy as np
|
|||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||
|
||||
import jieba
|
||||
from rouge_chinese import Rouge
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.packages import (
|
||||
is_jieba_available, is_nltk_available, is_rouge_available
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
if is_jieba_available():
|
||||
import jieba
|
||||
|
||||
if is_nltk_available():
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
|
||||
if is_rouge_available():
|
||||
from rouge_chinese import Rouge
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeMetrics:
|
|
@ -3,13 +3,13 @@
|
|||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
|
||||
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.train.sft.metric import ComputeMetrics
|
||||
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
|
@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner.pt import run_pt
|
||||
from llmtuner.tuner.sft import run_sft
|
||||
from llmtuner.tuner.rm import run_rm
|
||||
from llmtuner.tuner.ppo import run_ppo
|
||||
from llmtuner.tuner.dpo import run_dpo
|
||||
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.train.pt import run_pt
|
||||
from llmtuner.train.sft import run_sft
|
||||
from llmtuner.train.rm import run_rm
|
||||
from llmtuner.train.ppo import run_ppo
|
||||
from llmtuner.train.dpo import run_dpo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
|
@ -1 +0,0 @@
|
|||
from llmtuner.tuner.tune import export_model, run_exp
|
|
@ -1,3 +0,0 @@
|
|||
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
|
||||
from llmtuner.tuner.core.loader import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.utils import generate_model_card
|
|
@ -1 +0,0 @@
|
|||
from llmtuner.tuner.dpo.workflow import run_dpo
|
|
@ -1 +0,0 @@
|
|||
from llmtuner.tuner.ppo.workflow import run_ppo
|
|
@ -1 +0,0 @@
|
|||
from llmtuner.tuner.pt.workflow import run_pt
|
|
@ -1 +0,0 @@
|
|||
from llmtuner.tuner.rm.workflow import run_rm
|
|
@ -1 +0,0 @@
|
|||
from llmtuner.tuner.sft.workflow import run_sft
|
|
@ -2,7 +2,7 @@ import gradio as gr
|
|||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
from llmtuner.webui.common import get_save_dir
|
||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import gradio as gr
|
||||
from threading import Thread
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
|
@ -13,7 +13,7 @@ from llmtuner.extras.callbacks import LogCallback
|
|||
from llmtuner.extras.constants import TRAINING_STAGES
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.train import run_exp
|
||||
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
import os
|
||||
import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.extras.packages import is_matplotlib_available
|
||||
from llmtuner.extras.ploting import smooth
|
||||
from llmtuner.webui.common import get_save_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
||||
if not callback.max_steps:
|
||||
|
@ -56,7 +59,7 @@ def get_eval_results(path: os.PathLike) -> str:
|
|||
return "```json\n{}\n```\n".format(result)
|
||||
|
||||
|
||||
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
|
||||
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
|
||||
if not base_model:
|
||||
return
|
||||
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
|
||||
|
|
|
@ -7,12 +7,13 @@ import fire
|
|||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.data import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
||||
from llmtuner.model import get_train_args, load_model_and_tokenizer
|
||||
|
||||
|
||||
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
|
||||
|
@ -22,14 +23,16 @@ BASE_BS = 4_000_000 # from llama paper
|
|||
def calculate_lr(
|
||||
model_name_or_path: str,
|
||||
dataset: str,
|
||||
cutoff_len: int, # i.e. maximum input length during training
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
is_mistral: bool # mistral model uses a smaller learning rate
|
||||
cutoff_len: int, # i.e. maximum input length during training
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
is_mistral: bool, # mistral model uses a smaller learning rate,
|
||||
dataset_dir: Optional[str] = "data"
|
||||
):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
|
||||
stage="sft",
|
||||
model_name_or_path=model_name_or_path,
|
||||
dataset=dataset,
|
||||
dataset_dir=dataset_dir,
|
||||
template="default",
|
||||
cutoff_len=cutoff_len,
|
||||
output_dir="dummy_dir"
|
||||
|
|
Loading…
Reference in New Issue