support HQQ/EETQ #4113

This commit is contained in:
hiyouga 2024-06-27 00:29:42 +08:00
parent addca926de
commit ad144c2265
16 changed files with 134 additions and 57 deletions

View File

@ -48,7 +48,7 @@ Choose your path:
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. - **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
@ -341,7 +341,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality
> [!TIP] > [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts. > Use `pip install --no-deps -e .` to resolve package conflicts.

View File

@ -48,7 +48,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 - **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 - **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
@ -341,7 +341,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
可选的额外依赖项torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality 可选的额外依赖项torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality
> [!TIP] > [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。

View File

@ -39,12 +39,14 @@ extra_require = {
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0"], "deepspeed": ["deepspeed>=0.10.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"vllm": ["vllm>=0.4.3"], "hqq": ["hqq"],
"galore": ["galore-torch"], "eetq": ["eetq"],
"badam": ["badam>=1.2.1"],
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3"],
"galore": ["galore-torch"],
"badam": ["badam>=1.2.1"],
"qwen": ["transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"dev": ["ruff", "pytest"], "dev": ["ruff", "pytest"],

View File

@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -77,6 +77,10 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."}, metadata={"help": "Whether or not to use memory-efficient model loading."},
) )
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field( quantization_bit: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the model using bitsandbytes."}, metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
@ -235,9 +239,6 @@ class ModelArguments:
if self.new_special_tokens is not None: # support multiple special tokens if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
if self.export_quantization_bit is not None and self.export_quantization_dataset is None: if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.") raise ValueError("Quantization dataset is necessary for exporting.")

View File

@ -14,10 +14,12 @@
from .loader import load_config, load_model, load_tokenizer from .loader import load_config, load_model, load_tokenizer
from .model_utils.misc import find_all_linear_modules from .model_utils.misc import find_all_linear_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.valuehead import load_valuehead_params from .model_utils.valuehead import load_valuehead_params
__all__ = [ __all__ = [
"QuantizationMethod",
"load_config", "load_config",
"load_model", "load_model",
"load_tokenizer", "load_tokenizer",

View File

@ -186,11 +186,11 @@ def load_model(
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
if is_trainable: if is_trainable:
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param trainable_params, all_param, 100 * trainable_params / all_param
) )
else: else:
param_stats = "all params: {:d}".format(all_param) param_stats = "all params: {:,}".format(all_param)
logger.info(param_stats) logger.info(param_stats)

View File

@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@ -59,7 +59,7 @@ class QuantizationMethod(str, Enum):
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]: def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
r""" r"""
Prepares the dataset to perform AutoGPTQ. Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
""" """
if os.path.isfile(model_args.export_quantization_dataset): if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
@ -93,7 +93,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen] attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
samples.append({"input_ids": input_ids, "attention_mask": attention_mask}) samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()})
return samples return samples
@ -105,7 +105,7 @@ def configure_quantization(
init_kwargs: Dict[str, Any], init_kwargs: Dict[str, Any],
) -> None: ) -> None:
r""" r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
""" """
if getattr(config, "quantization_config", None): # ptq if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
@ -131,6 +131,9 @@ def configure_quantization(
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq elif model_args.export_quantization_bit is not None: # auto-gptq
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0") require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory from accelerate.utils import get_max_memory
@ -146,11 +149,11 @@ def configure_quantization(
init_kwargs["max_memory"] = get_max_memory() init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit)) logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
if model_args.quantization_bit == 8: if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4: elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig( init_kwargs["quantization_config"] = BitsAndBytesConfig(
@ -160,6 +163,8 @@ def configure_quantization(
bnb_4bit_quant_type=model_args.quantization_type, bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
) )
else:
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
# Do not assign device map if: # Do not assign device map if:
# 1. deepspeed zero3 or fsdp (train) # 1. deepspeed zero3 or fsdp (train)
@ -173,3 +178,19 @@ def configure_quantization(
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
require_version("hqq", "To fix: pip install hqq")
init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.")
require_version("eetq", "To fix: pip install eetq")
init_kwargs["quantization_config"] = EetqConfig()
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))

View File

@ -23,7 +23,7 @@ from ..data import Role
from ..extras.constants import PEFT_METHODS from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
from .common import get_save_dir from .common import QUANTIZATION_BITS, get_save_dir
from .locales import ALERTS from .locales import ALERTS
@ -76,11 +76,17 @@ class WebChatModel(ChatModel):
yield error yield error
return return
if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None
yield ALERTS["info_loading"][lang] yield ALERTS["info_loading"][lang]
args = dict( args = dict(
model_name_or_path=model_path, model_name_or_path=model_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"), template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"), use_unsloth=(get("top.booster") == "unsloth"),

View File

@ -47,6 +47,8 @@ DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data" DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves" DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml" USER_CONFIG = "user_config.yaml"
QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"]
GPTQ_BITS = ["8", "4", "3", "2"]
def get_save_dir(*paths: str) -> os.PathLike: def get_save_dir(*paths: str) -> os.PathLike:

View File

@ -18,7 +18,7 @@ from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ...train.tuner import export_model from ...train.tuner import export_model
from ..common import get_save_dir from ..common import GPTQ_BITS, get_save_dir
from ..locales import ALERTS from ..locales import ALERTS
@ -32,9 +32,6 @@ if TYPE_CHECKING:
from ..engine import Engine from ..engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"]
def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown": def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0: if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False) return gr.Dropdown(value="none", interactive=False)

View File

@ -18,7 +18,7 @@ from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import get_model_info, list_checkpoints, save_config from ..common import get_model_info, list_checkpoints, save_config
from ..utils import can_quantize from ..utils import can_quantize, can_quantize_to
if is_gradio_available(): if is_gradio_available():
@ -43,10 +43,11 @@ def create_top() -> Dict[str, "Component"]:
with gr.Accordion(open=False) as advanced_tab: with gr.Accordion(open=False) as advanced_tab:
with gr.Row(): with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2) quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2) quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3) template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3) rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2)
visual_inputs = gr.Checkbox(scale=1) visual_inputs = gr.Checkbox(scale=1)
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then( model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
@ -58,6 +59,7 @@ def create_top() -> Dict[str, "Component"]:
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
) )
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False) checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False)
return dict( return dict(
lang=lang, lang=lang,
@ -67,6 +69,7 @@ def create_top() -> Dict[str, "Component"]:
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
advanced_tab=advanced_tab, advanced_tab=advanced_tab,
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
quantization_method=quantization_method,
template=template, template=template,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
booster=booster, booster=booster,

View File

@ -85,15 +85,29 @@ LOCALES = {
"quantization_bit": { "quantization_bit": {
"en": { "en": {
"label": "Quantization bit", "label": "Quantization bit",
"info": "Enable 4/8-bit model quantization (QLoRA).", "info": "Enable quantization (QLoRA).",
}, },
"ru": { "ru": {
"label": "Уровень квантования", "label": "Уровень квантования",
"info": "Включить 4/8-битное квантование модели (QLoRA).", "info": "Включить квантование (QLoRA).",
}, },
"zh": { "zh": {
"label": "量化等级", "label": "量化等级",
"info": "启用 4/8 比特模型量化QLoRA", "info": "启用量化QLoRA",
},
},
"quantization_method": {
"en": {
"label": "Quantization method",
"info": "Quantization algorithm to use.",
},
"ru": {
"label": "Метод квантования",
"info": "Алгоритм квантования, который следует использовать.",
},
"zh": {
"label": "量化方法",
"info": "使用的量化算法。",
}, },
}, },
"template": { "template": {

View File

@ -71,6 +71,7 @@ class Manager:
self._id_to_elem["top.finetuning_type"], self._id_to_elem["top.finetuning_type"],
self._id_to_elem["top.checkpoint_path"], self._id_to_elem["top.checkpoint_path"],
self._id_to_elem["top.quantization_bit"], self._id_to_elem["top.quantization_bit"],
self._id_to_elem["top.quantization_method"],
self._id_to_elem["top.template"], self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"], self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"], self._id_to_elem["top.booster"],

View File

@ -22,7 +22,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .locales import ALERTS, LOCALES from .locales import ALERTS, LOCALES
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
@ -104,6 +104,11 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config() user_config = load_config()
if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None
args = dict( args = dict(
stage=TRAINING_STAGES[get("train.training_stage")], stage=TRAINING_STAGES[get("train.training_stage")],
do_train=True, do_train=True,
@ -111,7 +116,8 @@ class Runner:
cache_dir=user_config.get("cache_dir", None), cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16, preprocessing_num_workers=16,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"), template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
@ -234,13 +240,19 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config() user_config = load_config()
if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None
args = dict( args = dict(
stage="sft", stage="sft",
model_name_or_path=get("top.model_path"), model_name_or_path=get("top.model_path"),
cache_dir=user_config.get("cache_dir", None), cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16, preprocessing_num_workers=16,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"), template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",

View File

@ -25,6 +25,7 @@ from yaml import safe_dump, safe_load
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
from ..extras.packages import is_gradio_available, is_matplotlib_available from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot from ..extras.ploting import gen_loss_plot
from ..model import QuantizationMethod
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
from .locales import ALERTS from .locales import ALERTS
@ -55,6 +56,18 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
return gr.Dropdown(interactive=True) return gr.Dropdown(interactive=True)
def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
r"""
Returns the available quantization bits.
"""
if quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
return gr.Dropdown(choices=["none", "8", "4"])
elif quantization_method == QuantizationMethod.HQQ.value:
return gr.Dropdown(choices=["none", "8", "6", "5", "4", "3", "2", "1"])
elif quantization_method == QuantizationMethod.EETQ.value:
return gr.Dropdown(choices=["none", "8"])
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]: def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r""" r"""
Modifys states after changing the training stage. Modifys states after changing the training stage.