This commit is contained in:
BUAADreamer 2024-05-27 20:11:23 +08:00
commit 576b0206c2
12 changed files with 106 additions and 66 deletions

View File

@ -107,7 +107,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall`.
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
@ -164,7 +164,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | qkv_proj | phi |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
@ -403,20 +403,6 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
```
<details><summary>For Alibaba Cloud PAI or AutoDL users</summary>
If you encountered display problems in LLaMA Board on Alibaba Cloud PAI, try using the following command to set environment variables before starting LLaMA Board:
```bash
export GRADIO_SERVER_PORT=7860 GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
If you are using AutoDL, please install a specific version of Gradio:
```bash
pip install gradio==4.10.0
```
</details>
#### Use Docker

View File

@ -107,7 +107,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[24/02/05] Qwen1.5Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `dataset: glaive_toolcall` 即可使模型获得工具调用能力。
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `dataset: glaive_toolcall_zh` 即可使模型获得工具调用能力。
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `use_unsloth: true` 参数启用 unsloth 优化。该方法可提供 **170%** 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
@ -164,7 +164,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | qkv_proj | phi |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
@ -403,22 +403,6 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
```
<details><summary>阿里云 PAI 和 AutoDL 用户指南</summary>
如果您在阿里云 PAI 上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
```bash
export GRADIO_SERVER_PORT=7860 GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
如果您正在使用 AutoDL请安装下述 Gradio 版本:
```bash
pip install gradio==4.10.0
```
</details>
#### 使用 Docker
```bash

View File

@ -262,6 +262,36 @@
"ruozhiba_gpt4": {
"hf_hub_url": "hfl/ruozhiba_gpt4_turbo"
},
"llava_1k_en": {
"hf_hub_url": "BUAADreamer/llava-en-zh-2k",
"subset": "en",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"llava_1k_zh": {
"hf_hub_url": "BUAADreamer/llava-en-zh-2k",
"subset": "zh",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"llava_150k_en": {
"hf_hub_url": "BUAADreamer/llava-en-zh-300k",
"subset": "en",

View File

@ -6,6 +6,7 @@ stage: dpo
do_train: true
finetuning_type: lora
lora_target: q_proj,v_proj
pref_beta: 0.1
pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]
### dataset

View File

@ -8,7 +8,6 @@ import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_TOKEN
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
@ -60,9 +59,9 @@ class HuggingfaceEngine(BaseEngine):
processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
and template.image_token not in messages[0]["content"]
): # llava-like models
messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"]
messages[0]["content"] = template.image_token + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
@ -75,7 +74,7 @@ class HuggingfaceEngine(BaseEngine):
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids)

View File

@ -2,7 +2,6 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_TOKEN
from ..extras.logging import get_logger
from ..extras.misc import get_device_count, infer_optim_dtype
from ..extras.packages import is_vllm_available
@ -67,7 +66,7 @@ class VllmEngine(BaseEngine):
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None):
@ -97,9 +96,9 @@ class VllmEngine(BaseEngine):
self.processor is not None
and image is not None
and not hasattr(self.processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
): # llava-like models
messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"]
and self.template.image_token not in messages[0]["content"]
): # llava-like models (TODO: paligemma models)
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
@ -46,7 +46,7 @@ def preprocess_feedback_dataset(
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
if examples["response"][i][0]["content"]: # desired example
kto_tag = True
@ -82,7 +82,7 @@ def preprocess_feedback_dataset(
kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
input_ids = prompt_ids + response_ids

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
@ -44,7 +44,7 @@ def preprocess_pairwise_dataset(
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
@ -70,7 +70,7 @@ def preprocess_pairwise_dataset(
rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
chosen_input_ids = prompt_ids + chosen_ids

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
@ -37,13 +37,13 @@ def preprocess_supervised_dataset(
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")

View File

@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IMAGE_TOKEN
from ...extras.logging import get_logger
from ..utils import Role
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
@ -37,7 +36,7 @@ def preprocess_unsupervised_dataset(
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i]
@ -57,7 +56,7 @@ def preprocess_unsupervised_dataset(
labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
model_inputs["input_ids"].append(input_ids)

View File

@ -26,6 +26,7 @@ class Template:
format_separator: "Formatter"
default_system: str
stop_words: List[str]
image_token: str
efficient_eos: bool
replace_eos: bool
force_system: bool
@ -209,6 +210,7 @@ def _register_template(
format_separator: Optional["Formatter"] = None,
default_system: str = "",
stop_words: List[str] = [],
image_token: str = "<image>",
efficient_eos: bool = False,
replace_eos: bool = False,
force_system: bool = False,
@ -256,6 +258,7 @@ def _register_template(
format_separator=format_separator or default_separator_formatter,
default_system=default_system,
stop_words=stop_words,
image_token=image_token,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
force_system=force_system,
@ -730,7 +733,7 @@ _register_template(
_register_template(
name="mistral",
format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
@ -738,7 +741,7 @@ _register_template(
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
force_system=True,
@ -766,7 +769,6 @@ _register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful AI assistant.",
stop_words=["<|end|>"],

View File

@ -22,8 +22,6 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100
IMAGE_TOKEN = "<image>"
LAYERNORM_NAMES = {"norm", "ln"}
METHODS = ["full", "freeze", "lora"]
@ -327,6 +325,7 @@ register_model_group(
},
"DeepSeek-MoE-16B-v2-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite",
},
"DeepSeek-MoE-236B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
@ -338,6 +337,7 @@ register_model_group(
},
"DeepSeek-MoE-16B-v2-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat",
},
"DeepSeek-MoE-236B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
@ -430,6 +430,12 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
},
"Gemma-1.1-2B-Chat": {
DownloadSource.DEFAULT: "google/gemma-1.1-2b-it",
},
"Gemma-1.1-7B-Chat": {
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
},
},
template="gemma",
)
@ -437,16 +443,19 @@ register_model_group(
register_model_group(
models={
"CodeGemma-2B": {
DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
},
"CodeGemma-7B": {
DownloadSource.DEFAULT: "google/codegemma-7b",
},
"CodeGemma-7B-Chat": {
DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
DownloadSource.DEFAULT: "google/codegemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
},
"CodeGemma-1.1-2B": {
DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
},
"CodeGemma-1.1-7B-Chat": {
DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
},
},
template="gemma",
)
@ -635,6 +644,12 @@ register_model_group(
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
},
"Mistral-7B-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
},
"Mistral-7B-v0.3-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
},
},
template="mistral",
)
@ -656,6 +671,7 @@ register_model_group(
},
"Mixtral-8x22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
},
},
template="mistral",
@ -670,6 +686,9 @@ register_model_group(
"OLMo-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
},
"OLMo-7B-Chat": {
DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf",
},
"OLMo-1.7-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
},
@ -719,18 +738,23 @@ register_model_group(
models={
"PaliGemma-3B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
},
"PaliGemma-3B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
},
"PaliGemma-3B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
},
"PaliGemma-3B-mix-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
},
"PaliGemma-3B-mix-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
},
},
vision=True,
@ -753,14 +777,30 @@ register_model_group(
register_model_group(
models={
"Phi3-3.8B-4k-Chat": {
"Phi3-4B-4k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
},
"Phi3-3.8B-128k-Chat": {
"Phi3-4B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
},
"Phi3-7B-8k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
},
"Phi3-7B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
},
"Phi3-14B-8k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
},
"Phi3-14B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
},
},
module="qkv_proj",
template="phi",