diff --git a/data/dataset_info.json b/data/dataset_info.json index d053be1d..9673158f 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -60,7 +60,7 @@ }, "mllm_demo": { "file_name": "mllm_demo.json", - "file_sha1": "b6709b23657d5c42a701f1c5574f3a6edaa40a20", + "file_sha1": "d626cc0ad88a26d0dc9fcb47336821cf486d8bcc", "formatting": "sharegpt", "columns": { "messages": "messages", diff --git a/data/mllm_demo.json b/data/mllm_demo.json index 32d6d221..39bda392 100644 --- a/data/mllm_demo.json +++ b/data/mllm_demo.json @@ -2,7 +2,7 @@ { "messages": [ { - "content": "Who are they?", + "content": "Who are they?", "role": "user" }, { @@ -14,18 +14,18 @@ "role": "user" }, { - "content": "They are celebrating on the soccer field", + "content": "They are celebrating on the soccer field.", "role": "assistant" } ], "images": [ - "images/1.jpg" + "mllm_demo_data/1.jpg" ] }, { "messages": [ { - "content": "Who is he?", + "content": "Who is he?", "role": "user" }, { @@ -42,13 +42,13 @@ } ], "images": [ - "images/2.jpg" + "mllm_demo_data/2.jpg" ] }, { "messages": [ { - "content": "Please describe this image", + "content": "Please describe this image", "role": "user" }, { @@ -65,7 +65,76 @@ } ], "images": [ - "images/3.jpg" + "mllm_demo_data/3.jpg" + ] + }, + { + "messages": [ + { + "content": "他们是谁?", + "role": "user" + }, + { + "content": "他们是拜仁慕尼黑的凯恩和格雷茨卡。", + "role": "assistant" + }, + { + "content": "他们在做什么?", + "role": "user" + }, + { + "content": "他们在足球场上庆祝。", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/1.jpg" + ] + }, + { + "messages": [ + { + "content": "他是谁?", + "role": "user" + }, + { + "content": "他是来自拜仁慕尼黑的托马斯·穆勒。", + "role": "assistant" + }, + { + "content": "他为什么在地上?", + "role": "user" + }, + { + "content": "因为他正在双膝跪地滑行庆祝。", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/2.jpg" + ] + }, + { + "messages": [ + { + "content": "请描述这张图片", + "role": "user" + }, + { + "content": "中国宇航员桂海潮正在讲话。", + "role": "assistant" + }, + { + "content": "他取得过哪些成就?", + "role": "user" + }, + { + "content": "他于2022年6月被任命为神舟十六号任务的有效载荷专家,从而成为2023年5月30日进入太空的首位平民宇航员。他负责在轨操作空间科学实验有效载荷。", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/3.jpg" ] } ] \ No newline at end of file diff --git a/data/images/1.jpg b/data/mllm_demo_data/1.jpg similarity index 100% rename from data/images/1.jpg rename to data/mllm_demo_data/1.jpg diff --git a/data/images/2.jpg b/data/mllm_demo_data/2.jpg similarity index 100% rename from data/images/2.jpg rename to data/mllm_demo_data/2.jpg diff --git a/data/images/3.jpg b/data/mllm_demo_data/3.jpg similarity index 100% rename from data/images/3.jpg rename to data/mllm_demo_data/3.jpg diff --git a/setup.py b/setup.py index 9ef881e2..7ff3185f 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ extra_require = { "unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"], "galore": ["galore-torch"], "badam": ["badam"], - "vllm": ["vllm>=0.3.3"], + "vllm": ["vllm>=0.4.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"], "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], "awq": ["autoawq"], diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index 6cb78806..b3a980a5 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -7,5 +7,5 @@ from .train import export_model, run_exp from .webui import create_ui, create_web_demo -__version__ = "0.6.4.dev0" +__version__ = "0.7.0" __all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"] diff --git a/src/llmtuner/chat/hf_engine.py b/src/llmtuner/chat/hf_engine.py index f6f51898..e8f06a73 100644 --- a/src/llmtuner/chat/hf_engine.py +++ b/src/llmtuner/chat/hf_engine.py @@ -56,7 +56,7 @@ class HuggingfaceEngine(BaseEngine): input_kwargs: Optional[Dict[str, Any]] = {}, ) -> Tuple[Dict[str, Any], int]: if processor is not None and image is not None and "" not in messages[0]["content"]: - messages[0]["content"] = messages[0]["content"] + "" + messages[0]["content"] = "" + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] prompt_ids, _ = template.encode_oneturn( diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index a4caa53b..0f0dc366 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -11,10 +11,13 @@ from .base_engine import BaseEngine, Response if is_vllm_available(): from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm.lora.request import LoRARequest + from vllm.sequence import MultiModalData if TYPE_CHECKING: + import torch from numpy.typing import NDArray + from transformers.image_processing_utils import BaseImageProcessor from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments @@ -39,20 +42,30 @@ class VllmEngine(BaseEngine): self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) self.generating_args = generating_args.to_dict() - engine_args = AsyncEngineArgs( - model=model_args.model_name_or_path, - trust_remote_code=True, - download_dir=model_args.cache_dir, - dtype=infer_dtype, - max_model_len=model_args.vllm_maxlen, - tensor_parallel_size=get_device_count() or 1, - gpu_memory_utilization=model_args.vllm_gpu_util, - disable_log_stats=True, - disable_log_requests=True, - enforce_eager=model_args.vllm_enforce_eager, - enable_lora=model_args.adapter_name_or_path is not None, - ) - self.model = AsyncLLMEngine.from_engine_args(engine_args) + engine_args = { + "model": model_args.model_name_or_path, + "trust_remote_code": True, + "download_dir": model_args.cache_dir, + "dtype": infer_dtype, + "max_model_len": model_args.vllm_maxlen, + "tensor_parallel_size": get_device_count() or 1, + "gpu_memory_utilization": model_args.vllm_gpu_util, + "disable_log_stats": True, + "disable_log_requests": True, + "enforce_eager": model_args.vllm_enforce_eager, + "enable_lora": model_args.adapter_name_or_path is not None, + } + + if model_args.visual_inputs: + # TODO: auto derive from config + # https://github.com/vllm-project/vllm/pull/3042#issuecomment-1984893549 + self.image_feature_size = 576 + engine_args["image_input_type"] = "pixel_values" + engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("") + engine_args["image_input_shape"] = "1,3,336,336" + engine_args["image_feature_size"] = self.image_feature_size + + self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) if model_args.adapter_name_or_path is not None: self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0]) else: @@ -67,6 +80,9 @@ class VllmEngine(BaseEngine): **input_kwargs, ) -> AsyncIterator["RequestOutput"]: request_id = "chatcmpl-{}".format(uuid.uuid4().hex) + if self.processor is not None and image is not None and "" not in messages[0]["content"]: + messages[0]["content"] = "" * self.image_feature_size + messages[0]["content"] + paired_messages = messages + [{"role": "assistant", "content": ""}] prompt_ids, _ = self.template.encode_oneturn( tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools @@ -110,12 +126,21 @@ class VllmEngine(BaseEngine): max_tokens=generating_args["max_new_tokens"], skip_special_tokens=True, ) + + if self.processor is not None and image is not None: + image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") + pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"] + multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) + else: + multi_modal_data = None + result_generator = self.model.generate( prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids, lora_request=self.lora_request, + multi_modal_data=multi_modal_data, ) return result_generator diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 18681872..38211b0c 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -1,14 +1,20 @@ from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple from ..extras.constants import IGNORE_INDEX from ..extras.logging import get_logger +from ..extras.packages import is_pillow_available from .utils import Role +if is_pillow_available(): + from PIL import Image + + if TYPE_CHECKING: - from PIL.Image import Image + from numpy.typing import NDArray + from PIL.Image import Image as ImageObject from transformers import ProcessorMixin, Seq2SeqTrainingArguments from transformers.image_processing_utils import BaseImageProcessor from transformers.tokenization_utils import PreTrainedTokenizer @@ -20,12 +26,11 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def _preprocess_visual_inputs(model_inputs: Dict[str, Any], processor: "ProcessorMixin", image: "Image") -> None: +def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray": + # process visual inputs (currently only supports a single image) image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - pixel_values = image_processor(image, return_tensors="pt")["pixel_values"][0] - if "pixel_values" not in model_inputs: - model_inputs["pixel_values"] = [] - model_inputs["pixel_values"].append(pixel_values) + image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255)) + return image_processor(image, return_tensors="pt")["pixel_values"][0] def preprocess_pretrain_dataset( @@ -66,11 +71,17 @@ def preprocess_supervised_dataset( # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + if processor is not None: + model_inputs["pixel_values"] = [] + preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: continue + if processor is not None: + examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] + messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = [], [] for turn_idx, (source_ids, target_ids) in enumerate( @@ -100,8 +111,8 @@ def preprocess_supervised_dataset( model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) - if processor is not None and "images" in examples: - _preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0]) + if processor is not None: + model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) return model_inputs @@ -161,11 +172,17 @@ def preprocess_unsupervised_dataset( ) -> Dict[str, List[List[int]]]: # build inputs with format ` X` and labels with format `Y ` model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + if processor is not None: + model_inputs["pixel_values"] = [] + preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1: continue + if processor is not None: + examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] + if len(examples["response"][i]) == 1: messages = examples["prompt"][i] + examples["response"][i] else: @@ -186,8 +203,8 @@ def preprocess_unsupervised_dataset( model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) - if processor is not None and "images" in examples: - _preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0]) + if processor is not None: + model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) return model_inputs @@ -201,10 +218,17 @@ def preprocess_pairwise_dataset( ) -> Dict[str, List[List[int]]]: # build input pairs with format ` X`, `Y1 ` and `Y2 ` model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} + if processor is not None: + model_inputs["pixel_values"] = [] + preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) + for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: continue + if processor is not None: + examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] + chosen_messages = examples["prompt"][i] + [examples["response"][i][0]] rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] prompt_ids, chosen_ids = template.encode_oneturn( @@ -231,8 +255,8 @@ def preprocess_pairwise_dataset( model_inputs["prompt_ids"].append(prompt_ids) model_inputs["chosen_ids"].append(chosen_ids) model_inputs["rejected_ids"].append(rejected_ids) - if processor is not None and "images" in examples: - _preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0]) + if processor is not None: + model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) return model_inputs diff --git a/src/llmtuner/extras/packages.py b/src/llmtuner/extras/packages.py index aeeba084..a7317eec 100644 --- a/src/llmtuner/extras/packages.py +++ b/src/llmtuner/extras/packages.py @@ -48,6 +48,10 @@ def is_nltk_available(): return _is_package_available("nltk") +def is_pillow_available(): + return _is_package_available("PIL") + + def is_requests_available(): return _is_package_available("requests") diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index aa046837..977d7cf4 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -89,7 +89,7 @@ def _check_extra_dependencies( require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") if model_args.infer_backend == "vllm": - require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3") + require_version("vllm>=0.4.0", "To fix: pip install vllm>=0.4.0") if finetuning_args.use_galore: require_version("galore_torch", "To fix: pip install galore_torch") @@ -320,9 +320,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: raise ValueError("vLLM only accepts a single adapter. Merge them first.") - if model_args.visual_inputs: - raise ValueError("vLLM engine does not support MLLM yet. Stay tuned.") - if finetuning_args.stage == "rm" and model_args.visual_inputs: raise ValueError("Reward server does not support MLLM yet. Stay tuned.") diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 15c1fc83..0a55460c 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -27,10 +27,10 @@ def create_chat_box( with gr.Column(): role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value) system = gr.Textbox(show_label=False) - tools = gr.Textbox(show_label=False, lines=4) + tools = gr.Textbox(show_label=False, lines=3) with gr.Column() as image_box: - image = gr.Image(type="numpy") + image = gr.Image(sources=["upload"], type="numpy") query = gr.Textbox(show_label=False, lines=8) submit_btn = gr.Button(variant="primary")