support mllm hf inference

This commit is contained in:
hiyouga 2024-04-26 05:34:58 +08:00
parent c20f750d11
commit e057c8de48
27 changed files with 130 additions and 51 deletions

View File

@ -18,7 +18,8 @@ If you are using a custom dataset, please provide your dataset definition in the
"history": "the column name in the dataset containing the histories. (default: None)", "history": "the column name in the dataset containing the histories. (default: None)",
"messages": "the column name in the dataset containing the messages. (default: conversations)", "messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)", "system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)" "tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)"
}, },
"tags (optional, used for the sharegpt format)": { "tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)", "role_tag": "the key in the message represents the identity. (default: from)",

View File

@ -18,7 +18,8 @@
"history": "数据集代表历史对话的表头名称默认None", "history": "数据集代表历史对话的表头名称默认None",
"messages": "数据集代表消息列表的表头名称默认conversations", "messages": "数据集代表消息列表的表头名称默认conversations",
"system": "数据集代表系统提示的表头名称默认None", "system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None" "tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None"
}, },
"tags可选用于 sharegpt 格式)": { "tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from", "role_tag": "消息中代表发送者身份的键名默认from",

View File

@ -429,4 +429,4 @@
}, },
"folder": "python" "folder": "python"
} }
} }

Binary file not shown.

Before

Width:  |  Height:  |  Size: 48 KiB

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 68 KiB

After

Width:  |  Height:  |  Size: 16 KiB

View File

@ -68,4 +68,4 @@
"images/3.jpg" "images/3.jpg"
] ]
} }
] ]

View File

@ -9,6 +9,7 @@ examples/
│ ├── ppo.sh: Do PPO training using LoRA │ ├── ppo.sh: Do PPO training using LoRA
│ ├── dpo.sh: Do DPO training using LoRA │ ├── dpo.sh: Do DPO training using LoRA
│ ├── orpo.sh: Do ORPO training using LoRA │ ├── orpo.sh: Do ORPO training using LoRA
│ ├── sft_mllm.sh: Do supervised fine-tuning on multimodal data using LoRA
│ ├── prepare.sh: Save tokenized dataset │ ├── prepare.sh: Save tokenized dataset
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning │ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning
├── qlora_single_gpu/ ├── qlora_single_gpu/

View File

@ -9,6 +9,7 @@ examples/
│ ├── ppo.sh: 基于 LoRA 进行 PPO 训练 │ ├── ppo.sh: 基于 LoRA 进行 PPO 训练
│ ├── dpo.sh: 基于 LoRA 进行 DPO 训练 │ ├── dpo.sh: 基于 LoRA 进行 DPO 训练
│ ├── orpo.sh: 基于 LoRA 进行 ORPO 训练 │ ├── orpo.sh: 基于 LoRA 进行 ORPO 训练
│ ├── sft_mllm.sh: 基于 LoRA 进行多模态指令监督微调
│ ├── prepare.sh: 保存预处理后的数据集 │ ├── prepare.sh: 保存预处理后的数据集
│ └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数 │ └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数
├── qlora_single_gpu/ ├── qlora_single_gpu/

View File

@ -1,32 +1,33 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft_mm \ --stage sft \
--do_train \ --do_train \
--model_name_or_path llava-hf/llava-1.5-7b-hf \ --model_name_or_path llava-hf/llava-1.5-7b-hf \
--dataset mllm_instruct_example \ --visual_inputs \
--dataset_dir data \ --dataset mllm_demo \
--template default \ --dataset_dir ../../data \
--template vicuna \
--finetuning_type lora \ --finetuning_type lora \
--lora_target all \ --lora_target q_proj,v_proj \
--output_dir saves/llava-1.5-7b/lora/sft \ --output_dir ../../saves/LLaMA2-7B/lora/sft_mllm \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \
--cutoff_len 1024 \ --cutoff_len 1024 \
--preprocessing_num_workers 16 \ --preprocessing_num_workers 16 \
--per_device_train_batch_size 3 \ --per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \ --per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \ --gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 1 \ --logging_steps 10 \
--warmup_steps 20 \ --warmup_steps 20 \
--save_steps 100 \ --save_steps 100 \
--eval_steps 100 \ --eval_steps 100 \
--evaluation_strategy steps \ --evaluation_strategy steps \
--load_best_model_at_end \ --load_best_model_at_end \
--learning_rate 5e-5 \ --learning_rate 5e-5 \
--num_train_epochs 100 \ --num_train_epochs 100.0 \
--max_samples 3000 \ --max_samples 3000 \
--val_size 0.1 \ --val_size 0.1 \
--plot_loss \ --plot_loss \
--bf16 --fp16

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
@ -46,6 +47,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ... ) -> List["Response"]: ...
@ -55,6 +57,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ... ) -> AsyncGenerator[str, None]: ...

