add option to disable version check

This commit is contained in:
hiyouga 2024-02-10 22:31:23 +08:00
parent a754f6e9ec
commit 91d09a01ac
4 changed files with 47 additions and 26 deletions

View File

@ -2,11 +2,11 @@ import importlib.metadata
import importlib.util
def is_package_available(name: str) -> bool:
def _is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def get_package_version(name: str) -> str:
def _get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except Exception:
@ -14,36 +14,40 @@ def get_package_version(name: str) -> str:
def is_fastapi_availble():
return is_package_available("fastapi")
return _is_package_available("fastapi")
def is_flash_attn2_available():
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
def is_jieba_available():
return is_package_available("jieba")
return _is_package_available("jieba")
def is_matplotlib_available():
return is_package_available("matplotlib")
return _is_package_available("matplotlib")
def is_nltk_available():
return is_package_available("nltk")
return _is_package_available("nltk")
def is_requests_available():
return is_package_available("requests")
return _is_package_available("requests")
def is_rouge_available():
return is_package_available("rouge_chinese")
return _is_package_available("rouge_chinese")
def is_starlette_available():
return is_package_available("sse_starlette")
return _is_package_available("sse_starlette")
def is_unsloth_available():
return _is_package_available("unsloth")
def is_uvicorn_available():
return is_package_available("uvicorn")
return _is_package_available("uvicorn")

View File

@ -132,6 +132,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora", metadata={"help": "Which fine-tuning method to use."}
)
disable_version_checking: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to disable version checking."}
)
plot_loss: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to save the training loss curves."}
)

View File

@ -8,8 +8,10 @@ import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.packages import is_unsloth_available
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
@ -28,6 +30,14 @@ _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArgu
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _check_dependencies():
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
@ -123,8 +133,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth")
_verify_model_args(model_args, finetuning_args)
if not finetuning_args.disable_version_checking:
_check_dependencies()
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
@ -145,7 +161,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# postprocess training_args
# Post-process training arguments
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
@ -158,7 +174,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
training_args.resume_from_checkpoint = None
if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None
else:
can_resume_from_checkpoint = True
@ -194,7 +212,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
)
)
# postprocess model_args
# Post-process model arguments
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
@ -212,7 +230,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
@ -220,24 +237,30 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
_verify_model_args(model_args, finetuning_args)
if not finetuning_args.disable_version_checking:
_check_dependencies()
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
_verify_model_args(model_args, finetuning_args)
if not finetuning_args.disable_version_checking:
_check_dependencies()
transformers.set_seed(eval_args.seed)

View File

@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
@ -21,13 +20,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
@ -63,7 +55,6 @@ def load_model_and_tokenizer(
model = None
if is_trainable and model_args.use_unsloth:
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_kwargs = {