support unsloth generate
This commit is contained in:
parent
aa2b79eb23
commit
b1deb0a0b9
|
@ -7,10 +7,11 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
|||
from ..extras.logging import get_logger
|
||||
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .utils.quantization import QuantizationMethod
|
||||
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
@ -19,7 +20,11 @@ logger = get_logger(__name__)
|
|||
|
||||
|
||||
def init_adapter(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
|
||||
config: "PretrainedConfig",
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
|
@ -106,6 +111,10 @@ def init_adapter(
|
|||
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
||||
is_mergeable = False
|
||||
|
||||
if model_args.use_unsloth:
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
|
||||
is_mergeable = False
|
||||
|
||||
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
||||
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
||||
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
||||
|
@ -122,9 +131,15 @@ def init_adapter(
|
|||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
model = PeftModel.from_pretrained(
|
||||
model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder
|
||||
)
|
||||
if model_args.use_unsloth:
|
||||
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
|
||||
else:
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
adapter_to_resume,
|
||||
is_trainable=is_trainable,
|
||||
offload_folder=model_args.offload_folder,
|
||||
)
|
||||
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
|
@ -152,14 +167,8 @@ def init_adapter(
|
|||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_peft_kwargs = {
|
||||
"model": model,
|
||||
"max_seq_length": model_args.model_max_length,
|
||||
"use_gradient_checkpointing": "unsloth",
|
||||
}
|
||||
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
print(model)
|
||||
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
|
||||
else:
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
|
|
|
@ -3,12 +3,13 @@ from typing import TYPE_CHECKING, Any, Dict
|
|||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras.constants import MOD_SUPPORTED_MODELS
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||
from ..extras.misc import count_parameters, try_download_model_from_ms
|
||||
from .adapter import init_adapter
|
||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||
from .utils.misc import load_valuehead_params, register_autoclass
|
||||
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
from .utils.unsloth import load_unsloth_pretrained_model
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -83,54 +84,30 @@ def load_model(
|
|||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
|
||||
model = None
|
||||
if is_trainable and model_args.use_unsloth:
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
lazy_load = False
|
||||
if model_args.use_unsloth:
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
lazy_load = True
|
||||
elif is_trainable:
|
||||
model = load_unsloth_pretrained_model(config, model_args)
|
||||
|
||||
unsloth_kwargs = {
|
||||
"model_name": model_args.model_name_or_path,
|
||||
"max_seq_length": model_args.model_max_length,
|
||||
"dtype": model_args.compute_dtype,
|
||||
"load_in_4bit": model_args.quantization_bit == 4,
|
||||
"token": model_args.hf_hub_token,
|
||||
"device_map": {"": get_current_device()},
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
"fix_tokenizer": False,
|
||||
"trust_remote_code": True,
|
||||
}
|
||||
try:
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
except NotImplementedError:
|
||||
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
model_args.use_unsloth = False
|
||||
|
||||
if model_args.adapter_name_or_path:
|
||||
model_args.adapter_name_or_path = None
|
||||
logger.warning("Unsloth does not support loading adapters.")
|
||||
|
||||
if model is None:
|
||||
if model is None and not lazy_load:
|
||||
init_kwargs["config"] = config
|
||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||
|
||||
if model_args.mixture_of_depths == "load":
|
||||
from MoD import AutoMoDModelForCausalLM
|
||||
|
||||
model = AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
|
||||
if model_args.mixture_of_depths == "convert":
|
||||
from MoD import apply_mod_to_hf
|
||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||
|
||||
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
|
||||
raise ValueError("Current model is not supported by mixture-of-depth.")
|
||||
if not lazy_load:
|
||||
patch_model(model, tokenizer, model_args, is_trainable)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
|
||||
model = apply_mod_to_hf(model)
|
||||
model = model.to(model_args.compute_dtype)
|
||||
|
||||
patch_model(model, tokenizer, model_args, is_trainable)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
|
||||
|
||||
if add_valuehead:
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.constants import MOD_SUPPORTED_MODELS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
|
||||
from MoD import AutoMoDModelForCausalLM
|
||||
|
||||
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
|
||||
|
||||
def convert_pretrained_model_to_mod(
|
||||
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
|
||||
) -> "PreTrainedModel":
|
||||
from MoD import apply_mod_to_hf
|
||||
|
||||
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
|
||||
raise ValueError("Current model is not supported by mixture-of-depth.")
|
||||
|
||||
model = apply_mod_to_hf(model)
|
||||
model = model.to(model_args.compute_dtype)
|
||||
return model
|
|
@ -0,0 +1,85 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import get_current_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_unsloth_kwargs(
|
||||
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": model_name_or_path,
|
||||
"max_seq_length": model_args.model_max_length,
|
||||
"dtype": model_args.compute_dtype,
|
||||
"load_in_4bit": model_args.quantization_bit == 4,
|
||||
"token": model_args.hf_hub_token,
|
||||
"device_map": {"": get_current_device()},
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
"fix_tokenizer": False,
|
||||
"trust_remote_code": True,
|
||||
"use_gradient_checkpointing": "unsloth",
|
||||
}
|
||||
|
||||
|
||||
def load_unsloth_pretrained_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments"
|
||||
) -> Optional["PreTrainedModel"]:
|
||||
r"""
|
||||
Optionally loads pretrained model with unsloth.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
|
||||
try:
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
except NotImplementedError:
|
||||
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
model = None
|
||||
model_args.use_unsloth = False
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_unsloth_peft_model(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Gets the peft model for the pretrained model with unsloth.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
unsloth_peft_kwargs = {
|
||||
"model": model,
|
||||
"max_seq_length": model_args.model_max_length,
|
||||
"use_gradient_checkpointing": "unsloth",
|
||||
}
|
||||
return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
|
||||
|
||||
def load_unsloth_peft_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Loads peft model with unsloth.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path, model_args)
|
||||
try:
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
except NotImplementedError:
|
||||
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
|
||||
if not is_trainable:
|
||||
FastLanguageModel.for_inference(model)
|
||||
|
||||
return model
|
|
@ -61,6 +61,9 @@ def create_modelcard_and_push(
|
|||
if data_args.dataset is not None:
|
||||
kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")]
|
||||
|
||||
if model_args.use_unsloth:
|
||||
kwargs["tags"] = kwargs["tags"] + ["unsloth"]
|
||||
|
||||
if not training_args.do_train:
|
||||
pass
|
||||
elif training_args.push_to_hub:
|
||||
|
|
|
@ -138,7 +138,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
||||
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01)
|
||||
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
|
||||
create_new_adapter = gr.Checkbox()
|
||||
|
||||
|
|
Loading…
Reference in New Issue