View File

@ -8,6 +8,8 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@ -36,9 +38,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop) task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
return task.result() return task.result()
async def achat( async def achat(
@ -46,18 +49,20 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
return await self.engine.chat(messages, system, tools, **input_kwargs) return await self.engine.chat(messages, system, tools, image, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, **input_kwargs) generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
while True: while True:
try: try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@ -70,9 +75,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs): async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
yield new_token yield new_token
def get_scores( def get_scores(

View File

@ -14,7 +14,9 @@ from .base_engine import BaseEngine, Response
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper from trl import PreTrainedModelWrapper
from ..data import Template from ..data import Template
@ -30,7 +32,9 @@ class HuggingfaceEngine(BaseEngine):
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
self.can_generate = finetuning_args.stage == "sft" self.can_generate = finetuning_args.stage == "sft"
self.tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right" self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.model = load_model( self.model = load_model(
@ -42,13 +46,18 @@ class HuggingfaceEngine(BaseEngine):
def _process_args( def _process_args(
model: "PreTrainedModel", model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = messages[0]["content"] + "<image>"
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = template.encode_oneturn( prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
@ -95,6 +104,11 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )
if processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
@staticmethod @staticmethod
@ -102,15 +116,17 @@ class HuggingfaceEngine(BaseEngine):
def _chat( def _chat(
model: "PreTrainedModel", model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]: ) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
@ -135,15 +151,17 @@ class HuggingfaceEngine(BaseEngine):
def _stream_chat( def _stream_chat(
model: "PreTrainedModel", model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
@ -199,6 +217,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
if not self.can_generate: if not self.can_generate:
@ -208,11 +227,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = ( input_args = (
self.model, self.model,
self.tokenizer, self.tokenizer,
self.processor,
self.template, self.template,
self.generating_args, self.generating_args,
messages, messages,
system, system,
tools, tools,
image,
input_kwargs, input_kwargs,
) )
async with self._semaphore: async with self._semaphore:
@ -224,6 +245,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
@ -233,11 +255,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = ( input_args = (
self.model, self.model,
self.tokenizer, self.tokenizer,
self.processor,
self.template, self.template,
self.generating_args, self.generating_args,
messages, messages,
system, system,
tools, tools,
image,
input_kwargs, input_kwargs,
) )
async with self._semaphore: async with self._semaphore:

View File

@ -12,7 +12,10 @@ if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@ -29,7 +32,9 @@ class VllmEngine(BaseEngine):
infer_dtype = str(infer_dtype).split(".")[-1] infer_dtype = str(infer_dtype).split(".")[-1]
self.can_generate = finetuning_args.stage == "sft" self.can_generate = finetuning_args.stage == "sft"
self.tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
@ -58,6 +63,7 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
@ -121,10 +127,11 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, **input_kwargs) generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for request_output in generator: async for request_output in generator:
final_output = request_output final_output = request_output
@ -146,10 +153,11 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
generator = await self._generate(messages, system, tools, **input_kwargs) generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for result in generator: async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :] delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text generated_text = result.outputs[0].text

View File

@ -8,7 +8,7 @@ from .utils import Role
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL import Image from PIL.Image import Image
from transformers import ProcessorMixin, Seq2SeqTrainingArguments from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
@ -271,7 +271,11 @@ def get_preprocess_and_print_func(
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[Callable, Callable]: ) -> Tuple[Callable, Callable]:
if stage == "pt": if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args) preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate: elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing: if data_args.packing:

View File

@ -21,7 +21,7 @@ from .template import get_eval_template
class Evaluator: class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args) self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.model = load_model(self.tokenizer, self.model_args, finetuning_args)

View File

@ -196,6 +196,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)

View File

@ -24,8 +24,9 @@ def run_dpo(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding( data_collator = PairwiseDataCollatorWithPadding(

View File

@ -24,8 +24,9 @@ def run_orpo(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding( data_collator = PairwiseDataCollatorWithPadding(

View File

@ -27,8 +27,9 @@ def run_ppo(
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo") tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training

View File

@ -25,8 +25,9 @@ def run_pt(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt") tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

View File

@ -25,8 +25,9 @@ def run_rm(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)

View File

@ -29,9 +29,9 @@ def run_sft(
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=training_args.do_train) dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if training_args.predict_with_generate: if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation tokenizer.padding_side = "left" # use left-padding in generation

View File

@ -52,7 +52,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None: if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
raise ValueError("Please merge adapters before quantizing the model.") raise ValueError("Please merge adapters before quantizing the model.")
tokenizer = load_tokenizer(model_args) tokenizer = load_tokenizer(model_args)["tokenizer"]
get_template_and_fix_tokenizer(tokenizer, data_args.template) get_template_and_fix_tokenizer(tokenizer, data_args.template)
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab

View File

@ -91,7 +91,7 @@ def create_ref_model(
) )
ref_model_args = ModelArguments(**ref_model_args_dict) ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora") ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
tokenizer = load_tokenizer(ref_model_args) tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model( ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
) )
@ -100,7 +100,7 @@ def create_ref_model(
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
ref_model = None ref_model = None
else: else:
tokenizer = load_tokenizer(model_args) tokenizer = load_tokenizer(model_args)["tokenizer"]
ref_model = load_model( ref_model = load_model(
tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
) )
@ -147,7 +147,7 @@ def create_reward_model(
) )
reward_model_args = ModelArguments(**reward_model_args_dict) reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora") reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
tokenizer = load_tokenizer(reward_model_args) tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
reward_model = load_model( reward_model = load_model(
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
) )

View File

@ -2,6 +2,8 @@ import json
import os import os
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
from numpy.typing import NDArray
from ..chat import ChatModel from ..chat import ChatModel
from ..data import Role from ..data import Role
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
@ -112,6 +114,7 @@ class WebChatModel(ChatModel):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: str, system: str,
tools: str, tools: str,
image: Optional[NDArray],
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float, temperature: float,
@ -119,7 +122,7 @@ class WebChatModel(ChatModel):
chatbot[-1][1] = "" chatbot[-1][1] = ""
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
if tools: if tools:

View File

@ -23,9 +23,15 @@ def create_chat_box(
messages = gr.State([]) messages = gr.State([])
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value) with gr.Row():
system = gr.Textbox(show_label=False) with gr.Column():
tools = gr.Textbox(show_label=False, lines=2) role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=4)
with gr.Column():
image = gr.Image(type="numpy")
query = gr.Textbox(show_label=False, lines=8) query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary") submit_btn = gr.Button(variant="primary")
@ -43,7 +49,7 @@ def create_chat_box(
[chatbot, messages, query], [chatbot, messages, query],
).then( ).then(
engine.chatter.stream, engine.chatter.stream,
[chatbot, messages, system, tools, max_new_tokens, top_p, temperature], [chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
[chatbot, messages], [chatbot, messages],
) )
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
@ -56,6 +62,7 @@ def create_chat_box(
role=role, role=role,
system=system, system=system,
tools=tools, tools=tools,
image=image,
query=query, query=query,
submit_btn=submit_btn, submit_btn=submit_btn,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,

View File

@ -1073,6 +1073,17 @@ LOCALES = {
"placeholder": "工具列表(非必填)", "placeholder": "工具列表(非必填)",
}, },
}, },
"image": {
"en": {
"label": "Image (optional)",
},
"ru": {
"label": "Изображение (по желанию)",
},
"zh": {
"label": "图像(非必填)",
},
},
"query": { "query": {
"en": { "en": {
"placeholder": "Input...", "placeholder": "Input...",