forked from p04798526/LLaMA-Factory-Mirror
refactor export, fix #1190
This commit is contained in:
parent
273745f9b9
commit
ea82f8a82a
|
@ -371,7 +371,7 @@ python src/export_model.py \
|
|||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export \
|
||||
--export_dir path_to_export \
|
||||
--fp16
|
||||
```
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from .data_args import DataArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
from .general_args import GeneralArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
|
|
|
@ -8,6 +8,10 @@ class FinetuningArguments:
|
|||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."}
|
||||
)
|
||||
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."}
|
||||
|
|
|
@ -46,6 +46,10 @@ class ModelArguments:
|
|||
default=None,
|
||||
metadata={"help": "Adopt scaled rotary positional embeddings."}
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
)
|
||||
flash_attn: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
||||
|
@ -54,14 +58,14 @@ class ModelArguments:
|
|||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
)
|
||||
upcast_layernorm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
|
@ -70,9 +74,9 @@ class ModelArguments:
|
|||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||
)
|
||||
upcast_layernorm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
|
@ -5,7 +5,6 @@ import datasets
|
|||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
@ -13,8 +12,7 @@ from llmtuner.hparams import (
|
|||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
)
|
||||
|
||||
|
||||
|
@ -39,16 +37,14 @@ def parse_train_args(
|
|||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
@ -77,10 +73,9 @@ def get_train_args(
|
|||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
]:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
|
@ -96,36 +91,36 @@ def get_train_args(
|
|||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||
data_args.init_for_training()
|
||||
|
||||
if general_args.stage != "pt" and data_args.template is None:
|
||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if general_args.stage != "sft" and training_args.predict_with_generate:
|
||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||
|
||||
if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||
|
||||
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||
if finetuning_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||
raise ValueError("PPO and DPO stages can only be performed at training.")
|
||||
|
||||
if general_args.stage in ["rm", "dpo"]:
|
||||
if finetuning_args.stage in ["rm", "dpo"]:
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if not dataset_attr.ranking:
|
||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||
|
||||
if general_args.stage == "ppo" and model_args.reward_model is None:
|
||||
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if general_args.stage == "ppo" and data_args.streaming:
|
||||
if finetuning_args.stage == "ppo" and data_args.streaming:
|
||||
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
||||
|
||||
if general_args.stage == "ppo" and model_args.shift_attn:
|
||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
|
@ -205,7 +200,7 @@ def get_train_args(
|
|||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner.pt import run_pt
|
||||
from llmtuner.tuner.sft import run_sft
|
||||
from llmtuner.tuner.rm import run_rm
|
||||
|
@ -17,31 +17,32 @@ logger = get_logger(__name__)
|
|||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||
|
||||
if general_args.stage == "pt":
|
||||
if finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif general_args.stage == "sft":
|
||||
elif finetuning_args.stage == "sft":
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif general_args.stage == "rm":
|
||||
elif finetuning_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif general_args.stage == "ppo":
|
||||
elif finetuning_args.stage == "ppo":
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif general_args.stage == "dpo":
|
||||
elif finetuning_args.stage == "dpo":
|
||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||
model_args, _, training_args, finetuning_args, _, _ = get_train_args(args)
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.config.use_cache = True
|
||||
tokenizer.padding_side = "left" # restore padding side
|
||||
tokenizer.init_kwargs["padding_side"] = "left"
|
||||
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
|
||||
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
|
||||
try:
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
tokenizer.save_pretrained(model_args.export_dir)
|
||||
except:
|
||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
save_dir = gr.Textbox()
|
||||
export_dir = gr.Textbox()
|
||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||
|
||||
export_btn = gr.Button()
|
||||
|
@ -28,13 +28,13 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
engine.manager.get_elem("top.finetuning_type"),
|
||||
engine.manager.get_elem("top.template"),
|
||||
max_shard_size,
|
||||
save_dir
|
||||
export_dir
|
||||
],
|
||||
[info_box]
|
||||
)
|
||||
|
||||
elem_dict.update(dict(
|
||||
save_dir=save_dir,
|
||||
export_dir=export_dir,
|
||||
max_shard_size=max_shard_size,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box
|
||||
|
|
|
@ -531,7 +531,7 @@ LOCALES = {
|
|||
"label": "温度系数"
|
||||
}
|
||||
},
|
||||
"save_dir": {
|
||||
"export_dir": {
|
||||
"en": {
|
||||
"label": "Export dir",
|
||||
"info": "Directory to save exported model."
|
||||
|
@ -587,7 +587,7 @@ ALERTS = {
|
|||
"en": "Please select a checkpoint.",
|
||||
"zh": "请选择断点。"
|
||||
},
|
||||
"err_no_save_dir": {
|
||||
"err_no_export_dir": {
|
||||
"en": "Please provide export dir.",
|
||||
"zh": "请填写导出目录"
|
||||
},
|
||||
|
|
|
@ -124,7 +124,7 @@ def save_model(
|
|||
finetuning_type: str,
|
||||
template: str,
|
||||
max_shard_size: int,
|
||||
save_dir: str
|
||||
export_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
if not model_name:
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
|
@ -138,8 +138,8 @@ def save_model(
|
|||
yield ALERTS["err_no_checkpoint"][lang]
|
||||
return
|
||||
|
||||
if not save_dir:
|
||||
yield ALERTS["err_no_save_dir"][lang]
|
||||
if not export_dir:
|
||||
yield ALERTS["err_no_export_dir"][lang]
|
||||
return
|
||||
|
||||
args = dict(
|
||||
|
@ -147,7 +147,7 @@ def save_model(
|
|||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
output_dir=save_dir
|
||||
export_dir=export_dir
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
|
|
Loading…
Reference in New Issue