support autogptq in llama board #246

This commit is contained in:
hiyouga 2023-12-16 16:31:30 +08:00
parent 93f64ce9a8
commit 71389be37c
14 changed files with 1032 additions and 65 deletions

View File

@ -482,7 +482,7 @@ python src/export_model.py \
> Merging LoRA weights into a quantized model is not supported. > Merging LoRA weights into a quantized model is not supported.
> [!TIP] > [!TIP]
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/wiki_demo.txt` to quantize the model. > Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model.
### API Demo ### API Demo

View File

@ -482,7 +482,7 @@ python src/export_model.py \
> 尚不支持量化模型的 LoRA 权重合并及导出。 > 尚不支持量化模型的 LoRA 权重合并及导出。
> [!TIP] > [!TIP]
> 使用 `--export_quantization_bit 4``--export_quantization_dataset data/wiki_demo.txt` 量化导出模型。 > 使用 `--export_quantization_bit 4``--export_quantization_dataset data/c4_demo.json` 量化导出模型。
### API 服务 ### API 服务

902
data/c4_demo.json Normal file

File diff suppressed because one or more lines are too long

View File

@ -239,6 +239,13 @@
"prompt": "text" "prompt": "text"
} }
}, },
"c4_demo": {
"file_name": "c4_demo.json",
"file_sha1": "a5a0c86759732f9a5238e447fecd74f28a66cca8",
"columns": {
"prompt": "text"
}
},
"refinedweb": { "refinedweb": {
"hf_hub_url": "tiiuae/falcon-refinedweb", "hf_hub_url": "tiiuae/falcon-refinedweb",
"columns": { "columns": {

View File

@ -26,6 +26,8 @@ LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
PEFT_METHODS = ["lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict() SUPPORTED_MODELS = OrderedDict()

View File

@ -1,13 +1,9 @@
import gc import gc
import os import os
import sys
import torch import torch
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from typing import TYPE_CHECKING, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
try: try:
from transformers.utils import ( from transformers.utils import (
is_torch_bf16_cpu_available, is_torch_bf16_cpu_available,
@ -106,22 +102,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
return torch.float32 return torch.float32
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if unknown_args:
logger.warning(parser.format_help())
logger.error(f'\nGot unknown args, potentially deprecated arguments: {unknown_args}\n')
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,)
def torch_gc() -> None: def torch_gc() -> None:
r""" r"""
Collects GPU memory. Collects GPU memory.

View File

@ -1,5 +1,7 @@
import os import os
import sys
import torch import torch
import logging
import datasets import datasets
import transformers import transformers
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
@ -7,7 +9,6 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import parse_args
from llmtuner.hparams import ( from llmtuner.hparams import (
ModelArguments, ModelArguments,
DataArguments, DataArguments,
@ -40,6 +41,33 @@ _EVAL_CLS = Tuple[
] ]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if unknown_args:
print(parser.format_help())
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
return (*parsed_args,)
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None: def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
@ -56,34 +84,28 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
return parse_args(parser, args) return _parse_args(parser, args)
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS) parser = HfArgumentParser(_INFER_ARGS)
return parse_args(parser, args) return _parse_args(parser, args)
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS) parser = HfArgumentParser(_EVAL_ARGS)
return parse_args(parser, args) return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level() log_level = training_args.get_process_log_level()
datasets.utils.logging.set_verbosity(log_level) _set_transformers_logging(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Check arguments # Check arguments
data_args.init_for_training(training_args.seed) data_args.init_for_training(training_args.seed)
@ -193,7 +215,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
if data_args.template is None: if data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")
@ -204,7 +227,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()
if data_args.template is None: if data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")

View File

@ -76,8 +76,13 @@ def configure_quantization(
if finetuning_args.export_quantization_bit is not None: # gptq if finetuning_args.export_quantization_bit is not None: # gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
config_kwargs["quantization_config"] = GPTQConfig( config_kwargs["quantization_config"] = GPTQConfig(
bits=finetuning_args.export_quantization_bit, bits=finetuning_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args) dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args)
) )
config_kwargs["device_map"] = "auto" config_kwargs["device_map"] = "auto"

View File

@ -7,6 +7,7 @@ from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
from llmtuner.extras.constants import ( from llmtuner.extras.constants import (
DEFAULT_MODULE, DEFAULT_MODULE,
DEFAULT_TEMPLATE, DEFAULT_TEMPLATE,
PEFT_METHODS,
SUPPORTED_MODELS, SUPPORTED_MODELS,
TRAINING_STAGES, TRAINING_STAGES,
DownloadSource DownloadSource
@ -77,8 +78,11 @@ def get_template(model_name: str) -> str:
def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]: def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
if finetuning_type not in PEFT_METHODS:
return gr.update(value=[], choices=[], interactive=False)
adapters = [] adapters = []
if model_name and finetuning_type == "lora": # full and freeze have no adapter if model_name and finetuning_type == "lora":
save_dir = get_save_dir(model_name, finetuning_type) save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir): if save_dir and os.path.isdir(save_dir):
for adapter in os.listdir(save_dir): for adapter in os.listdir(save_dir):
@ -87,7 +91,7 @@ def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
and any([os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES]) and any([os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES])
): ):
adapters.append(adapter) adapters.append(adapter)
return gr.update(value=[], choices=adapters) return gr.update(value=[], choices=adapters, interactive=True)
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:

View File

@ -21,8 +21,11 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except:
return gr.update(interactive=False)
if ( if (
len(dataset) > 0 len(dataset) > 0

View File

@ -10,6 +10,9 @@ if TYPE_CHECKING:
from llmtuner.webui.engine import Engine from llmtuner.webui.engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"]
def save_model( def save_model(
lang: str, lang: str,
model_name: str, model_name: str,
@ -18,6 +21,8 @@ def save_model(
finetuning_type: str, finetuning_type: str,
template: str, template: str,
max_shard_size: int, max_shard_size: int,
export_quantization_bit: int,
export_quantization_dataset: str,
export_dir: str export_dir: str
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
error = "" error = ""
@ -25,23 +30,32 @@ def save_model(
error = ALERTS["err_no_model"][lang] error = ALERTS["err_no_model"][lang]
elif not model_path: elif not model_path:
error = ALERTS["err_no_path"][lang] error = ALERTS["err_no_path"][lang]
elif not adapter_path:
error = ALERTS["err_no_adapter"][lang]
elif not export_dir: elif not export_dir:
error = ALERTS["err_no_export_dir"][lang] error = ALERTS["err_no_export_dir"][lang]
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
error = ALERTS["err_no_adapter"][lang]
if error: if error:
gr.Warning(error) gr.Warning(error)
yield error yield error
return return
if adapter_path:
adapter_name_or_path = ",".join([get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path])
else:
adapter_name_or_path = None
args = dict( args = dict(
model_name_or_path=model_path, model_name_or_path=model_path,
adapter_name_or_path=",".join([get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]), adapter_name_or_path=adapter_name_or_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
template=template, template=template,
export_dir=export_dir, export_dir=export_dir,
export_size=max_shard_size export_size=max_shard_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset
) )
yield ALERTS["info_exporting"][lang] yield ALERTS["info_exporting"][lang]
@ -51,9 +65,11 @@ def save_model(
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
export_dir = gr.Textbox()
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100) max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_dir = gr.Textbox()
export_btn = gr.Button() export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False) info_box = gr.Textbox(show_label=False, interactive=False)
@ -67,14 +83,18 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_name("top.finetuning_type"), engine.manager.get_elem_by_name("top.finetuning_type"),
engine.manager.get_elem_by_name("top.template"), engine.manager.get_elem_by_name("top.template"),
max_shard_size, max_shard_size,
export_quantization_bit,
export_quantization_dataset,
export_dir export_dir
], ],
[info_box] [info_box]
) )
return dict( return dict(
export_dir=export_dir,
max_shard_size=max_shard_size, max_shard_size=max_shard_size,
export_quantization_bit=export_quantization_bit,
export_quantization_dataset=export_quantization_dataset,
export_dir=export_dir,
export_btn=export_btn, export_btn=export_btn,
info_box=info_box info_box=info_box
) )

View File

@ -20,7 +20,7 @@ def create_top() -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
adapter_path = gr.Dropdown(multiselect=True, scale=5) adapter_path = gr.Dropdown(multiselect=True, scale=5, allow_custom_value=True)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Accordion(label="Advanced config", open=False) as advanced_tab:

View File

@ -94,7 +94,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row(): with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=3) reward_model = gr.Dropdown(scale=3, allow_custom_value=True)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
refresh_btn.click( refresh_btn.click(

View File

@ -432,11 +432,11 @@ LOCALES = {
"reward_model": { "reward_model": {
"en": { "en": {
"label": "Reward model", "label": "Reward model",
"info": "Checkpoint of the reward model for PPO training. (Needs to refresh checkpoints)" "info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)"
}, },
"zh": { "zh": {
"label": "奖励模型", "label": "奖励模型",
"info": "PPO 训练中奖励模型的断点路径。(需要刷新断点" "info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器"
} }
}, },
"cmd_preview_btn": { "cmd_preview_btn": {
@ -585,6 +585,36 @@ LOCALES = {
"label": "温度系数" "label": "温度系数"
} }
}, },
"max_shard_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小GB",
"info": "单个模型文件的最大大小。"
}
},
"export_quantization_bit": {
"en": {
"label": "Export quantization bit.",
"info": "Quantizing the exported model."
},
"zh": {
"label": "导出量化等级",
"info": "量化导出模型。"
}
},
"export_quantization_dataset": {
"en": {
"label": "Export quantization dataset.",
"info": "The calibration dataset used for quantization."
},
"zh": {
"label": "导出量化数据集",
"info": "量化过程中使用的校准数据集。"
}
},
"export_dir": { "export_dir": {
"en": { "en": {
"label": "Export dir", "label": "Export dir",
@ -595,16 +625,6 @@ LOCALES = {
"info": "保存导出模型的文件夹路径。" "info": "保存导出模型的文件夹路径。"
} }
}, },
"max_shard_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小GB",
"info": "模型文件的最大大小。"
}
},
"export_btn": { "export_btn": {
"en": { "en": {
"value": "Export" "value": "Export"