vllm + lora support
This commit is contained in:
parent
8d66435664
commit
28571da80a
|
@ -10,6 +10,7 @@ from .base_engine import BaseEngine, Response
|
|||
|
||||
if is_vllm_available():
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
@ -24,7 +25,8 @@ class VllmEngine(BaseEngine):
|
|||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
load_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
infer_dtype = str(infer_dtype).split(".")[-1]
|
||||
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
self.tokenizer = load_tokenizer(model_args)
|
||||
|
@ -36,15 +38,20 @@ class VllmEngine(BaseEngine):
|
|||
model=model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
download_dir=model_args.cache_dir,
|
||||
dtype=str(load_dtype).split(".")[-1],
|
||||
dtype=infer_dtype,
|
||||
max_model_len=model_args.vllm_maxlen,
|
||||
tensor_parallel_size=get_device_count() or 1,
|
||||
gpu_memory_utilization=model_args.vllm_gpu_util,
|
||||
disable_log_stats=True,
|
||||
disable_log_requests=True,
|
||||
enforce_eager=model_args.vllm_enforce_eager,
|
||||
enable_lora=model_args.adapter_name_or_path is not None,
|
||||
)
|
||||
self.model = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
||||
else:
|
||||
self.lora_request = None
|
||||
|
||||
async def _generate(
|
||||
self,
|
||||
|
@ -98,7 +105,11 @@ class VllmEngine(BaseEngine):
|
|||
skip_special_tokens=True,
|
||||
)
|
||||
result_generator = self.model.generate(
|
||||
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids
|
||||
prompt=None,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_ids,
|
||||
lora_request=self.lora_request,
|
||||
)
|
||||
return result_generator
|
||||
|
||||
|
|
|
@ -308,15 +308,15 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||
if finetuning_args.stage != "sft":
|
||||
raise ValueError("vLLM engine only supports auto-regressive models.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
raise ValueError("vLLM engine does not support quantization.")
|
||||
raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
|
||||
|
||||
if model_args.rope_scaling is not None:
|
||||
raise ValueError("vLLM engine does not support RoPE scaling.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
|
|
Loading…
Reference in New Issue