This commit is contained in:
hiyouga 2024-03-12 15:53:29 +08:00
parent c901aa63ff
commit 07f9b754a7
3 changed files with 19 additions and 2 deletions

View File

@ -31,9 +31,11 @@ class VllmEngine(BaseEngine):
model=model_args.model_name_or_path,
trust_remote_code=True,
max_model_len=model_args.vllm_maxlen,
tensor_parallel_size=get_device_count(),
tensor_parallel_size=get_device_count() or 1,
gpu_memory_utilization=model_args.vllm_gpu_util,
disable_log_stats=True,
disable_log_requests=True,
enforce_eager=model_args.vllm_enforce_eager,
)
self.model = AsyncLLMEngine.from_engine_args(engine_args)
self.tokenizer = load_tokenizer(model_args)

View File

@ -89,6 +89,14 @@ class ModelArguments:
default=2048,
metadata={"help": "Maximum input length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.9,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
)
vllm_enforce_eager: bool = field(
default=False,
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},

View File

@ -124,7 +124,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
if finetuning_args.use_dora:
@ -141,6 +141,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.fp16 or training_args.bf16:
raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
if (
finetuning_args.use_galore
and finetuning_args.galore_layerwise
and training_args.parallel_mode.value == "distributed"
):
raise ValueError("Distributed training does not support layer-wise GaLore.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")