diff --git a/.dockerignore b/.dockerignore
index 2ac0e11d..23ad75a8 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -4,10 +4,10 @@
.venv
cache
data
+docker
+saves
hf_cache
output
-examples
.dockerignore
.gitattributes
.gitignore
-Dockerfile
diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml
index 1d962200..768adea6 100644
--- a/.github/ISSUE_TEMPLATE/bug-report.yml
+++ b/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -38,7 +38,9 @@ body:
请合理使用 Markdown 标签来格式化您的文本。
placeholder: |
+ ```bash
llamafactory-cli train ...
+ ```
- type: textarea
id: expected-behavior
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index b31e9d19..d23d6be3 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -5,3 +5,4 @@ Fixes # (issue)
## Before submitting
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
+- [ ] Did you write any new necessary tests?
diff --git a/.github/workflows/label_issue.yml b/.github/workflows/label_issue.yml
new file mode 100644
index 00000000..ffd644a7
--- /dev/null
+++ b/.github/workflows/label_issue.yml
@@ -0,0 +1,27 @@
+name: label_issue
+
+on:
+ issues:
+ types:
+ - opened
+
+jobs:
+ label_issue:
+ runs-on: ubuntu-latest
+
+ steps:
+ - env:
+ GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ ISSUE_URL: ${{ github.event.issue.html_url }}
+ ISSUE_TITLE: ${{ github.event.issue.title }}
+ run: |
+ LABEL=pending
+ NPU_KEYWORDS=(npu ascend huawei 华为 昇腾)
+ ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]')
+ for KEYWORD in ${NPU_KEYWORDS[@]}; do
+ if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then
+ LABEL=pending,npu
+ break
+ fi
+ done
+ gh issue edit $ISSUE_URL --add-label $LABEL
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
new file mode 100644
index 00000000..15c7153e
--- /dev/null
+++ b/.github/workflows/publish.yml
@@ -0,0 +1,40 @@
+name: publish
+
+on:
+ release:
+ types:
+ - published
+
+jobs:
+ publish:
+ name: Upload release to PyPI
+
+ runs-on: ubuntu-latest
+
+ environment:
+ name: release
+ url: https://pypi.org/p/llamafactory
+
+ permissions:
+ id-token: write
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.8"
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install build
+
+ - name: Build package
+ run: |
+ python -m build
+
+ - name: Publish package
+ uses: pypa/gh-action-pypi-publish@release/v1
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 32edf6a8..73d77de5 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -19,21 +19,27 @@ on:
jobs:
tests:
runs-on: ubuntu-latest
+
steps:
- - uses: actions/checkout@v4
+ - name: Checkout
+ uses: actions/checkout@v4
+
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.8"
cache: "pip"
cache-dependency-path: "setup.py"
+
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- python -m pip install .[torch,dev]
+ python -m pip install ".[torch,dev]"
+
- name: Check quality
run: |
make style && make quality
+
- name: Test with pytest
run: |
make test
diff --git a/.gitignore b/.gitignore
index 0355c666..82e6e9e6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -160,6 +160,8 @@ cython_debug/
.idea/
# custom .gitignore
-user.config
-saves/
cache/
+config/
+saves/
+output/
+wandb/
diff --git a/CITATION.cff b/CITATION.cff
index 4caf3787..01b4c9fd 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -12,12 +12,16 @@ authors:
given-names: "Yanhan"
- family-names: "Luo"
given-names: "Zheyan"
+- family-names: "Feng"
+ given-names: "Zhangchi"
- family-names: "Ma"
given-names: "Yongqiang"
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
url: "https://arxiv.org/abs/2403.13372"
preferred-citation:
- type: article
+ type: conference-paper
+ conference:
+ name: "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)"
authors:
- family-names: "Zheng"
given-names: "Yaowei"
@@ -29,9 +33,12 @@ preferred-citation:
given-names: "Yanhan"
- family-names: "Luo"
given-names: "Zheyan"
+ - family-names: "Feng"
+ given-names: "Zhangchi"
- family-names: "Ma"
given-names: "Yongqiang"
- journal: "arXiv preprint arXiv:2403.13372"
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
url: "https://arxiv.org/abs/2403.13372"
year: 2024
+ publisher: "Association for Computational Linguistics"
+ address: "Bangkok, Thailand"
diff --git a/Dockerfile b/Dockerfile
deleted file mode 100644
index 0a35e355..00000000
--- a/Dockerfile
+++ /dev/null
@@ -1,14 +0,0 @@
-FROM nvcr.io/nvidia/pytorch:24.01-py3
-
-WORKDIR /app
-
-COPY requirements.txt /app/
-RUN pip install -r requirements.txt
-
-COPY . /app/
-RUN pip install -e .[metrics,bitsandbytes,qwen]
-
-VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
-EXPOSE 7860
-
-CMD [ "llamafactory-cli", "webui" ]
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 00000000..82c51f63
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include LICENSE requirements.txt
diff --git a/Makefile b/Makefile
index 65be047b..3f13b215 100644
--- a/Makefile
+++ b/Makefile
@@ -11,4 +11,4 @@ style:
ruff format $(check_dirs)
test:
- pytest tests/
+ CUDA_VISIBLE_DEVICES= pytest tests/
diff --git a/README.md b/README.md
index fb6c5782..3d3feae5 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
-[![Citation](https://img.shields.io/badge/citation-44-green)](#projects-using-llama-factory)
+[![Citation](https://img.shields.io/badge/citation-71-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -15,7 +15,7 @@
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
-👋 Join our [WeChat](assets/wechat.jpg).
+👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
\[ English | [中文](README_zh.md) \]
@@ -48,8 +48,8 @@ Choose your path:
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
-- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
-- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
+- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
+- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@@ -71,9 +71,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
-[24/06/07] We supported fine-tuning the **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** series models.
+[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
-[24/06/05] We supported fine-tuning the **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** models.
+[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
@@ -151,35 +151,32 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models
-| Model | Model size | Template |
-| -------------------------------------------------------- | -------------------------------- | --------- |
-| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
-| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
-| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
-| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
-| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
-| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
-| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
-| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
-| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
-| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
-| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
-| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
-| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
-| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
-| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
-| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
-| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
-| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
-| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
-| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
-| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
-| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
-| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
-| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
-| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
-| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
-| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
+| Model | Model size | Template |
+| ------------------------------------------------------------ | -------------------------------- | --------- |
+| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
+| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
+| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
+| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
+| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
+| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
+| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
+| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
+| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
+| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
+| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
+| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
+| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
+| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
+| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
+| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
+| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
+| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
+| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
+| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
+| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
+| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
+| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
+| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
@@ -259,6 +256,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
+- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
+- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
+- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@@ -335,10 +335,10 @@ huggingface-cli login
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
-pip install -e '.[torch,metrics]'
+pip install -e ".[torch,metrics]"
```
-Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
+Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality
> [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts.
@@ -357,9 +357,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
For Ascend NPU users
-Join [NPU user group](assets/wechat_npu.jpg).
-
-To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
+To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
```bash
# replace the url according to your CANN version and devices
@@ -382,15 +380,12 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 |
-Docker image:
-
-- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
-- 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
-
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
+Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
+
### Data Preparation
@@ -405,9 +400,9 @@ Please refer to [data/README.md](data/README.md) for checking the details about
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
+llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
```
See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
@@ -417,34 +412,89 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
-#### Use local environment
-
```bash
-CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
+llamafactory-cli webui
```
-
+### Build Docker
-#### Use Docker
+For CUDA users:
```bash
-docker build -f ./Dockerfile -t llama-factory:latest .
-docker run --gpus=all \
- -v ./hf_cache:/root/.cache/huggingface/ \
+cd docker/docker-cuda/
+docker-compose up -d
+docker-compose exec llamafactory bash
+```
+
+For Ascend NPU users:
+
+```bash
+cd docker/docker-npu/
+docker-compose up -d
+docker-compose exec llamafactory bash
+```
+
+Build without Docker Compose
+
+For CUDA users:
+
+```bash
+docker build -f ./docker/docker-cuda/Dockerfile \
+ --build-arg INSTALL_BNB=false \
+ --build-arg INSTALL_VLLM=false \
+ --build-arg INSTALL_DEEPSPEED=false \
+ --build-arg INSTALL_FLASHATTN=false \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ -t llamafactory:latest .
+
+docker run -dit --gpus=all \
+ -v ./hf_cache:/root/.cache/huggingface \
+ -v ./ms_cache:/root/.cache/modelscope \
-v ./data:/app/data \
-v ./output:/app/output \
-p 7860:7860 \
+ -p 8000:8000 \
--shm-size 16G \
- --name llama_factory \
- -d llama-factory:latest
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
```
-#### Use Docker Compose
+For Ascend NPU users:
```bash
-docker compose -f ./docker-compose.yml up -d
+# Choose docker image upon your environment
+docker build -f ./docker/docker-npu/Dockerfile \
+ --build-arg INSTALL_DEEPSPEED=false \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ -t llamafactory:latest .
+
+# Change `device` upon your resources
+docker run -dit \
+ -v ./hf_cache:/root/.cache/huggingface \
+ -v ./ms_cache:/root/.cache/modelscope \
+ -v ./data:/app/data \
+ -v ./output:/app/output \
+ -v /usr/local/dcmi:/usr/local/dcmi \
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
+ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
+ -p 7860:7860 \
+ -p 8000:8000 \
+ --device /dev/davinci0 \
+ --device /dev/davinci_manager \
+ --device /dev/devmm_svm \
+ --device /dev/hisi_hdc \
+ --shm-size 16G \
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
```
+
+
Details about volume
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
@@ -456,7 +506,7 @@ docker compose -f ./docker-compose.yml up -d
### Deploy with OpenAI-style API and vLLM
```bash
-CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
+API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
```
> [!TIP]
@@ -474,7 +524,7 @@ Train the model by specifying a model ID of the ModelScope Hub as the `model_nam
### Use W&B Logger
-To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments.
+To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
```yaml
report_to: wandb
@@ -494,38 +544,63 @@ If you have a project that should be incorporated, please contact via email or c
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
-1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
-1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
+1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
+1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
-1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
+1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
-1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
+1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
-1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
+1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
-1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
+1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
-1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
-1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
+1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
+1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
+1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
+1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
+1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
+1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
+1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
+1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
+1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
+1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
+1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
+1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
+1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
+1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
+1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
+1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
+1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
+1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
+1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
+1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
+1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
+1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
+1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
+1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
+1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
+1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
+1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
@@ -533,6 +608,8 @@ If you have a project that should be incorporated, please contact via email or c
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
+1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
+1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
@@ -540,17 +617,19 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE).
-Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
+Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation
If this work is helpful, please kindly cite as:
```bibtex
-@article{zheng2024llamafactory,
+@inproceedings{zheng2024llamafactory,
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
- author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
- journal={arXiv preprint arXiv:2403.13372},
+ author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
+ booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
+ address={Bangkok, Thailand},
+ publisher={Association for Computational Linguistics},
year={2024},
url={http://arxiv.org/abs/2403.13372}
}
diff --git a/README_zh.md b/README_zh.md
index 142254df..cb5a42e4 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
-[![Citation](https://img.shields.io/badge/citation-44-green)](#使用了-llama-factory-的项目)
+[![Citation](https://img.shields.io/badge/citation-71-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -15,7 +15,7 @@
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
-👋 加入我们的[微信群](assets/wechat.jpg)。
+👋 加入我们的[微信群](assets/wechat.jpg)或 [NPU 用户群](assets/wechat_npu.jpg)。
\[ [English](README.md) | 中文 \]
@@ -48,8 +48,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
-- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
-- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
+- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
+- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@@ -71,9 +71,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志
-[24/06/07] 我们支持了 **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** 系列模型的微调。
+[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
-[24/06/05] 我们支持了 **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** 模型的微调。
+[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
@@ -151,35 +151,32 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型
-| 模型名 | 模型大小 | Template |
-| -------------------------------------------------------- | -------------------------------- | --------- |
-| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
-| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
-| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
-| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
-| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
-| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
-| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
-| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
-| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
-| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
-| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
-| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
-| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
-| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
-| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
-| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
-| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
-| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
-| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
-| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
-| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
-| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
-| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
-| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
-| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
-| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
-| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
+| 模型名 | 模型大小 | Template |
+| ------------------------------------------------------------ | -------------------------------- | --------- |
+| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
+| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
+| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
+| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
+| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
+| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
+| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
+| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
+| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
+| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
+| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
+| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
+| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
+| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
+| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
+| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
+| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
+| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
+| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
+| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
+| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
+| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
+| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
+| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
@@ -259,6 +256,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
+- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
+- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
+- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@@ -335,10 +335,10 @@ huggingface-cli login
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
-pip install -e '.[torch,metrics]'
+pip install -e ".[torch,metrics]"
```
-可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
+可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality
> [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
@@ -357,9 +357,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
昇腾 NPU 用户指南
-加入 [NPU 用户群](assets/wechat_npu.jpg)。
-
-在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
+在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
@@ -382,15 +380,12 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 |
-Docker 镜像:
-
-- 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
-- 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
-
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
+下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
+
### 数据准备
@@ -405,9 +400,9 @@ Docker 镜像:
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
+llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
```
高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。
@@ -417,32 +412,89 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
-#### 使用本地环境
-
```bash
-CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
+llamafactory-cli webui
```
-#### 使用 Docker
+### 构建 Docker
+
+CUDA 用户:
```bash
-docker build -f ./Dockerfile -t llama-factory:latest .
-docker run --gpus=all \
- -v ./hf_cache:/root/.cache/huggingface/ \
+cd docker/docker-cuda/
+docker-compose up -d
+docker-compose exec llamafactory bash
+```
+
+昇腾 NPU 用户:
+
+```bash
+cd docker/docker-npu/
+docker-compose up -d
+docker-compose exec llamafactory bash
+```
+
+不使用 Docker Compose 构建
+
+CUDA 用户:
+
+```bash
+docker build -f ./docker/docker-cuda/Dockerfile \
+ --build-arg INSTALL_BNB=false \
+ --build-arg INSTALL_VLLM=false \
+ --build-arg INSTALL_DEEPSPEED=false \
+ --build-arg INSTALL_FLASHATTN=false \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ -t llamafactory:latest .
+
+docker run -dit --gpus=all \
+ -v ./hf_cache:/root/.cache/huggingface \
+ -v ./ms_cache:/root/.cache/modelscope \
-v ./data:/app/data \
-v ./output:/app/output \
-p 7860:7860 \
+ -p 8000:8000 \
--shm-size 16G \
- --name llama_factory \
- -d llama-factory:latest
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
```
-#### 使用 Docker Compose
+昇腾 NPU 用户:
```bash
-docker compose -f ./docker-compose.yml up -d
+# 根据您的环境选择镜像
+docker build -f ./docker/docker-npu/Dockerfile \
+ --build-arg INSTALL_DEEPSPEED=false \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ -t llamafactory:latest .
+
+# 根据您的资源更改 `device`
+docker run -dit \
+ -v ./hf_cache:/root/.cache/huggingface \
+ -v ./ms_cache:/root/.cache/modelscope \
+ -v ./data:/app/data \
+ -v ./output:/app/output \
+ -v /usr/local/dcmi:/usr/local/dcmi \
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
+ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
+ -p 7860:7860 \
+ -p 8000:8000 \
+ --device /dev/davinci0 \
+ --device /dev/davinci_manager \
+ --device /dev/devmm_svm \
+ --device /dev/hisi_hdc \
+ --shm-size 16G \
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
```
+
+
数据卷详情
- hf_cache:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
@@ -454,7 +506,7 @@ docker compose -f ./docker-compose.yml up -d
### 利用 vLLM 部署 OpenAI API
```bash
-CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
+API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
```
> [!TIP]
@@ -472,7 +524,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
### 使用 W&B 面板
-若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请添加下面的参数。
+若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
```yaml
report_to: wandb
@@ -492,38 +544,63 @@ run_name: test_run # 可选
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
-1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
-1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
+1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
+1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
-1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
+1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
-1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
+1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
-1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
+1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
-1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
+1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
-1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
-1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
+1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
+1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
+1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
+1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
+1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
+1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
+1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
+1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
+1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
+1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
+1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
+1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
+1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
+1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
+1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
+1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
+1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
+1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
+1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
+1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
+1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
+1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
+1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
+1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
+1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
+1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
+1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
@@ -531,6 +608,8 @@ run_name: test_run # 可选
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
+1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
+1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: 在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
@@ -538,17 +617,19 @@ run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
-使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
+使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用
如果您觉得此项目有帮助,请考虑以下列格式引用
```bibtex
-@article{zheng2024llamafactory,
- title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
- author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
- journal={arXiv preprint arXiv:2403.13372},
+@inproceedings{zheng2024llamafactory,
+ title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
+ author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
+ booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
+ address={Bangkok, Thailand},
+ publisher={Association for Computational Linguistics},
year={2024},
url={http://arxiv.org/abs/2403.13372}
}
diff --git a/assets/wechat.jpg b/assets/wechat.jpg
index e0e89b78..f26119a9 100644
Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ
diff --git a/assets/wechat_npu.jpg b/assets/wechat_npu.jpg
index 97b6d5dd..2c7e0817 100644
Binary files a/assets/wechat_npu.jpg and b/assets/wechat_npu.jpg differ
diff --git a/data/dataset_info.json b/data/dataset_info.json
index 8c5cbb45..f8ffd407 100644
--- a/data/dataset_info.json
+++ b/data/dataset_info.json
@@ -248,6 +248,21 @@
"ruozhiba_gpt4": {
"hf_hub_url": "hfl/ruozhiba_gpt4_turbo"
},
+ "neo_sft": {
+ "hf_hub_url": "m-a-p/neo_sft_phase2",
+ "formatting": "sharegpt"
+ },
+ "magpie_pro_300k": {
+ "hf_hub_url": "Magpie-Align/Magpie-Pro-300K-Filtered",
+ "formatting": "sharegpt"
+ },
+ "web_instruct": {
+ "hf_hub_url": "TIGER-Lab/WebInstructSub",
+ "columns": {
+ "prompt": "question",
+ "response": "answer"
+ }
+ },
"llava_1k_en": {
"hf_hub_url": "BUAADreamer/llava-en-zh-2k",
"subset": "en",
@@ -520,13 +535,13 @@
"prompt": "text"
}
},
- "fileweb": {
+ "fineweb": {
"hf_hub_url": "HuggingFaceFW/fineweb",
"columns": {
"prompt": "text"
}
},
- "fileweb_edu": {
+ "fineweb_edu": {
"hf_hub_url": "HuggingFaceFW/fineweb-edu",
"columns": {
"prompt": "text"
diff --git a/data/glaive_toolcall_en_demo.json b/data/glaive_toolcall_en_demo.json
index c17465a3..4dcd4383 100644
--- a/data/glaive_toolcall_en_demo.json
+++ b/data/glaive_toolcall_en_demo.json
@@ -1,4 +1,41 @@
[
+ {
+ "conversations": [
+ {
+ "from": "human",
+ "value": "Hi, I have some ingredients and I want to cook something. Can you help me find a recipe?"
+ },
+ {
+ "from": "gpt",
+ "value": "Of course! I can help you with that. Please tell me what ingredients you have."
+ },
+ {
+ "from": "human",
+ "value": "I have chicken, bell peppers, and rice."
+ },
+ {
+ "from": "function_call",
+ "value": "{\"name\": \"search_recipes\", \"arguments\": {\"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"]}}"
+ },
+ {
+ "from": "observation",
+ "value": "{\"recipes\": [{\"name\": \"Chicken and Bell Pepper Stir Fry\", \"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"], \"instructions\": \"Cut the chicken into small pieces. Slice the bell peppers. Cook the rice. Stir fry the chicken and bell peppers. Serve over rice.\"}, {\"name\": \"Chicken and Rice Casserole\", \"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"], \"instructions\": \"Cook the chicken and rice separately. Mix them together with the bell peppers in a casserole dish. Bake until golden brown.\"}]}"
+ },
+ {
+ "from": "gpt",
+ "value": "I found two recipes for you. The first one is \"Chicken and Bell Pepper Stir Fry\". The instructions are: Cut the chicken into small pieces. Slice the bell peppers. Cook the rice. Stir fry the chicken and bell peppers. Serve over rice. The second one is \"Chicken and Rice Casserole\". The instructions are: Cook the chicken and rice separately. Mix them together with the bell peppers in a casserole dish. Bake until golden brown. Which one would you like to try?"
+ },
+ {
+ "from": "human",
+ "value": "They both sound delicious, but I think I'll try the stir fry. Can you order the ingredients for me?"
+ },
+ {
+ "from": "gpt",
+ "value": "I'm sorry, but as an AI, I don't have the capability to perform external tasks such as ordering ingredients. However, I can help you find more recipes or provide cooking tips if you need."
+ }
+ ],
+ "tools": "[{\"name\": \"search_recipes\", \"description\": \"Search for recipes based on ingredients\", \"parameters\": {\"type\": \"object\", \"properties\": {\"ingredients\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}, \"description\": \"The ingredients to search for\"}}, \"required\": [\"ingredients\"]}}]"
+ },
{
"conversations": [
{
@@ -155,43 +192,6 @@
],
"tools": "[]"
},
- {
- "conversations": [
- {
- "from": "human",
- "value": "Hi, I have some ingredients and I want to cook something. Can you help me find a recipe?"
- },
- {
- "from": "gpt",
- "value": "Of course! I can help you with that. Please tell me what ingredients you have."
- },
- {
- "from": "human",
- "value": "I have chicken, bell peppers, and rice."
- },
- {
- "from": "function_call",
- "value": "{\"name\": \"search_recipes\", \"arguments\": {\"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"]}}"
- },
- {
- "from": "observation",
- "value": "{\"recipes\": [{\"name\": \"Chicken and Bell Pepper Stir Fry\", \"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"], \"instructions\": \"Cut the chicken into small pieces. Slice the bell peppers. Cook the rice. Stir fry the chicken and bell peppers. Serve over rice.\"}, {\"name\": \"Chicken and Rice Casserole\", \"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"], \"instructions\": \"Cook the chicken and rice separately. Mix them together with the bell peppers in a casserole dish. Bake until golden brown.\"}]}"
- },
- {
- "from": "gpt",
- "value": "I found two recipes for you. The first one is \"Chicken and Bell Pepper Stir Fry\". The instructions are: Cut the chicken into small pieces. Slice the bell peppers. Cook the rice. Stir fry the chicken and bell peppers. Serve over rice. The second one is \"Chicken and Rice Casserole\". The instructions are: Cook the chicken and rice separately. Mix them together with the bell peppers in a casserole dish. Bake until golden brown. Which one would you like to try?"
- },
- {
- "from": "human",
- "value": "They both sound delicious, but I think I'll try the stir fry. Can you order the ingredients for me?"
- },
- {
- "from": "gpt",
- "value": "I'm sorry, but as an AI, I don't have the capability to perform external tasks such as ordering ingredients. However, I can help you find more recipes or provide cooking tips if you need."
- }
- ],
- "tools": "[{\"name\": \"search_recipes\", \"description\": \"Search for recipes based on ingredients\", \"parameters\": {\"type\": \"object\", \"properties\": {\"ingredients\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}, \"description\": \"The ingredients to search for\"}}, \"required\": [\"ingredients\"]}}]"
- },
{
"conversations": [
{
diff --git a/data/glaive_toolcall_zh_demo.json b/data/glaive_toolcall_zh_demo.json
index 06efbac0..f40bda68 100644
--- a/data/glaive_toolcall_zh_demo.json
+++ b/data/glaive_toolcall_zh_demo.json
@@ -1,25 +1,4 @@
[
- {
- "conversations": [
- {
- "from": "human",
- "value": "假设你有一个需要随机数的Java程序,范围在0到10之间。你可以使用什么代码片段来生成这样的数字?\n\n下面是一个示例代码片段,它使用Math.random()方法来生成一个在0到1之间的随机数,然后缩放并四舍五入结果,得到一个在0到10之间的随机数:\n\n```java\nint randomNumber = (int) (Math.random() * 11);\n```\n\n在Java中是否有更有效的方法来生成随机数?你会如何修改代码以生成具有不同范围或分布的随机数?"
- },
- {
- "from": "gpt",
- "value": "您提供的代码片段是Java中生成随机数的一种常见且高效的方法。然而,如果您需要对随机数生成有更多的控制,您可以使用java.util.Random类,它提供了更多生成随机数的选项。\n\n要生成一个在不同范围内的随机数,您可以修改代码以调整缩放因子。例如,要生成一个在1到10之间的随机数,您可以使用以下代码:\n\nint randomNumber = (int) (Math.random() * 10) + 1;\n\n要生成具有不同分布的随机数,您可以使用各种技术,如正态分布、指数分布或自定义分布。生成正态分布的一种方法是使用Box-Muller变换。这里有一个代码示例,它生成具有正态分布的随机数:\n\nRandom random = new Random();\ndouble u1 = random.nextDouble();\ndouble u2 = random.nextDouble();\ndouble z = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);\ndouble mean = 5;\ndouble stdDev = 2;\nint randomNumber = (int) (z * stdDev + mean);\n\n这段代码生成一个具有均值5和标准差2的正态分布随机数。您可以调整均值和标准差来生成具有不同分布的随机数。"
- },
- {
- "from": "human",
- "value": "你能解释一下算法分析中的大O记法概念吗?"
- },
- {
- "from": "gpt",
- "value": "是的,大O表示法是一种描述算法性能随输入数据大小增加的方法。它用于估计算法在最坏情况下的时间复杂度或空间复杂度。大O中的“O”代表“数量级”,这意味着随着输入大小的增加,算法所需的操作数以一定的速率增长。\n\n例如,如果我们有一个需要n步才能解决大小为n的问题的算法,那么时间复杂度可以表示为O(n),这意味着操作数随着输入大小线性增长。如果我们有一个需要n^2步才能解决大小为n的问题的算法,那么时间复杂度可以表示为O(n^2),这意味着操作数随着输入大小二次方增长。\n\n还有其他几种大O表示法,如O(log n)、O(n log n)、O(2^n)等,它们代表了随着输入大小增加的不同操作数增长速率。通过分析算法的大O表示法,我们可以比较不同算法的性能,并选择给定问题最有效的算法。\n\n您是否有关于大O表示法或算法分析的特定问题?"
- }
- ],
- "tools": "[]"
- },
{
"conversations": [
{
@@ -41,6 +20,27 @@
],
"tools": "[{\"name\": \"generate_invoice\", \"description\": \"生成发票\", \"parameters\": {\"type\": \"object\", \"properties\": {\"customer_name\": {\"type\": \"string\", \"description\": \"客户名称\"}, \"items\": {\"type\": \"array\", \"items\": {\"type\": \"object\", \"properties\": {\"name\": {\"type\": \"string\", \"description\": \"The item name\"}, \"quantity\": {\"type\": \"integer\", \"description\": \"The quantity of the item\"}, \"price\": {\"type\": \"number\", \"description\": \"The price per unit\"}}, \"required\": [\"name\", \"quantity\", \"price\"]}}}, \"required\": [\"customer_name\", \"items\"]}}, {\"name\": \"generate_password\", \"description\": \"生成随机密码\", \"parameters\": {\"type\": \"object\", \"properties\": {\"length\": {\"type\": \"integer\", \"description\": \"密码的长度\"}}, \"required\": [\"length\"]}}]"
},
+ {
+ "conversations": [
+ {
+ "from": "human",
+ "value": "假设你有一个需要随机数的Java程序,范围在0到10之间。你可以使用什么代码片段来生成这样的数字?\n\n下面是一个示例代码片段,它使用Math.random()方法来生成一个在0到1之间的随机数,然后缩放并四舍五入结果,得到一个在0到10之间的随机数:\n\n```java\nint randomNumber = (int) (Math.random() * 11);\n```\n\n在Java中是否有更有效的方法来生成随机数?你会如何修改代码以生成具有不同范围或分布的随机数?"
+ },
+ {
+ "from": "gpt",
+ "value": "您提供的代码片段是Java中生成随机数的一种常见且高效的方法。然而,如果您需要对随机数生成有更多的控制,您可以使用java.util.Random类,它提供了更多生成随机数的选项。\n\n要生成一个在不同范围内的随机数,您可以修改代码以调整缩放因子。例如,要生成一个在1到10之间的随机数,您可以使用以下代码:\n\nint randomNumber = (int) (Math.random() * 10) + 1;\n\n要生成具有不同分布的随机数,您可以使用各种技术,如正态分布、指数分布或自定义分布。生成正态分布的一种方法是使用Box-Muller变换。这里有一个代码示例,它生成具有正态分布的随机数:\n\nRandom random = new Random();\ndouble u1 = random.nextDouble();\ndouble u2 = random.nextDouble();\ndouble z = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);\ndouble mean = 5;\ndouble stdDev = 2;\nint randomNumber = (int) (z * stdDev + mean);\n\n这段代码生成一个具有均值5和标准差2的正态分布随机数。您可以调整均值和标准差来生成具有不同分布的随机数。"
+ },
+ {
+ "from": "human",
+ "value": "你能解释一下算法分析中的大O记法概念吗?"
+ },
+ {
+ "from": "gpt",
+ "value": "是的,大O表示法是一种描述算法性能随输入数据大小增加的方法。它用于估计算法在最坏情况下的时间复杂度或空间复杂度。大O中的“O”代表“数量级”,这意味着随着输入大小的增加,算法所需的操作数以一定的速率增长。\n\n例如,如果我们有一个需要n步才能解决大小为n的问题的算法,那么时间复杂度可以表示为O(n),这意味着操作数随着输入大小线性增长。如果我们有一个需要n^2步才能解决大小为n的问题的算法,那么时间复杂度可以表示为O(n^2),这意味着操作数随着输入大小二次方增长。\n\n还有其他几种大O表示法,如O(log n)、O(n log n)、O(2^n)等,它们代表了随着输入大小增加的不同操作数增长速率。通过分析算法的大O表示法,我们可以比较不同算法的性能,并选择给定问题最有效的算法。\n\n您是否有关于大O表示法或算法分析的特定问题?"
+ }
+ ],
+ "tools": "[]"
+ },
{
"conversations": [
{
diff --git a/docker-compose.yml b/docker-compose.yml
deleted file mode 100644
index 9602a3e3..00000000
--- a/docker-compose.yml
+++ /dev/null
@@ -1,23 +0,0 @@
-version: '3.8'
-
-services:
- llama-factory:
- build:
- dockerfile: Dockerfile
- context: .
- container_name: llama_factory
- volumes:
- - ./hf_cache:/root/.cache/huggingface/
- - ./data:/app/data
- - ./output:/app/output
- ports:
- - "7860:7860"
- ipc: host
- deploy:
- resources:
- reservations:
- devices:
- - driver: nvidia
- count: "all"
- capabilities: [gpu]
- restart: unless-stopped
diff --git a/docker/docker-cuda/Dockerfile b/docker/docker-cuda/Dockerfile
new file mode 100644
index 00000000..d94aa970
--- /dev/null
+++ b/docker/docker-cuda/Dockerfile
@@ -0,0 +1,58 @@
+# Use the NVIDIA official image with PyTorch 2.3.0
+# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
+FROM nvcr.io/nvidia/pytorch:24.02-py3
+
+# Define environments
+ENV MAX_JOBS=4
+ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
+
+# Define installation arguments
+ARG INSTALL_BNB=false
+ARG INSTALL_VLLM=false
+ARG INSTALL_DEEPSPEED=false
+ARG INSTALL_FLASHATTN=false
+ARG PIP_INDEX=https://pypi.org/simple
+
+# Set the working directory
+WORKDIR /app
+
+# Install the requirements
+COPY requirements.txt /app
+RUN pip config set global.index-url "$PIP_INDEX" && \
+ pip config set global.extra-index-url "$PIP_INDEX" && \
+ python -m pip install --upgrade pip && \
+ python -m pip install -r requirements.txt
+
+# Rebuild flash attention
+RUN pip uninstall -y transformer-engine flash-attn && \
+ if [ "$INSTALL_FLASHATTN" == "true" ]; then \
+ pip uninstall -y ninja && pip install ninja && \
+ pip install --no-cache-dir flash-attn --no-build-isolation; \
+ fi
+
+# Copy the rest of the application into the image
+COPY . /app
+
+# Install the LLaMA Factory
+RUN EXTRA_PACKAGES="metrics"; \
+ if [ "$INSTALL_BNB" == "true" ]; then \
+ EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
+ fi; \
+ if [ "$INSTALL_VLLM" == "true" ]; then \
+ EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
+ fi; \
+ if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
+ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
+ fi; \
+ pip install -e ".[$EXTRA_PACKAGES]"
+
+# Set up volumes
+VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
+
+# Expose port 7860 for the LLaMA Board
+ENV GRADIO_SERVER_PORT 7860
+EXPOSE 7860
+
+# Expose port 8000 for the API service
+ENV API_PORT 8000
+EXPOSE 8000
diff --git a/docker/docker-cuda/docker-compose.yml b/docker/docker-cuda/docker-compose.yml
new file mode 100644
index 00000000..16267dc3
--- /dev/null
+++ b/docker/docker-cuda/docker-compose.yml
@@ -0,0 +1,32 @@
+services:
+ llamafactory:
+ build:
+ dockerfile: ./docker/docker-cuda/Dockerfile
+ context: ../..
+ args:
+ INSTALL_BNB: false
+ INSTALL_VLLM: false
+ INSTALL_DEEPSPEED: false
+ INSTALL_FLASHATTN: false
+ PIP_INDEX: https://pypi.org/simple
+ container_name: llamafactory
+ volumes:
+ - ../../hf_cache:/root/.cache/huggingface
+ - ../../ms_cache:/root/.cache/modelscope
+ - ../../data:/app/data
+ - ../../output:/app/output
+ ports:
+ - "7860:7860"
+ - "8000:8000"
+ ipc: host
+ tty: true
+ stdin_open: true
+ command: bash
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: "all"
+ capabilities: [gpu]
+ restart: unless-stopped
diff --git a/docker/docker-npu/Dockerfile b/docker/docker-npu/Dockerfile
new file mode 100644
index 00000000..34cf9616
--- /dev/null
+++ b/docker/docker-npu/Dockerfile
@@ -0,0 +1,45 @@
+# Use the Ubuntu 22.04 image with CANN 8.0.rc1
+# More versions can be found at https://hub.docker.com/r/cosdt/cann/tags
+# FROM cosdt/cann:8.0.rc1-910-ubuntu22.04
+FROM cosdt/cann:8.0.rc1-910b-ubuntu22.04
+# FROM cosdt/cann:8.0.rc1-910-openeuler22.03
+# FROM cosdt/cann:8.0.rc1-910b-openeuler22.03
+
+# Define environments
+ENV DEBIAN_FRONTEND=noninteractive
+
+# Define installation arguments
+ARG INSTALL_DEEPSPEED=false
+ARG PIP_INDEX=https://pypi.org/simple
+ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu
+
+# Set the working directory
+WORKDIR /app
+
+# Install the requirements
+COPY requirements.txt /app
+RUN pip config set global.index-url "$PIP_INDEX" && \
+ pip config set global.extra-index-url "$TORCH_INDEX" && \
+ python -m pip install --upgrade pip && \
+ python -m pip install -r requirements.txt
+
+# Copy the rest of the application into the image
+COPY . /app
+
+# Install the LLaMA Factory
+RUN EXTRA_PACKAGES="torch-npu,metrics"; \
+ if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
+ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
+ fi; \
+ pip install -e ".[$EXTRA_PACKAGES]"
+
+# Set up volumes
+VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
+
+# Expose port 7860 for the LLaMA Board
+ENV GRADIO_SERVER_PORT 7860
+EXPOSE 7860
+
+# Expose port 8000 for the API service
+ENV API_PORT 8000
+EXPOSE 8000
diff --git a/docker/docker-npu/docker-compose.yml b/docker/docker-npu/docker-compose.yml
new file mode 100644
index 00000000..657cba9f
--- /dev/null
+++ b/docker/docker-npu/docker-compose.yml
@@ -0,0 +1,31 @@
+services:
+ llamafactory:
+ build:
+ dockerfile: ./docker/docker-npu/Dockerfile
+ context: ../..
+ args:
+ INSTALL_DEEPSPEED: false
+ PIP_INDEX: https://pypi.org/simple
+ container_name: llamafactory
+ volumes:
+ - ../../hf_cache:/root/.cache/huggingface
+ - ../../ms_cache:/root/.cache/modelscope
+ - ../../data:/app/data
+ - ../../output:/app/output
+ - /usr/local/dcmi:/usr/local/dcmi
+ - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
+ - /usr/local/Ascend/driver:/usr/local/Ascend/driver
+ - /etc/ascend_install.info:/etc/ascend_install.info
+ ports:
+ - "7860:7860"
+ - "8000:8000"
+ ipc: host
+ tty: true
+ stdin_open: true
+ command: bash
+ devices:
+ - /dev/davinci0
+ - /dev/davinci_manager
+ - /dev/devmm_svm
+ - /dev/hisi_hdc
+ restart: unless-stopped
diff --git a/evaluation/ceval/ceval.py b/evaluation/ceval/ceval.py
index 4111d6b4..48442d50 100644
--- a/evaluation/ceval/ceval.py
+++ b/evaluation/ceval/ceval.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os
import datasets
diff --git a/evaluation/cmmlu/cmmlu.py b/evaluation/cmmlu/cmmlu.py
index 37efb328..5ff548a4 100644
--- a/evaluation/cmmlu/cmmlu.py
+++ b/evaluation/cmmlu/cmmlu.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os
import datasets
diff --git a/evaluation/mmlu/mmlu.py b/evaluation/mmlu/mmlu.py
index a4530250..1065fb31 100644
--- a/evaluation/mmlu/mmlu.py
+++ b/evaluation/mmlu/mmlu.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os
import datasets
diff --git a/examples/README.md b/examples/README.md
index f985d552..d5aca5ad 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -4,59 +4,59 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
## Table of Contents
-- [LoRA Fine-Tuning on A Single GPU](#lora-fine-tuning-on-a-single-gpu)
-- [QLoRA Fine-Tuning on a Single GPU](#qlora-fine-tuning-on-a-single-gpu)
-- [LoRA Fine-Tuning on Multiple GPUs](#lora-fine-tuning-on-multiple-gpus)
-- [LoRA Fine-Tuning on Multiple NPUs](#lora-fine-tuning-on-multiple-npus)
-- [Full-Parameter Fine-Tuning on Multiple GPUs](#full-parameter-fine-tuning-on-multiple-gpus)
+- [LoRA Fine-Tuning](#lora-fine-tuning)
+- [QLoRA Fine-Tuning](#qlora-fine-tuning)
+- [Full-Parameter Fine-Tuning](#full-parameter-fine-tuning)
- [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization)
- [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models)
- [Extras](#extras)
+Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
+
## Examples
-### LoRA Fine-Tuning on A Single GPU
+### LoRA Fine-Tuning
#### (Continuous) Pre-Training
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
```
#### Supervised Fine-Tuning
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
#### Multimodal Supervised Fine-Tuning
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
```
#### Reward Modeling
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
```
#### PPO Training
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
```
#### DPO/ORPO/SimPO Training
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO Training
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
```
#### Preprocess Dataset
@@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml
+llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
```
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml
+llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml
-```
-
-### QLoRA Fine-Tuning on a Single GPU
-
-#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
-```
-
-#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
-```
-
-#### Supervised Fine-Tuning with 4-bit AWQ Quantization
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
-```
-
-#### Supervised Fine-Tuning with 2-bit AQLM Quantization
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
-```
-
-### LoRA Fine-Tuning on Multiple GPUs
-
-#### Supervised Fine-Tuning on Single Node
-
-```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
```
#### Supervised Fine-Tuning on Multiple Nodes
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
+FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
```
-### LoRA Fine-Tuning on Multiple NPUs
+### QLoRA Fine-Tuning
-#### Supervised Fine-Tuning with DeepSpeed ZeRO-0
+#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
```bash
-ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
```
-### Full-Parameter Fine-Tuning on Multiple GPUs
+#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
+```
+
+#### Supervised Fine-Tuning with 4-bit AWQ Quantization
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
+```
+
+#### Supervised Fine-Tuning with 2-bit AQLM Quantization
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
+```
+
+### Full-Parameter Fine-Tuning
#### Supervised Fine-Tuning on Single Node
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
+FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### Supervised Fine-Tuning on Multiple Nodes
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
+llamafactory-cli train examples/train_full/llama3_full_predict.yaml
```
### Merging LoRA Adapters and Quantization
@@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
+llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
```
#### Quantizing Model using AutoGPTQ
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
+llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
```
### Inferring LoRA Fine-Tuned Models
-Use `CUDA_VISIBLE_DEVICES=0,1` to infer models on multiple devices.
-
#### Use CLI
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
```
#### Use Web UI
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
```
#### Launch OpenAI-style API
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
+llamafactory-cli api examples/inference/llama3_lora_sft.yaml
```
### Extras
@@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
#### Full-Parameter Fine-Tuning using GaLore
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
```
#### Full-Parameter Fine-Tuning using BAdam
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
```
#### LoRA+ Fine-Tuning
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
+llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
+```
+
+#### PiSSA Fine-Tuning
+
+```bash
+llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
```
#### Mixture-of-Depths Fine-Tuning
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
```
#### LLaMA-Pro Fine-Tuning
```bash
bash examples/extras/llama_pro/expand.sh
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
+llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```
#### FSDP+QLoRA Fine-Tuning
```bash
-bash examples/extras/fsdp_qlora/single_node.sh
+bash examples/extras/fsdp_qlora/train.sh
```
diff --git a/examples/README_zh.md b/examples/README_zh.md
index cf5bbf49..d96bf882 100644
--- a/examples/README_zh.md
+++ b/examples/README_zh.md
@@ -4,59 +4,59 @@
## 目录
-- [单 GPU LoRA 微调](#单-gpu-lora-微调)
-- [单 GPU QLoRA 微调](#单-gpu-qlora-微调)
-- [多 GPU LoRA 微调](#多-gpu-lora-微调)
-- [多 NPU LoRA 微调](#多-npu-lora-微调)
-- [多 GPU 全参数微调](#多-gpu-全参数微调)
+- [LoRA 微调](#lora-微调)
+- [QLoRA 微调](#qlora-微调)
+- [全参数微调](#全参数微调)
- [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化)
- [推理 LoRA 模型](#推理-lora-模型)
- [杂项](#杂项)
+使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。
+
## 示例
-### 单 GPU LoRA 微调
+### LoRA 微调
#### (增量)预训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
```
#### 指令监督微调
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
#### 多模态指令监督微调
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
```
#### 奖励模型训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
```
#### PPO 训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
```
#### DPO/ORPO/SimPO 训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO 训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
```
#### 预处理数据集
@@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml
+llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
```
#### 在 MMLU/CMMLU/C-Eval 上评估
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml
+llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
```
-### 单 GPU QLoRA 微调
-
-#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
+#### 多机指令监督微调
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
-```
-
-#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
-```
-
-#### 基于 4 比特 AWQ 量化进行指令监督微调
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
-```
-
-#### 基于 2 比特 AQLM 量化进行指令监督微调
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
-```
-
-### 多 GPU LoRA 微调
-
-#### 在单机上进行指令监督微调
-
-```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
-```
-
-#### 在多机上进行指令监督微调
-
-```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
#### 使用 DeepSpeed ZeRO-3 平均分配显存
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
+FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
```
-### 多 NPU LoRA 微调
+### QLoRA 微调
-#### 使用 DeepSpeed ZeRO-0 进行指令监督微调
+#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
```bash
-ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
```
-### 多 GPU 全参数微调
+#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
+```
+
+#### 基于 4 比特 AWQ 量化进行指令监督微调
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
+```
+
+#### 基于 2 比特 AQLM 量化进行指令监督微调
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
+```
+
+### 全参数微调
#### 在单机上进行指令监督微调
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
+FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### 在多机上进行指令监督微调
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
+llamafactory-cli train examples/train_full/llama3_full_predict.yaml
```
### 合并 LoRA 适配器与模型量化
@@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
+llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
```
#### 使用 AutoGPTQ 量化模型
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
+llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
```
### 推理 LoRA 模型
-使用 `CUDA_VISIBLE_DEVICES=0,1` 进行多卡推理。
-
#### 使用命令行接口
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
```
#### 使用浏览器界面
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
```
#### 启动 OpenAI 风格 API
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
+llamafactory-cli api examples/inference/llama3_lora_sft.yaml
```
### 杂项
@@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
#### 使用 GaLore 进行全参数训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
```
#### 使用 BAdam 进行全参数训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
```
#### LoRA+ 微调
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
+llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
+```
+
+#### PiSSA 微调
+
+```bash
+llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
```
#### 深度混合微调
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
```
#### LLaMA-Pro 微调
```bash
bash examples/extras/llama_pro/expand.sh
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
+llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```
#### FSDP+QLoRA 微调
```bash
-bash examples/extras/fsdp_qlora/single_node.sh
+bash examples/extras/fsdp_qlora/train.sh
```
diff --git a/examples/full_multi_gpu/llama3_full_sft.yaml b/examples/extras/badam/llama3_full_sft.yaml
similarity index 81%
rename from examples/full_multi_gpu/llama3_full_sft.yaml
rename to examples/extras/badam/llama3_full_sft.yaml
index 40b62f24..31d61c33 100644
--- a/examples/full_multi_gpu/llama3_full_sft.yaml
+++ b/examples/extras/badam/llama3_full_sft.yaml
@@ -5,10 +5,11 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: sft
do_train: true
finetuning_type: full
-
-### ddp
-ddp_timeout: 180000000
-deepspeed: examples/deepspeed/ds_z3_config.json
+use_badam: true
+badam_mode: layer
+badam_switch_mode: ascending
+badam_switch_interval: 50
+badam_verbose: 2
### dataset
dataset: identity,alpaca_en_demo
@@ -27,12 +28,11 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
-gradient_accumulation_steps: 2
+gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
### eval
val_size: 0.1
diff --git a/examples/extras/badam/llama3_lora_sft.yaml b/examples/extras/badam/llama3_full_sft_ds3.yaml
similarity index 91%
rename from examples/extras/badam/llama3_lora_sft.yaml
rename to examples/extras/badam/llama3_full_sft_ds3.yaml
index a78de2fa..f2d7309f 100644
--- a/examples/extras/badam/llama3_lora_sft.yaml
+++ b/examples/extras/badam/llama3_full_sft_ds3.yaml
@@ -6,9 +6,11 @@ stage: sft
do_train: true
finetuning_type: full
use_badam: true
+badam_mode: layer
badam_switch_mode: ascending
badam_switch_interval: 50
badam_verbose: 2
+deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: identity,alpaca_en_demo
@@ -32,7 +34,6 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-pure_bf16: true
### eval
val_size: 0.1
diff --git a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml
index 084269ef..6c80ef58 100644
--- a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml
+++ b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml
@@ -8,9 +8,6 @@ do_train: true
finetuning_type: lora
lora_target: all
-### ddp
-ddp_timeout: 180000000
-
### dataset
dataset: identity,alpaca_en_demo
template: llama3
@@ -33,7 +30,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/extras/fsdp_qlora/single_node.sh b/examples/extras/fsdp_qlora/train.sh
similarity index 100%
rename from examples/extras/fsdp_qlora/single_node.sh
rename to examples/extras/fsdp_qlora/train.sh
diff --git a/examples/extras/llama_pro/llama3_freeze_sft.yaml b/examples/extras/llama_pro/llama3_freeze_sft.yaml
index 444a1113..5e7e90bb 100644
--- a/examples/extras/llama_pro/llama3_freeze_sft.yaml
+++ b/examples/extras/llama_pro/llama3_freeze_sft.yaml
@@ -31,7 +31,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/extras/loraplus/llama3_lora_sft.yaml b/examples/extras/loraplus/llama3_lora_sft.yaml
index 1ba654ec..062a312b 100644
--- a/examples/extras/loraplus/llama3_lora_sft.yaml
+++ b/examples/extras/loraplus/llama3_lora_sft.yaml
@@ -30,7 +30,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/extras/mod/llama3_full_sft.yaml b/examples/extras/mod/llama3_full_sft.yaml
index df03c1e0..085febfc 100644
--- a/examples/extras/mod/llama3_full_sft.yaml
+++ b/examples/extras/mod/llama3_full_sft.yaml
@@ -31,6 +31,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
pure_bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml b/examples/extras/pissa/llama3_lora_sft.yaml
similarity index 88%
rename from examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
rename to examples/extras/pissa/llama3_lora_sft.yaml
index b308dcab..05077b6c 100644
--- a/examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
+++ b/examples/extras/pissa/llama3_lora_sft.yaml
@@ -1,12 +1,14 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
-quantization_bit: 4
### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
+pissa_init: true
+pissa_iter: 4
+pissa_convert: true
### dataset
dataset: identity,alpaca_en_demo
@@ -30,7 +32,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/full_multi_gpu/llama3_full_predict.yaml b/examples/train_full/llama3_full_predict.yaml
similarity index 100%
rename from examples/full_multi_gpu/llama3_full_predict.yaml
rename to examples/train_full/llama3_full_predict.yaml
diff --git a/examples/lora_multi_gpu/llama3_lora_sft.yaml b/examples/train_full/llama3_full_sft_ds3.yaml
similarity index 83%
rename from examples/lora_multi_gpu/llama3_lora_sft.yaml
rename to examples/train_full/llama3_full_sft_ds3.yaml
index 348e53b9..c983ad5c 100644
--- a/examples/lora_multi_gpu/llama3_lora_sft.yaml
+++ b/examples/train_full/llama3_full_sft_ds3.yaml
@@ -4,11 +4,8 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
### method
stage: sft
do_train: true
-finetuning_type: lora
-lora_target: all
-
-### ddp
-ddp_timeout: 180000000
+finetuning_type: full
+deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: identity,alpaca_en_demo
@@ -19,7 +16,7 @@ overwrite_cache: true
preprocessing_num_workers: 16
### output
-output_dir: saves/llama3-8b/lora/sft
+output_dir: saves/llama3-8b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
@@ -32,7 +29,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_lora_dpo.yaml b/examples/train_lora/llama3_lora_dpo.yaml
similarity index 87%
rename from examples/lora_single_gpu/llama3_lora_dpo.yaml
rename to examples/train_lora/llama3_lora_dpo.yaml
index 78344330..d87c0669 100644
--- a/examples/lora_single_gpu/llama3_lora_dpo.yaml
+++ b/examples/train_lora/llama3_lora_dpo.yaml
@@ -7,7 +7,7 @@ do_train: true
finetuning_type: lora
lora_target: all
pref_beta: 0.1
-pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]
+pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
### dataset
dataset: dpo_en_demo
@@ -31,7 +31,8 @@ learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_lora_eval.yaml b/examples/train_lora/llama3_lora_eval.yaml
similarity index 100%
rename from examples/lora_single_gpu/llama3_lora_eval.yaml
rename to examples/train_lora/llama3_lora_eval.yaml
diff --git a/examples/lora_single_gpu/llama3_lora_kto.yaml b/examples/train_lora/llama3_lora_kto.yaml
similarity index 93%
rename from examples/lora_single_gpu/llama3_lora_kto.yaml
rename to examples/train_lora/llama3_lora_kto.yaml
index d5234c0a..08208c25 100644
--- a/examples/lora_single_gpu/llama3_lora_kto.yaml
+++ b/examples/train_lora/llama3_lora_kto.yaml
@@ -6,6 +6,7 @@ stage: kto
do_train: true
finetuning_type: lora
lora_target: all
+pref_beta: 0.1
### dataset
dataset: kto_en_demo
@@ -29,7 +30,8 @@ learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_lora_ppo.yaml b/examples/train_lora/llama3_lora_ppo.yaml
similarity index 95%
rename from examples/lora_single_gpu/llama3_lora_ppo.yaml
rename to examples/train_lora/llama3_lora_ppo.yaml
index 98c842f9..512e90ea 100644
--- a/examples/lora_single_gpu/llama3_lora_ppo.yaml
+++ b/examples/train_lora/llama3_lora_ppo.yaml
@@ -30,7 +30,8 @@ learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### generate
max_new_tokens: 512
diff --git a/examples/lora_single_gpu/llama3_lora_predict.yaml b/examples/train_lora/llama3_lora_predict.yaml
similarity index 95%
rename from examples/lora_single_gpu/llama3_lora_predict.yaml
rename to examples/train_lora/llama3_lora_predict.yaml
index a127d248..148c8635 100644
--- a/examples/lora_single_gpu/llama3_lora_predict.yaml
+++ b/examples/train_lora/llama3_lora_predict.yaml
@@ -22,3 +22,4 @@ overwrite_output_dir: true
### eval
per_device_eval_batch_size: 1
predict_with_generate: true
+ddp_timeout: 180000000
diff --git a/examples/lora_single_gpu/llama3_lora_pretrain.yaml b/examples/train_lora/llama3_lora_pretrain.yaml
similarity index 94%
rename from examples/lora_single_gpu/llama3_lora_pretrain.yaml
rename to examples/train_lora/llama3_lora_pretrain.yaml
index db435ca9..5e8aaaef 100644
--- a/examples/lora_single_gpu/llama3_lora_pretrain.yaml
+++ b/examples/train_lora/llama3_lora_pretrain.yaml
@@ -28,7 +28,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_lora_reward.yaml b/examples/train_lora/llama3_lora_reward.yaml
similarity index 91%
rename from examples/lora_single_gpu/llama3_lora_reward.yaml
rename to examples/train_lora/llama3_lora_reward.yaml
index 1ce42ea4..96c32238 100644
--- a/examples/lora_single_gpu/llama3_lora_reward.yaml
+++ b/examples/train_lora/llama3_lora_reward.yaml
@@ -25,11 +25,12 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
-learning_rate: 1.0e-5
+learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml
similarity index 95%
rename from examples/lora_single_gpu/llama3_lora_sft.yaml
rename to examples/train_lora/llama3_lora_sft.yaml
index 651b636f..55a8077e 100644
--- a/examples/lora_single_gpu/llama3_lora_sft.yaml
+++ b/examples/train_lora/llama3_lora_sft.yaml
@@ -29,7 +29,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_multi_npu/llama3_lora_sft_ds.yaml b/examples/train_lora/llama3_lora_sft_ds0.yaml
similarity index 97%
rename from examples/lora_multi_npu/llama3_lora_sft_ds.yaml
rename to examples/train_lora/llama3_lora_sft_ds0.yaml
index a0ec8aa1..f1442faa 100644
--- a/examples/lora_multi_npu/llama3_lora_sft_ds.yaml
+++ b/examples/train_lora/llama3_lora_sft_ds0.yaml
@@ -6,9 +6,6 @@ stage: sft
do_train: true
finetuning_type: lora
lora_target: all
-
-### ddp
-ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z0_config.json
### dataset
@@ -33,7 +30,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_multi_gpu/llama3_lora_sft_ds.yaml b/examples/train_lora/llama3_lora_sft_ds3.yaml
similarity index 97%
rename from examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
rename to examples/train_lora/llama3_lora_sft_ds3.yaml
index 1c432fa7..66e7007e 100644
--- a/examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
+++ b/examples/train_lora/llama3_lora_sft_ds3.yaml
@@ -6,9 +6,6 @@ stage: sft
do_train: true
finetuning_type: lora
lora_target: all
-
-### ddp
-ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
@@ -33,7 +30,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_preprocess.yaml b/examples/train_lora/llama3_preprocess.yaml
similarity index 100%
rename from examples/lora_single_gpu/llama3_preprocess.yaml
rename to examples/train_lora/llama3_preprocess.yaml
diff --git a/examples/lora_single_gpu/llava1_5_lora_sft.yaml b/examples/train_lora/llava1_5_lora_sft.yaml
similarity index 95%
rename from examples/lora_single_gpu/llava1_5_lora_sft.yaml
rename to examples/train_lora/llava1_5_lora_sft.yaml
index df510a93..ec03f82c 100644
--- a/examples/lora_single_gpu/llava1_5_lora_sft.yaml
+++ b/examples/train_lora/llava1_5_lora_sft.yaml
@@ -30,7 +30,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml b/examples/train_qlora/llama3_lora_sft_aqlm.yaml
similarity index 95%
rename from examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
rename to examples/train_qlora/llama3_lora_sft_aqlm.yaml
index d54d6af6..3519d46b 100644
--- a/examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
+++ b/examples/train_qlora/llama3_lora_sft_aqlm.yaml
@@ -29,7 +29,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/qlora_single_gpu/llama3_lora_sft_awq.yaml b/examples/train_qlora/llama3_lora_sft_awq.yaml
similarity index 95%
rename from examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
rename to examples/train_qlora/llama3_lora_sft_awq.yaml
index 5cef178a..df48669b 100644
--- a/examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
+++ b/examples/train_qlora/llama3_lora_sft_awq.yaml
@@ -29,7 +29,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml b/examples/train_qlora/llama3_lora_sft_gptq.yaml
similarity index 95%
rename from examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
rename to examples/train_qlora/llama3_lora_sft_gptq.yaml
index b950042e..61fa9bb4 100644
--- a/examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
+++ b/examples/train_qlora/llama3_lora_sft_gptq.yaml
@@ -29,7 +29,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/train_qlora/llama3_lora_sft_otfq.yaml b/examples/train_qlora/llama3_lora_sft_otfq.yaml
new file mode 100644
index 00000000..80a05768
--- /dev/null
+++ b/examples/train_qlora/llama3_lora_sft_otfq.yaml
@@ -0,0 +1,41 @@
+### model
+model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
+quantization_bit: 4
+quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)]
+
+### method
+stage: sft
+do_train: true
+finetuning_type: lora
+lora_target: all
+
+### dataset
+dataset: identity,alpaca_en_demo
+template: llama3
+cutoff_len: 1024
+max_samples: 1000
+overwrite_cache: true
+preprocessing_num_workers: 16
+
+### output
+output_dir: saves/llama3-8b/lora/sft
+logging_steps: 10
+save_steps: 500
+plot_loss: true
+overwrite_output_dir: true
+
+### train
+per_device_train_batch_size: 1
+gradient_accumulation_steps: 8
+learning_rate: 1.0e-4
+num_train_epochs: 3.0
+lr_scheduler_type: cosine
+warmup_ratio: 0.1
+bf16: true
+ddp_timeout: 180000000
+
+### eval
+val_size: 0.1
+per_device_eval_batch_size: 1
+eval_strategy: steps
+eval_steps: 500
diff --git a/requirements.txt b/requirements.txt
index 9e00555e..7380add4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,6 +4,7 @@ accelerate>=0.30.1
peft>=0.11.1
trl>=0.8.6
gradio>=4.0.0
+pandas>=2.0.0
scipy
einops
sentencepiece
@@ -17,3 +18,4 @@ matplotlib>=3.7.0
fire
packaging
pyyaml
+numpy<2.0.0
diff --git a/scripts/cal_flops.py b/scripts/cal_flops.py
index ac87e0ab..32526d89 100644
--- a/scripts/cal_flops.py
+++ b/scripts/cal_flops.py
@@ -1,7 +1,20 @@
# coding=utf-8
-# Calculates the flops of pre-trained models.
-# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
-# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
+# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
+#
+# This code is inspired by the Microsoft's DeepSpeed library.
+# https://www.deepspeed.ai/tutorials/flops-profiler/
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import fire
import torch
@@ -17,6 +30,10 @@ def calculate_flops(
seq_length: int = 256,
flash_attn: str = "auto",
):
+ r"""
+ Calculates the flops of pre-trained models.
+ Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
+ """
with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
diff --git a/scripts/cal_lr.py b/scripts/cal_lr.py
index bfa32cc9..ad6992cb 100644
--- a/scripts/cal_lr.py
+++ b/scripts/cal_lr.py
@@ -1,7 +1,20 @@
# coding=utf-8
-# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
-# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
-# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
+# Copyright 2024 imoneoi and the LlamaFactory team.
+#
+# This code is inspired by the imoneoi's OpenChat library.
+# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import math
from typing import Literal
@@ -32,6 +45,10 @@ def calculate_lr(
cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate,
):
+ r"""
+ Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
+ Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
+ """
model_args, data_args, training_args, _, _ = get_train_args(
dict(
stage=stage,
diff --git a/scripts/cal_ppl.py b/scripts/cal_ppl.py
index 387b756c..fb503629 100644
--- a/scripts/cal_ppl.py
+++ b/scripts/cal_ppl.py
@@ -1,6 +1,17 @@
# coding=utf-8
-# Calculates the ppl on the dataset of the pre-trained models.
-# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
from dataclasses import dataclass
@@ -56,6 +67,10 @@ def cal_ppl(
max_samples: Optional[int] = None,
train_on_prompt: bool = False,
):
+ r"""
+ Calculates the ppl on the dataset of the pre-trained models.
+ Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
+ """
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(
stage=stage,
diff --git a/scripts/length_cdf.py b/scripts/length_cdf.py
index 7739dcf0..4cdf01e6 100644
--- a/scripts/length_cdf.py
+++ b/scripts/length_cdf.py
@@ -1,6 +1,17 @@
# coding=utf-8
-# Calculates the distribution of the input lengths in the dataset.
-# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from collections import defaultdict
@@ -19,6 +30,10 @@ def length_cdf(
template: str = "default",
interval: int = 1000,
):
+ r"""
+ Calculates the distribution of the input lengths in the dataset.
+ Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
+ """
model_args, data_args, training_args, _, _ = get_train_args(
dict(
stage="sft",
diff --git a/scripts/llama_pro.py b/scripts/llama_pro.py
index 727998ae..17bf6fc2 100644
--- a/scripts/llama_pro.py
+++ b/scripts/llama_pro.py
@@ -1,7 +1,20 @@
# coding=utf-8
-# Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
-# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
-# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
+# Copyright 2024 Tencent Inc. and the LlamaFactory team.
+#
+# This code is inspired by the Tencent's LLaMA-Pro library.
+# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
import os
@@ -37,6 +50,10 @@ def block_expansion(
shard_size: Optional[str] = "2GB",
save_safetensors: Optional[bool] = False,
):
+ r"""
+ Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
+ Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
+ """
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
num_layers = getattr(config, "num_hidden_layers")
setattr(config, "num_hidden_layers", num_layers + num_expand)
@@ -103,7 +120,7 @@ def block_expansion(
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
- print("Fine-tune this model with:")
+ print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("finetuning_type: freeze")
print("freeze_trainable_layers: {}".format(num_expand))
diff --git a/scripts/llamafy_baichuan2.py b/scripts/llamafy_baichuan2.py
index 1ae58879..19284f5f 100644
--- a/scripts/llamafy_baichuan2.py
+++ b/scripts/llamafy_baichuan2.py
@@ -1,8 +1,17 @@
# coding=utf-8
-# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
-# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
-# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
-# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
import os
@@ -79,6 +88,11 @@ def save_config(input_dir: str, output_dir: str):
def llamafy_baichuan2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
):
+ r"""
+ Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
+ Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
+ Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
+ """
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
diff --git a/scripts/llamafy_qwen.py b/scripts/llamafy_qwen.py
index 69cf3e8e..e5b59483 100644
--- a/scripts/llamafy_qwen.py
+++ b/scripts/llamafy_qwen.py
@@ -1,7 +1,17 @@
# coding=utf-8
-# Converts the Qwen models in the same format as LLaMA2.
-# Usage: python llamafy_qwen.py --input_dir input --output_dir output
-# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
import os
@@ -131,6 +141,11 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
def llamafy_qwen(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
):
+ r"""
+ Converts the Qwen models in the same format as LLaMA2.
+ Usage: python llamafy_qwen.py --input_dir input --output_dir output
+ Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
+ """
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
diff --git a/scripts/loftq_init.py b/scripts/loftq_init.py
index 7f244316..4d2c01b9 100644
--- a/scripts/loftq_init.py
+++ b/scripts/loftq_init.py
@@ -1,14 +1,25 @@
# coding=utf-8
-# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
-# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir
-# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is based on the HuggingFace's PEFT library.
+# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import os
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING
import fire
-import torch
-import torch.nn as nn
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -17,65 +28,61 @@ if TYPE_CHECKING:
from transformers import PreTrainedModel
-class Shell(nn.Module):
- def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
- super().__init__()
- self.weight = nn.Parameter(weight, requires_grad=False)
- if bias is not None:
- self.bias = nn.Parameter(bias, requires_grad=False)
-
-
-def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
- for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
- parent_name = ".".join(name.split(".")[:-1])
- child_name = name.split(".")[-1]
- parent_module = model.get_submodule(parent_name)
- child_module = getattr(parent_module, child_name)
- base_layer = getattr(child_module, "base_layer")
- weight = getattr(base_layer, "weight", None)
- bias = getattr(base_layer, "bias", None)
- setattr(parent_module, child_name, Shell(weight, bias))
-
- print("Model unwrapped.")
-
-
def quantize_loftq(
model_name_or_path: str,
- save_dir: str,
- loftq_bits: Optional[int] = 4,
- loftq_iter: Optional[int] = 1,
- lora_alpha: Optional[int] = None,
- lora_rank: Optional[int] = 16,
- lora_target: Optional[str] = "q_proj,v_proj",
- save_safetensors: Optional[bool] = False,
+ output_dir: str,
+ loftq_bits: int = 4,
+ loftq_iter: int = 4,
+ lora_alpha: int = None,
+ lora_rank: int = 16,
+ lora_dropout: float = 0,
+ lora_target: tuple = ("q_proj", "v_proj"),
+ save_safetensors: bool = True,
):
+ r"""
+ Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
+ Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
+ """
+ if isinstance(lora_target, str):
+ lora_target = [name.strip() for name in lora_target.split(",")]
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
+
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=True,
r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
- lora_dropout=0.1,
- target_modules=[name.strip() for name in lora_target.split(",")],
+ lora_dropout=lora_dropout,
+ target_modules=lora_target,
init_lora_weights="loftq",
loftq_config=loftq_config,
)
# Init LoftQ model
- lora_model = get_peft_model(model, lora_config)
- base_model: "PreTrainedModel" = lora_model.get_base_model()
+ print("Initializing LoftQ weights, it may be take several minutes, wait patiently.")
+ peft_model = get_peft_model(model, lora_config)
+ loftq_dir = os.path.join(output_dir, "loftq_init")
# Save LoftQ model
- setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir)
- setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True)
- lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors)
+ setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir)
+ setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
+ peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
+ print("Adapter weights saved in {}".format(loftq_dir))
# Save base model
- unwrap_model(base_model)
- base_model.save_pretrained(save_dir, safe_serialization=save_safetensors)
- tokenizer.save_pretrained(save_dir)
+ base_model: "PreTrainedModel" = peft_model.unload()
+ base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
+ tokenizer.save_pretrained(output_dir)
+ print("Model weights saved in {}".format(output_dir))
+
+ print("- Fine-tune this model with:")
+ print("model_name_or_path: {}".format(output_dir))
+ print("adapter_name_or_path: {}".format(loftq_dir))
+ print("finetuning_type: lora")
+ print("quantization_bit: {}".format(loftq_bits))
if __name__ == "__main__":
diff --git a/scripts/pissa_init.py b/scripts/pissa_init.py
new file mode 100644
index 00000000..ad9d161c
--- /dev/null
+++ b/scripts/pissa_init.py
@@ -0,0 +1,86 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is based on the HuggingFace's PEFT library.
+# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import TYPE_CHECKING
+
+import fire
+from peft import LoraConfig, TaskType, get_peft_model
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+
+
+def quantize_pissa(
+ model_name_or_path: str,
+ output_dir: str,
+ pissa_iter: int = 4,
+ lora_alpha: int = None,
+ lora_rank: int = 16,
+ lora_dropout: float = 0,
+ lora_target: tuple = ("q_proj", "v_proj"),
+ save_safetensors: bool = True,
+):
+ r"""
+ Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
+ Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
+ """
+ if isinstance(lora_target, str):
+ lora_target = [name.strip() for name in lora_target.split(",")]
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
+
+ lora_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=lora_rank,
+ lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
+ lora_dropout=lora_dropout,
+ target_modules=lora_target,
+ init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
+ )
+
+ # Init PiSSA model
+ peft_model = get_peft_model(model, lora_config)
+ pissa_dir = os.path.join(output_dir, "pissa_init")
+
+ # Save PiSSA model
+ setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
+ peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
+ print("Adapter weights saved in {}".format(pissa_dir))
+
+ # Save base model
+ base_model: "PreTrainedModel" = peft_model.unload()
+ base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
+ tokenizer.save_pretrained(output_dir)
+ print("Model weights saved in {}".format(output_dir))
+
+ print("- Fine-tune this model with:")
+ print("model_name_or_path: {}".format(output_dir))
+ print("adapter_name_or_path: {}".format(pissa_dir))
+ print("finetuning_type: lora")
+ print("pissa_init: false")
+ print("pissa_convert: true")
+ print("- and optionally with:")
+ print("quantization_bit: 4")
+
+
+if __name__ == "__main__":
+ fire.Fire(quantize_pissa)
diff --git a/scripts/test_toolcall.py b/scripts/test_toolcall.py
index 7e460017..6f6fd06c 100644
--- a/scripts/test_toolcall.py
+++ b/scripts/test_toolcall.py
@@ -1,3 +1,18 @@
+# coding=utf-8
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from typing import Sequence
diff --git a/setup.py b/setup.py
index 405ac46e..d43c311c 100644
--- a/setup.py
+++ b/setup.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import re
@@ -23,14 +37,16 @@ extra_require = {
"torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
- "deepspeed": ["deepspeed>=0.10.0,<=0.14.0"],
+ "deepspeed": ["deepspeed>=0.10.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
- "vllm": ["vllm>=0.4.3"],
- "galore": ["galore-torch"],
- "badam": ["badam"],
- "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
+ "hqq": ["hqq"],
+ "eetq": ["eetq"],
+ "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
+ "vllm": ["vllm>=0.4.3"],
+ "galore": ["galore-torch"],
+ "badam": ["badam>=1.2.1"],
"qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"],
"dev": ["ruff", "pytest"],
diff --git a/src/api.py b/src/api.py
index 3655e393..0f925497 100644
--- a/src/api.py
+++ b/src/api.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import uvicorn
diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py
index 78230937..9d732777 100644
--- a/src/llamafactory/__init__.py
+++ b/src/llamafactory/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Level: api, webui > chat, eval, train > data, model > hparams > extras
from .cli import VERSION
diff --git a/src/llamafactory/api/app.py b/src/llamafactory/api/app.py
index 21edab2f..c1264617 100644
--- a/src/llamafactory/api/app.py
+++ b/src/llamafactory/api/app.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from contextlib import asynccontextmanager
from typing import Optional
diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py
index 98957bc1..72b2ae50 100644
--- a/src/llamafactory/api/chat.py
+++ b/src/llamafactory/api/chat.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import base64
import io
import json
@@ -78,9 +92,11 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
- name = message.tool_calls[0].function.name
- arguments = message.tool_calls[0].function.arguments
- content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
+ tool_calls = [
+ {"name": tool_call.function.name, "arguments": tool_call.function.arguments}
+ for tool_call in message.tool_calls
+ ]
+ content = json.dumps(tool_calls, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
for input_item in message.content:
@@ -104,7 +120,7 @@ def _process_request(
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
- except Exception:
+ except json.JSONDecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = None
@@ -146,15 +162,17 @@ async def create_chat_completion_response(
choices = []
for i, response in enumerate(responses):
if tools:
- result = chat_model.engine.template.format_tools.extract(response.response_text)
+ result = chat_model.engine.template.extract_tool(response.response_text)
else:
result = response.response_text
- if isinstance(result, tuple):
- name, arguments = result
- function = Function(name=name, arguments=arguments)
- tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
- response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
+ if isinstance(result, list):
+ tool_calls = []
+ for tool in result:
+ function = Function(name=tool[0], arguments=tool[1])
+ tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
+
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
diff --git a/src/llamafactory/api/common.py b/src/llamafactory/api/common.py
index 5ad9a071..d1ac94de 100644
--- a/src/llamafactory/api/common.py
+++ b/src/llamafactory/api/common.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
from typing import TYPE_CHECKING, Any, Dict
diff --git a/src/llamafactory/api/protocol.py b/src/llamafactory/api/protocol.py
index 055fa781..a69132ea 100644
--- a/src/llamafactory/api/protocol.py
+++ b/src/llamafactory/api/protocol.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import time
from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union
diff --git a/src/llamafactory/chat/__init__.py b/src/llamafactory/chat/__init__.py
index a1a79de6..07276d48 100644
--- a/src/llamafactory/chat/__init__.py
+++ b/src/llamafactory/chat/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .base_engine import BaseEngine
from .chat_model import ChatModel
diff --git a/src/llamafactory/chat/base_engine.py b/src/llamafactory/chat/base_engine.py
index 65b6c59c..ccdf4c92 100644
--- a/src/llamafactory/chat/base_engine.py
+++ b/src/llamafactory/chat/base_engine.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
@@ -36,11 +50,6 @@ class BaseEngine(ABC):
generating_args: "GeneratingArguments",
) -> None: ...
- @abstractmethod
- async def start(
- self,
- ) -> None: ...
-
@abstractmethod
async def chat(
self,
diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py
index 281ef0c1..5c83fa67 100644
--- a/src/llamafactory/chat/chat_model.py
+++ b/src/llamafactory/chat/chat_model.py
@@ -1,3 +1,20 @@
+# Copyright 2024 THUDM and the LlamaFactory team.
+#
+# This code is inspired by the THUDM's ChatGLM implementation.
+# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import asyncio
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
@@ -14,7 +31,7 @@ if TYPE_CHECKING:
from .base_engine import BaseEngine, Response
-def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
+def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
@@ -32,7 +49,6 @@ class ChatModel:
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
- asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
def chat(
self,
diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py
index 28e6a409..22a24339 100644
--- a/src/llamafactory/chat/hf_engine.py
+++ b/src/llamafactory/chat/hf_engine.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import asyncio
import concurrent.futures
import os
@@ -40,11 +54,19 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
- self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab
self.generating_args = generating_args.to_dict()
+ try:
+ asyncio.get_event_loop()
+ except RuntimeError:
+ logger.warning("There is no current event loop, creating a new one.")
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
@staticmethod
def _process_args(
@@ -245,9 +267,6 @@ class HuggingfaceEngine(BaseEngine):
return scores
- async def start(self) -> None:
- self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
-
async def chat(
self,
messages: Sequence[Dict[str, str]],
@@ -272,7 +291,7 @@ class HuggingfaceEngine(BaseEngine):
image,
input_kwargs,
)
- async with self._semaphore:
+ async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
@@ -300,7 +319,7 @@ class HuggingfaceEngine(BaseEngine):
image,
input_kwargs,
)
- async with self._semaphore:
+ async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args)
while True:
@@ -319,6 +338,6 @@ class HuggingfaceEngine(BaseEngine):
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
- async with self._semaphore:
+ async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args)
diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py
index 87ce8684..f0d23676 100644
--- a/src/llamafactory/chat/vllm_engine.py
+++ b/src/llamafactory/chat/vllm_engine.py
@@ -1,10 +1,24 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
-from ..extras.packages import is_vllm_available
+from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
from ..model import load_config, load_tokenizer
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
@@ -13,7 +27,11 @@ from .base_engine import BaseEngine, Response
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
- from vllm.sequence import MultiModalData
+
+ if is_vllm_version_greater_than_0_5():
+ from vllm.multimodal.image import ImagePixelData
+ else:
+ from vllm.sequence import MultiModalData
if TYPE_CHECKING:
@@ -41,14 +59,14 @@ class VllmEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
- self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
self.generating_args = generating_args.to_dict()
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"download_dir": model_args.cache_dir,
- "dtype": model_args.vllm_dtype,
+ "dtype": model_args.infer_dtype,
"max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
@@ -106,7 +124,10 @@ class VllmEngine(BaseEngine):
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
- multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
+ if is_vllm_version_greater_than_0_5():
+ multi_modal_data = ImagePixelData(image=pixel_values)
+ else: # TODO: remove vllm 0.4.3 support
+ multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
@@ -162,9 +183,6 @@ class VllmEngine(BaseEngine):
)
return result_generator
- async def start(self) -> None:
- pass
-
async def chat(
self,
messages: Sequence[Dict[str, str]],
diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py
index 5042e53c..48eb2898 100644
--- a/src/llamafactory/cli.py
+++ b/src/llamafactory/cli.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import random
import subprocess
@@ -60,7 +74,7 @@ class Command(str, Enum):
def main():
- command = sys.argv.pop(1)
+ command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
if command == Command.API:
run_api()
elif command == Command.CHAT:
@@ -77,7 +91,7 @@ def main():
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
- subprocess.run(
+ process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
@@ -92,6 +106,7 @@ def main():
),
shell=True,
)
+ sys.exit(process.returncode)
else:
run_exp()
elif command == Command.WEBDEMO:
diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py
index b08691d3..307853bc 100644
--- a/src/llamafactory/data/__init__.py
+++ b/src/llamafactory/data/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .data_utils import Role, split_dataset
from .loader import get_dataset
diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py
index 434956af..299bdca3 100644
--- a/src/llamafactory/data/aligner.py
+++ b/src/llamafactory/data/aligner.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
@@ -10,6 +24,7 @@ from .data_utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
+ from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .parser import DatasetAttr
@@ -175,7 +190,10 @@ def convert_sharegpt(
def align_dataset(
- dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
+ dataset: Union["Dataset", "IterableDataset"],
+ dataset_attr: "DatasetAttr",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
@@ -208,7 +226,7 @@ def align_dataset(
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
- load_from_cache_file=(not data_args.overwrite_cache),
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset",
)
diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py
index 1dc8dd8d..e4859ff5 100644
--- a/src/llamafactory/data/collator.py
+++ b/src/llamafactory/data/collator.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
from typing import Any, Dict, Sequence
diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py
index 9b313112..76ded47e 100644
--- a/src/llamafactory/data/data_utils.py
+++ b/src/llamafactory/data/data_utils.py
@@ -1,5 +1,19 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from enum import Enum, unique
-from typing import TYPE_CHECKING, Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
from datasets import concatenate_datasets, interleave_datasets
@@ -16,6 +30,9 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
+SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
+
+
@unique
class Role(str, Enum):
USER = "user"
@@ -25,13 +42,6 @@ class Role(str, Enum):
OBSERVATION = "observation"
-def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
- max_target_len = int(max_len * (target_len / (source_len + target_len)))
- max_target_len = max(max_target_len, reserved_label_len)
- max_source_len = max_len - min(max_target_len, target_len)
- return max_source_len, max_target_len
-
-
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py
index 0cd3d6c1..c1653a76 100644
--- a/src/llamafactory/data/formatter.py
+++ b/src/llamafactory/data/formatter.py
@@ -1,83 +1,36 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
-from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
-
-SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
-
-
-JSON_FORMAT_PROMPT = (
- """, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
-)
-
-
-TOOL_SYSTEM_PROMPT = (
- "You have access to the following tools:\n{tool_text}"
- "Use the following format if using a tool:\n"
- "```\n"
- "Action: tool name (one of [{tool_names}]).\n"
- "Action Input: the input to the tool{format_prompt}.\n"
- "```\n"
-)
-
-
-def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
- tool_text = ""
- tool_names = []
- for tool in tools:
- param_text = ""
- for name, param in tool["parameters"]["properties"].items():
- required = ", required" if name in tool["parameters"].get("required", []) else ""
- enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
- items = (
- ", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
- )
- param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
- name=name,
- type=param.get("type", ""),
- required=required,
- desc=param.get("description", ""),
- enum=enum,
- items=items,
- )
-
- tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
- name=tool["name"], desc=tool.get("description", ""), args=param_text
- )
- tool_names.append(tool["name"])
-
- return TOOL_SYSTEM_PROMPT.format(
- tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
- )
-
-
-def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
- regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
- action_match = re.search(regex, content)
- if not action_match:
- return content
-
- tool_name = action_match.group(1).strip()
- tool_input = action_match.group(2).strip().strip('"').strip("```")
- try:
- arguments = json.loads(tool_input)
- except json.JSONDecodeError:
- return content
-
- return tool_name, json.dumps(arguments, ensure_ascii=False)
+from .data_utils import SLOTS
+from .tool_utils import DefaultToolUtils, GLM4ToolUtils
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
- tool_format: Optional[Literal["default"]] = None
+ tool_format: Optional[Literal["default", "glm4"]] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
- def extract(self, content: str) -> Union[str, Tuple[str, str]]:
+ def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
raise NotImplementedError
@@ -128,34 +81,37 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
- has_name, has_args = False, False
- for slot in filter(lambda s: isinstance(s, str), self.slots):
- if "{{name}}" in slot:
- has_name = True
- if "{{arguments}}" in slot:
- has_args = True
-
- if not has_name or not has_args:
- raise ValueError("Name and arguments placeholders are required in the function formatter.")
+ if self.tool_format == "default":
+ self.slots = DefaultToolUtils.get_function_slots() + self.slots
+ elif self.tool_format == "glm4":
+ self.slots = GLM4ToolUtils.get_function_slots() + self.slots
+ else:
+ raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
+ functions: List[Tuple[str, str]] = []
try:
- function = json.loads(content)
- name = function["name"]
- arguments = json.dumps(function["arguments"], ensure_ascii=False)
- except Exception:
- name, arguments = "", ""
+ tool_calls = json.loads(content)
+ if not isinstance(tool_calls, list): # parallel function call
+ tool_calls = [tool_calls]
+
+ for tool_call in tool_calls:
+ functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
+
+ except json.JSONDecodeError:
+ functions = []
elements = []
- for slot in self.slots:
- if isinstance(slot, str):
- slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
- elements.append(slot)
- elif isinstance(slot, (dict, set)):
- elements.append(slot)
- else:
- raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
+ for name, arguments in functions:
+ for slot in self.slots:
+ if isinstance(slot, str):
+ slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
+ elements.append(slot)
+ elif isinstance(slot, (dict, set)):
+ elements.append(slot)
+ else:
+ raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@@ -163,25 +119,22 @@ class FunctionFormatter(Formatter):
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
- if self.tool_format is None:
- raise ValueError("Tool format was not found.")
+ if self.tool_format == "default":
+ self._tool_formatter = DefaultToolUtils.tool_formatter
+ self._tool_extractor = DefaultToolUtils.tool_extractor
+ elif self.tool_format == "glm4":
+ self._tool_formatter = GLM4ToolUtils.tool_formatter
+ self._tool_extractor = GLM4ToolUtils.tool_extractor
+ else:
+ raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
- if not len(tools):
- return [""]
-
- if self.tool_format == "default":
- return [default_tool_formatter(tools)]
- else:
- raise NotImplementedError
- except Exception:
+ return [self._tool_formatter(tools) if len(tools) != 0 else ""]
+ except json.JSONDecodeError:
return [""]
- def extract(self, content: str) -> Union[str, Tuple[str, str]]:
- if self.tool_format == "default":
- return default_tool_extractor(content)
- else:
- raise NotImplementedError
+ def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
+ return self._tool_extractor(content)
diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py
index 2c236c76..8e7062db 100644
--- a/src/llamafactory/data/loader.py
+++ b/src/llamafactory/data/loader.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import inspect
import os
import sys
@@ -18,8 +32,7 @@ from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
- from transformers import ProcessorMixin, Seq2SeqTrainingArguments
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr
@@ -32,6 +45,7 @@ def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
@@ -123,7 +137,7 @@ def load_single_dataset(
max_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(max_samples))
- return align_dataset(dataset, dataset_attr, data_args)
+ return align_dataset(dataset, dataset_attr, data_args, training_args)
def get_dataset(
@@ -134,7 +148,7 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:
- template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
+ template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
@@ -157,7 +171,8 @@ def get_dataset(
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
- all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
+ all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
+
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
@@ -169,7 +184,7 @@ def get_dataset(
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
- load_from_cache_file=(not data_args.overwrite_cache),
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Running tokenizer on dataset",
)
diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py
index ec97bfc1..4bebcd68 100644
--- a/src/llamafactory/data/parser.py
+++ b/src/llamafactory/data/parser.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from dataclasses import dataclass
diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py
index cf207d7e..3a80900c 100644
--- a/src/llamafactory/data/preprocess.py
+++ b/src/llamafactory/data/preprocess.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
@@ -13,8 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
if TYPE_CHECKING:
- from transformers import ProcessorMixin, Seq2SeqTrainingArguments
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .template import Template
diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py
index 98d83658..7ba05e23 100644
--- a/src/llamafactory/data/processors/feedback.py
+++ b/src/llamafactory/data/processors/feedback.py
@@ -1,13 +1,26 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
-from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING:
- from transformers import ProcessorMixin
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..template import Template
@@ -42,12 +55,8 @@ def _encode_feedback_example(
else:
kl_messages = prompt + [kl_response[1]]
- prompt_ids, response_ids = template.encode_oneturn(
- tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
- _, kl_response_ids = template.encode_oneturn(
- tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
+ prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
+ _, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
@@ -57,6 +66,12 @@ def _encode_feedback_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
+ # do not consider the kl_response
+ source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
+ prompt_ids = prompt_ids[:source_len]
+ response_ids = response_ids[:target_len]
+ kl_response_ids = kl_response_ids[:target_len]
+
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py
index fe984efa..c6001e6e 100644
--- a/src/llamafactory/data/processors/pairwise.py
+++ b/src/llamafactory/data/processors/pairwise.py
@@ -1,13 +1,26 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
-from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING:
- from transformers import ProcessorMixin
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..template import Template
@@ -31,12 +44,8 @@ def _encode_pairwise_example(
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
- prompt_ids, chosen_ids = template.encode_oneturn(
- tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
- _, rejected_ids = template.encode_oneturn(
- tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
+ prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
+ _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
@@ -46,6 +55,13 @@ def _encode_pairwise_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
+ source_len, target_len = infer_seqlen(
+ len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
+ ) # consider the response is more important
+ prompt_ids = prompt_ids[:source_len]
+ chosen_ids = chosen_ids[:target_len]
+ rejected_ids = rejected_ids[:target_len]
+
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py
index 87727b55..67d6009b 100644
--- a/src/llamafactory/data/processors/pretrain.py
+++ b/src/llamafactory/data/processors/pretrain.py
@@ -1,9 +1,26 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer
from ...hparams import DataArguments
@@ -12,7 +29,8 @@ def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
- text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
+ eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
+ text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py
index 9903a053..455908ae 100644
--- a/src/llamafactory/data/processors/processor_utils.py
+++ b/src/llamafactory/data/processors/processor_utils.py
@@ -1,5 +1,19 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import bisect
-from typing import TYPE_CHECKING, List, Sequence
+from typing import TYPE_CHECKING, List, Sequence, Tuple
from ...extras.packages import is_pillow_available
@@ -62,3 +76,16 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
+
+
+def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
+ if target_len * 2 < cutoff_len: # truncate source
+ max_target_len = cutoff_len
+ elif source_len * 2 < cutoff_len: # truncate target
+ max_target_len = cutoff_len - source_len
+ else: # truncate both
+ max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
+
+ new_target_len = min(max_target_len, target_len)
+ new_source_len = max(cutoff_len - new_target_len, 0)
+ return new_source_len, new_target_len
diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py
index 35640174..8ef55321 100644
--- a/src/llamafactory/data/processors/supervised.py
+++ b/src/llamafactory/data/processors/supervised.py
@@ -1,14 +1,27 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
-from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
- from transformers import ProcessorMixin
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..template import Template
@@ -38,10 +51,17 @@ def _encode_supervised_example(
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
- encoded_pairs = template.encode_multiturn(
- tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
+ encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
+ total_length = 1 if template.efficient_eos else 0
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
+ if total_length >= data_args.cutoff_len:
+ break
+
+ source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
+ source_ids = source_ids[:source_len]
+ target_ids = target_ids[:target_len]
+ total_length += source_len + target_len
+
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py
index f711eeac..b3fc85c9 100644
--- a/src/llamafactory/data/processors/unsupervised.py
+++ b/src/llamafactory/data/processors/unsupervised.py
@@ -1,13 +1,26 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger
from ..data_utils import Role
-from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING:
- from transformers import ProcessorMixin
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..template import Template
@@ -34,9 +47,7 @@ def _encode_unsupervised_example(
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
- input_ids, labels = template.encode_oneturn(
- tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
+ input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
@@ -44,6 +55,9 @@ def _encode_unsupervised_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
+ source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
+ input_ids = input_ids[:source_len]
+ labels = labels[:target_len]
return input_ids, labels
diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index b600c567..aefd5195 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -1,8 +1,22 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger
-from .data_utils import Role, infer_max_len
+from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
@@ -24,69 +38,74 @@ class Template:
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
+ format_prefix: "Formatter"
default_system: str
stop_words: List[str]
image_token: str
efficient_eos: bool
replace_eos: bool
- force_system: bool
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
- messages: List[Dict[str, str]],
+ messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
- cutoff_len: int = 1_000_000,
- reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
- encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
- for query_ids, resp_ids in encoded_pairs[:-1]:
- prompt_ids += query_ids + resp_ids
- prompt_ids = prompt_ids + encoded_pairs[-1][0]
- answer_ids = encoded_pairs[-1][1]
+ for encoded_ids in encoded_messages[:-1]:
+ prompt_ids += encoded_ids
+
+ answer_ids = encoded_messages[-1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
- messages: List[Dict[str, str]],
+ messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
- cutoff_len: int = 1_000_000,
- reserved_label_len: int = 1,
- ) -> Sequence[Tuple[List[int], List[int]]]:
+ ) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
- return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
+ return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
+
+ def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
+ r"""
+ Extracts tool message.
+ """
+ return self.format_tools.extract(content)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
- messages: List[Dict[str, str]],
+ messages: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
- cutoff_len: int,
- reserved_label_len: int,
- ) -> Sequence[Tuple[List[int], List[int]]]:
+ ) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
- Turn 0: system + query resp
- Turn t: sep + query resp
+ Turn 0: prefix + system + query resp
+ Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
- if i == 0 and (system or tools or self.force_system):
- tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
- elements += self.format_system.apply(content=(system + tool_text))
- elif i > 0 and i % 2 == 0:
+
+ if i == 0:
+ elements += self.format_prefix.apply()
+ if system or tools:
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
+ elements += self.format_system.apply(content=(system + tool_text))
+
+ if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
@@ -102,11 +121,9 @@ class Template:
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
- return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
+ return encoded_messages
- def _convert_elements_to_ids(
- self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
- ) -> List[int]:
+ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
r"""
Converts elements to token ids.
"""
@@ -127,57 +144,34 @@ class Template:
return token_ids
- def _make_pairs(
- self,
- encoded_messages: Sequence[List[int]],
- cutoff_len: int,
- reserved_label_len: int,
- ) -> Sequence[Tuple[List[int], List[int]]]:
- encoded_pairs = []
- total_length = 0
- for i in range(0, len(encoded_messages), 2):
- if total_length >= cutoff_len:
- break
-
- max_source_len, max_target_len = infer_max_len(
- source_len=len(encoded_messages[i]),
- target_len=len(encoded_messages[i + 1]),
- max_len=(cutoff_len - total_length),
- reserved_label_len=reserved_label_len,
- )
- source_ids = encoded_messages[i][:max_source_len]
- target_ids = encoded_messages[i + 1][:max_target_len]
- total_length += len(source_ids) + len(target_ids)
- encoded_pairs.append((source_ids, target_ids))
-
- return encoded_pairs
-
@dataclass
class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
- messages: List[Dict[str, str]],
+ messages: Sequence[Dict[str, str]],
system: str,
tools: str,
- cutoff_len: int,
- reserved_label_len: int,
- ) -> Sequence[Tuple[List[int], List[int]]]:
+ ) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
- Turn 0: system + query resp
- Turn t: sep + query resp
+ Turn 0: prefix + system + query resp
+ Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
+
system_text = ""
- if i == 0 and (system or tools or self.force_system):
- tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
- system_text = self.format_system.apply(content=(system + tool_text))[0]
- elif i > 0 and i % 2 == 0:
+ if i == 0:
+ elements += self.format_prefix.apply()
+ if system or tools:
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
+ system_text = self.format_system.apply(content=(system + tool_text))[0]
+
+ if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
@@ -193,7 +187,7 @@ class Llama2Template(Template):
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
- return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
+ return encoded_messages
TEMPLATES: Dict[str, Template] = {}
@@ -208,12 +202,12 @@ def _register_template(
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
+ format_prefix: Optional["Formatter"] = None,
default_system: str = "",
- stop_words: List[str] = [],
+ stop_words: Sequence[str] = [],
image_token: str = "",
efficient_eos: bool = False,
replace_eos: bool = False,
- force_system: bool = False,
) -> None:
r"""
Registers a chat template.
@@ -245,9 +239,10 @@ def _register_template(
template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
- default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
+ default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
+ default_prefix_formatter = EmptyFormatter()
TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
@@ -256,12 +251,12 @@ def _register_template(
format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
+ format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
image_token=image_token,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
- force_system=force_system,
)
@@ -307,6 +302,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
jinja_template = ""
+ prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
+ if prefix:
+ jinja_template += "{{ " + prefix + " }}"
+
if template.default_system:
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
@@ -315,11 +314,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
)
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
- if isinstance(template, Llama2Template):
- pass
- elif template.force_system:
- jinja_template += "{{ " + system_message + " }}"
- else:
+ if not isinstance(template, Llama2Template):
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in messages %}"
@@ -346,6 +341,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
+ tool_format: Optional[str] = None,
) -> Template:
if name is None:
template = TEMPLATES["empty"] # placeholder
@@ -354,6 +350,12 @@ def get_template_and_fix_tokenizer(
if template is None:
raise ValueError("Template {} does not exist.".format(name))
+ if tool_format is not None:
+ logger.info("Using tool format: {}.".format(tool_format))
+ eos_slots = [] if template.efficient_eos else [{"eos_token"}]
+ template.format_tools = ToolFormatter(tool_format=tool_format)
+ template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
+
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
@@ -435,9 +437,8 @@ _register_template(
_register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@@ -450,11 +451,7 @@ _register_template(
_register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- default_system=(
- "You are a helpful AI assistant built by MediaTek Research. "
- "The user you are helping speaks Traditional Chinese and comes from Taiwan."
- ),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
)
@@ -462,10 +459,9 @@ _register_template(
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
- format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
efficient_eos=True,
- force_system=True,
)
@@ -473,32 +469,13 @@ _register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
- format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
- format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
+ format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
- stop_words=["<|user|>", "<|observation|>"],
- efficient_eos=True,
- force_system=True,
-)
-
-
-_register_template(
- name="chatglm3_system",
- format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
- format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
- format_system=StringFormatter(
- slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
- ),
- format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
- format_observation=StringFormatter(
- slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
- ),
- default_system=(
- "You are ChatGLM3, a large language model trained by Zhipu.AI. "
- "Follow the user's instructions carefully. Respond using markdown."
- ),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
)
@@ -529,8 +506,7 @@ _register_template(
_register_template(
name="codegeex2",
- format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
)
@@ -544,21 +520,15 @@ _register_template(
)
]
),
- format_system=StringFormatter(
- slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]
- ),
- default_system=(
- "You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users "
- "by providing thorough responses. You are trained by Cohere."
- ),
+ format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@@ -591,30 +561,28 @@ _register_template(
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
- format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
- format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
),
- stop_words=["<|EOT|>"],
- efficient_eos=True,
)
_register_template(
name="default",
- format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
+ format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
@@ -622,11 +590,7 @@ _register_template(
_register_template(
name="empty",
- format_user=StringFormatter(slots=["{{content}}"]),
- format_assistant=StringFormatter(slots=["{{content}}"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
efficient_eos=True,
- force_system=True,
)
@@ -648,13 +612,12 @@ _register_template(
_register_template(
name="gemma",
format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter(
slots=["tool\n{{content}}\nmodel\n"]
),
format_separator=EmptyFormatter(slots=["\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
- force_system=True,
)
@@ -662,36 +625,33 @@ _register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
- format_system=StringFormatter(slots=["[gMASK]{{content}}"]),
- format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
- force_system=True,
)
_register_template(
name="intern",
- format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:"]),
- format_separator=EmptyFormatter(slots=[{"token": ""}, "\n"]),
+ format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
+ format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=[""],
- efficient_eos=True,
+ efficient_eos=True, # internlm tokenizer cannot set eos_token_id
)
_register_template(
name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
- format_separator=EmptyFormatter(slots=["\n"]),
- default_system=(
- "You are an AI assistant whose name is InternLM (书生·浦语).\n"
- "- InternLM (书生·浦语) is a conversational language model that is developed "
- "by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
- "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
- "by the user such as English and 中文."
- ),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
)
@@ -700,7 +660,6 @@ _register_template(
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
- format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]),
)
@@ -723,9 +682,7 @@ _register_template(
)
]
),
- format_system=StringFormatter(
- slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
- ),
+ format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
@@ -734,7 +691,7 @@ _register_template(
)
]
),
- default_system="You are a helpful assistant.",
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
)
@@ -743,24 +700,21 @@ _register_template(
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
- format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
)
_register_template(
name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@@ -774,27 +728,25 @@ _register_template(
)
]
),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
- force_system=True,
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
- default_system="You are a helpful AI assistant.",
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
@@ -827,7 +779,6 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"],
replace_eos=True,
- force_system=True,
)
diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py
new file mode 100644
index 00000000..ac5565d5
--- /dev/null
+++ b/src/llamafactory/data/tool_utils.py
@@ -0,0 +1,140 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Any, Dict, List, Tuple, Union
+
+from .data_utils import SLOTS
+
+
+DEFAULT_TOOL_PROMPT = (
+ "You have access to the following tools:\n{tool_text}"
+ "Use the following format if using a tool:\n"
+ "```\n"
+ "Action: tool name (one of [{tool_names}]).\n"
+ "Action Input: the input to the tool, in a JSON format representing the kwargs "
+ """(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
+ "```\n"
+)
+
+
+GLM4_TOOL_PROMPT = (
+ "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
+ "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
+)
+
+
+@dataclass
+class ToolUtils(ABC):
+ @staticmethod
+ @abstractmethod
+ def get_function_slots() -> SLOTS: ...
+
+ @staticmethod
+ @abstractmethod
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
+
+ @staticmethod
+ @abstractmethod
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
+
+
+class DefaultToolUtils(ToolUtils):
+ @staticmethod
+ def get_function_slots() -> SLOTS:
+ return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
+
+ @staticmethod
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str:
+ tool_text = ""
+ tool_names = []
+ for tool in tools:
+ param_text = ""
+ for name, param in tool["parameters"]["properties"].items():
+ required, enum, items = "", "", ""
+ if name in tool["parameters"].get("required", []):
+ required = ", required"
+
+ if param.get("enum", None):
+ enum = ", should be one of [{}]".format(", ".join(param["enum"]))
+
+ if param.get("items", None):
+ items = ", where each item should be {}".format(param["items"].get("type", ""))
+
+ param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
+ name=name,
+ type=param.get("type", ""),
+ required=required,
+ desc=param.get("description", ""),
+ enum=enum,
+ items=items,
+ )
+
+ tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
+ name=tool["name"], desc=tool.get("description", ""), args=param_text
+ )
+ tool_names.append(tool["name"])
+
+ return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
+
+ @staticmethod
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
+ regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
+ action_match: List[Tuple[str, str]] = re.findall(regex, content)
+ if not action_match:
+ return content
+
+ results = []
+ for match in action_match:
+ tool_name = match[0].strip()
+ tool_input = match[1].strip().strip('"').strip("```")
+ try:
+ arguments = json.loads(tool_input)
+ results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
+ except json.JSONDecodeError:
+ return content
+
+ return results
+
+
+class GLM4ToolUtils(ToolUtils):
+ @staticmethod
+ def get_function_slots() -> SLOTS:
+ return ["{{name}}\n{{arguments}}"]
+
+ @staticmethod
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str:
+ tool_text = ""
+ for tool in tools:
+ tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
+ name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
+ )
+
+ return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
+
+ @staticmethod
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
+ if "\n" not in content:
+ return content
+
+ tool_name, tool_input = content.split("\n", maxsplit=1)
+ try:
+ arguments = json.loads(tool_input)
+ except json.JSONDecodeError:
+ return content
+
+ return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
diff --git a/src/llamafactory/eval/evaluator.py b/src/llamafactory/eval/evaluator.py
index 192f4815..d3140793 100644
--- a/src/llamafactory/eval/evaluator.py
+++ b/src/llamafactory/eval/evaluator.py
@@ -1,4 +1,41 @@
-# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
+# Copyright 2024 the LlamaFactory team.
+#
+# This code is inspired by the Dan's test library.
+# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# MIT License
+#
+# Copyright (c) 2020 Dan Hendrycks
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
import inspect
import json
@@ -26,9 +63,7 @@ class Evaluator:
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang)
- self.choice_inputs = [
- self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
- ]
+ self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
diff --git a/src/llamafactory/eval/template.py b/src/llamafactory/eval/template.py
index a4a6ef0e..7d524e7c 100644
--- a/src/llamafactory/eval/template.py
+++ b/src/llamafactory/eval/template.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple
@@ -10,7 +24,6 @@ class EvalTemplate:
system: str
choice: str
answer: str
- prefix: str
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
@@ -42,8 +55,8 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {}
-def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
- eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
+def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
+ eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
def get_eval_template(name: str) -> "EvalTemplate":
@@ -56,8 +69,7 @@ _register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
- answer="\nAnswer: ",
- prefix=" ",
+ answer="\nAnswer:",
)
@@ -66,5 +78,4 @@ _register_eval_template(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
- prefix=" ",
)
diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py
index 466b1269..6029d84f 100644
--- a/src/llamafactory/extras/constants.py
+++ b/src/llamafactory/extras/constants.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
@@ -404,6 +418,18 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
},
+ "DeepSeek-MoE-Coder-16B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
+ },
+ "DeepSeek-MoE-Coder-236B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
+ },
+ "DeepSeek-MoE-Coder-16B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
+ },
+ "DeepSeek-MoE-Coder-236B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
+ },
},
template="deepseek",
)
@@ -496,6 +522,18 @@ register_model_group(
"Gemma-1.1-7B-Chat": {
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
},
+ "Gemma-2-9B": {
+ DownloadSource.DEFAULT: "google/gemma-2-9b",
+ },
+ "Gemma-2-27B": {
+ DownloadSource.DEFAULT: "google/gemma-2-27b",
+ },
+ "Gemma-2-9B-Chat": {
+ DownloadSource.DEFAULT: "google/gemma-2-9b-it",
+ },
+ "Gemma-2-27B-Chat": {
+ DownloadSource.DEFAULT: "google/gemma-2-27b-it",
+ },
},
template="gemma",
)
@@ -568,7 +606,7 @@ register_model_group(
register_model_group(
models={
- "Jambda-v0.1": {
+ "Jamba-v0.1": {
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
}
@@ -683,6 +721,21 @@ register_model_group(
)
+register_model_group(
+ models={
+ "MiniCPM-2B-SFT-Chat": {
+ DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
+ DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
+ },
+ "MiniCPM-2B-DPO-Chat": {
+ DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
+ DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
+ },
+ },
+ template="cpm",
+)
+
+
register_model_group(
models={
"Mistral-7B-v0.1": {
diff --git a/src/llamafactory/extras/env.py b/src/llamafactory/extras/env.py
index 1d4e43f1..14876048 100644
--- a/src/llamafactory/extras/env.py
+++ b/src/llamafactory/extras/env.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import platform
import accelerate
@@ -9,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
-VERSION = "0.8.1.dev0"
+VERSION = "0.8.3.dev0"
def print_env() -> None:
diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py
index 430b8a48..67622212 100644
--- a/src/llamafactory/extras/logging.py
+++ b/src/llamafactory/extras/logging.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import logging
import os
import sys
diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py
index fc33f77e..20c752c5 100644
--- a/src/llamafactory/extras/misc.py
+++ b/src/llamafactory/extras/misc.py
@@ -1,13 +1,29 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's PEFT library.
+# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import gc
import os
-from typing import TYPE_CHECKING, Dict, Tuple
+from typing import TYPE_CHECKING, Tuple
import torch
-from peft import PeftModel
-from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
+import transformers.dynamic_module_utils
+from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
+from transformers.dynamic_module_utils import get_relative_imports
from transformers.utils import (
- SAFE_WEIGHTS_NAME,
- WEIGHTS_NAME,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_mps_available,
@@ -16,7 +32,6 @@ from transformers.utils import (
)
from transformers.utils.versions import require_version
-from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger
@@ -28,8 +43,6 @@ except Exception:
if TYPE_CHECKING:
- from trl import AutoModelForCausalLMWithValueHead
-
from ..hparams import ModelArguments
@@ -58,6 +71,9 @@ class AverageMeter:
def check_dependencies() -> None:
+ r"""
+ Checks the version of the required packages.
+ """
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
@@ -68,7 +84,7 @@ def check_dependencies() -> None:
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
-def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
+def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
@@ -79,7 +95,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
- # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
+ # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
@@ -97,55 +113,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
-def fix_valuehead_checkpoint(
- model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
-) -> None:
- r"""
- The model is already unwrapped.
-
- There are three cases:
- 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
- 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
- 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
-
- We assume `stage3_gather_16bit_weights_on_model_save=true`.
- """
- if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
- return
-
- if safe_serialization:
- from safetensors import safe_open
- from safetensors.torch import save_file
-
- path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
- with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
- state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
- else:
- path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
- state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
-
- decoder_state_dict = {}
- v_head_state_dict = {}
- for name, param in state_dict.items():
- if name.startswith("v_head."):
- v_head_state_dict[name] = param
- else:
- decoder_state_dict[name.replace("pretrained_model.", "")] = param
-
- os.remove(path_to_checkpoint)
- model.pretrained_model.save_pretrained(
- output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
- )
-
- if safe_serialization:
- save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
- else:
- torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
-
- logger.info("Value head model saved at: {}".format(output_dir))
-
-
-def get_current_device() -> torch.device:
+def get_current_device() -> "torch.device":
r"""
Gets the current available device.
"""
@@ -184,7 +152,14 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor
-def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
+def has_tokenized_data(path: "os.PathLike") -> bool:
+ r"""
+ Checks if the path has a tokenized dataset.
+ """
+ return os.path.isdir(path) and len(os.listdir(path)) > 0
+
+
+def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
@@ -203,11 +178,9 @@ def is_gpu_or_npu_available() -> bool:
return is_torch_npu_available() or is_torch_cuda_available()
-def has_tokenized_data(path: os.PathLike) -> bool:
- r"""
- Checks if the path has a tokenized dataset.
- """
- return os.path.isdir(path) and len(os.listdir(path)) > 0
+def skip_check_imports() -> None:
+ if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
+ transformers.dynamic_module_utils.check_imports = get_relative_imports
def torch_gc() -> None:
diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py
index 4c9e6492..0a84a293 100644
--- a/src/llamafactory/extras/packages.py
+++ b/src/llamafactory/extras/packages.py
@@ -1,5 +1,23 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import importlib.metadata
import importlib.util
+from functools import lru_cache
from typing import TYPE_CHECKING
from packaging import version
@@ -24,10 +42,6 @@ def is_fastapi_available():
return _is_package_available("fastapi")
-def is_flash_attn2_available():
- return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
-
-
def is_galore_available():
return _is_package_available("galore_torch")
@@ -36,18 +50,10 @@ def is_gradio_available():
return _is_package_available("gradio")
-def is_jieba_available():
- return _is_package_available("jieba")
-
-
def is_matplotlib_available():
return _is_package_available("matplotlib")
-def is_nltk_available():
- return _is_package_available("nltk")
-
-
def is_pillow_available():
return _is_package_available("PIL")
@@ -60,10 +66,6 @@ def is_rouge_available():
return _is_package_available("rouge_chinese")
-def is_sdpa_available():
- return _get_package_version("torch") > version.parse("2.1.1")
-
-
def is_starlette_available():
return _is_package_available("sse_starlette")
@@ -74,3 +76,8 @@ def is_uvicorn_available():
def is_vllm_available():
return _is_package_available("vllm")
+
+
+@lru_cache
+def is_vllm_version_greater_than_0_5():
+ return _get_package_version("vllm") >= version.parse("0.5.0")
diff --git a/src/llamafactory/extras/ploting.py b/src/llamafactory/extras/ploting.py
index dea23bbe..596d55e7 100644
--- a/src/llamafactory/extras/ploting.py
+++ b/src/llamafactory/extras/ploting.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import math
import os
diff --git a/src/llamafactory/hparams/__init__.py b/src/llamafactory/hparams/__init__.py
index d1ee98dd..cfe448c1 100644
--- a/src/llamafactory/hparams/__init__.py
+++ b/src/llamafactory/hparams/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py
index d2d53ec8..e351fccf 100644
--- a/src/llamafactory/hparams/data_args.py
+++ b/src/llamafactory/hparams/data_args.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass, field
from typing import Literal, Optional
@@ -28,10 +45,6 @@ class DataArguments:
default=1024,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
)
- reserved_label_len: int = field(
- default=1,
- metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
- )
train_on_prompt: bool = field(
default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."},
@@ -90,15 +103,16 @@ class DataArguments:
"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."
},
)
+ tool_format: Optional[str] = field(
+ default=None,
+ metadata={"help": "Tool format to use for constructing function calling examples."},
+ )
tokenized_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the tokenized datasets."},
)
def __post_init__(self):
- if self.reserved_label_len >= self.cutoff_len:
- raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
-
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.")
diff --git a/src/llamafactory/hparams/evaluation_args.py b/src/llamafactory/hparams/evaluation_args.py
index 5a05f6f6..a7f221ca 100644
--- a/src/llamafactory/hparams/evaluation_args.py
+++ b/src/llamafactory/hparams/evaluation_args.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from dataclasses import dataclass, field
from typing import Literal, Optional
diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py
index 08af31e4..3867c0ec 100644
--- a/src/llamafactory/hparams/finetuning_args.py
+++ b/src/llamafactory/hparams/finetuning_args.py
@@ -1,5 +1,19 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass, field
-from typing import Literal, Optional
+from typing import List, Literal, Optional
@dataclass
@@ -94,6 +108,18 @@ class LoraArguments:
default=False,
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
)
+ pissa_init: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to initialize a PiSSA adapter."},
+ )
+ pissa_iter: int = field(
+ default=16,
+ metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
+ )
+ pissa_convert: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
+ )
create_new_adapter: bool = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
@@ -319,20 +345,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
return [item.strip() for item in arg.split(",")]
return arg
- self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
- self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
- self.lora_alpha = self.lora_alpha or self.lora_rank * 2
- self.lora_target = split_arg(self.lora_target)
- self.additional_target = split_arg(self.additional_target)
- self.galore_target = split_arg(self.galore_target)
+ self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
+ self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
+ self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
+ self.lora_target: List[str] = split_arg(self.lora_target)
+ self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
+ self.galore_target: List[str] = split_arg(self.galore_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
+ self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
- self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
-
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("`reward_model` is necessary for PPO training.")
@@ -354,5 +379,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
+ if self.pissa_init and self.finetuning_type != "lora":
+ raise ValueError("`pissa_init` is only valid for LoRA training.")
+
+ if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
+ raise ValueError("Cannot use PiSSA for current training stage.")
+
if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.")
diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py
index 0ee17d1a..7ebb4eed 100644
--- a/src/llamafactory/hparams/generating_args.py
+++ b/src/llamafactory/hparams/generating_args.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py
index 6352a420..087c8c38 100644
--- a/src/llamafactory/hparams/model_args.py
+++ b/src/llamafactory/hparams/model_args.py
@@ -1,5 +1,28 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import asdict, dataclass, field
-from typing import Any, Dict, Literal, Optional
+from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
+
+from typing_extensions import Self
+
+
+if TYPE_CHECKING:
+ import torch
@dataclass
@@ -22,6 +45,10 @@ class ModelArguments:
)
},
)
+ adapter_folder: Optional[str] = field(
+ default=None,
+ metadata={"help": "The folder containing the adapter weights to load."},
+ )
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
@@ -50,6 +77,10 @@ class ModelArguments:
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
+ quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
+ default="bitsandbytes",
+ metadata={"help": "Quantization method to use for on-the-fly quantization."},
+ )
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
@@ -70,7 +101,7 @@ class ModelArguments:
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
- flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
+ flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
default="auto",
metadata={"help": "Enable FlashAttention for faster training and inference."},
)
@@ -127,13 +158,9 @@ class ModelArguments:
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
- default=8,
+ default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
- vllm_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
- default="auto",
- metadata={"help": "Data type for model weights and activations in the vLLM engine."},
- )
offload_folder: str = field(
default="offload",
metadata={"help": "Path to offload model weights."},
@@ -142,6 +169,10 @@ class ModelArguments:
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)
+ infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
+ default="auto",
+ metadata={"help": "Data type for model weights and activations at inference."},
+ )
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
@@ -192,9 +223,9 @@ class ModelArguments:
)
def __post_init__(self):
- self.compute_dtype = None
- self.device_map = None
- self.model_max_length = None
+ self.compute_dtype: Optional["torch.dtype"] = None
+ self.device_map: Optional[Union[str, Dict[str, Any]]] = None
+ self.model_max_length: Optional[int] = None
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
@@ -208,11 +239,18 @@ class ModelArguments:
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
- assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
- assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
-
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
+
+ @classmethod
+ def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
+ arg_dict = old_arg.to_dict()
+ arg_dict.update(**kwargs)
+ new_arg = cls(**arg_dict)
+ new_arg.compute_dtype = old_arg.compute_dtype
+ new_arg.device_map = old_arg.device_map
+ new_arg.model_max_length = old_arg.model_max_length
+ return new_arg
diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py
index ff1fbf5d..8b2ea4c1 100644
--- a/src/llamafactory/hparams/parser.py
+++ b/src/llamafactory/hparams/parser.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import logging
import os
import sys
@@ -8,6 +25,7 @@ import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
+from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
@@ -65,13 +83,13 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.")
- if model_args.use_unsloth and is_deepspeed_zero3_enabled():
- raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
-
if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
+ if finetuning_args.pissa_init:
+ raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
+
if model_args.resize_vocab:
raise ValueError("Cannot resize embedding layers of a quantized model.")
@@ -100,7 +118,7 @@ def _check_extra_dependencies(
require_version("galore_torch", "To fix: pip install galore_torch")
if finetuning_args.use_badam:
- require_version("badam", "To fix: pip install badam")
+ require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
@@ -162,6 +180,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
):
raise ValueError("PPO only accepts wandb or tensorboard logger.")
+ if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
+ raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
+
+ if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
+ raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
+
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
@@ -171,32 +195,31 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")
- if finetuning_args.use_dora and model_args.use_unsloth:
- raise ValueError("Unsloth does not support DoRA.")
+ if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
+ raise ValueError("PiSSA is incompatible with DeepSpeed ZeRO-3.")
if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available():
raise ValueError("This device does not support `pure_bf16`.")
- if training_args.fp16 or training_args.bf16:
- raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
+ if is_deepspeed_zero3_enabled():
+ raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
if (
finetuning_args.use_galore
and finetuning_args.galore_layerwise
- and training_args.parallel_mode.value == "distributed"
+ and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
raise ValueError("Distributed training does not support layer-wise GaLore.")
- if (
- finetuning_args.use_badam
- and finetuning_args.badam_mode == "layer"
- and training_args.parallel_mode.value == "distributed"
- ):
- raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
+ if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
+ if finetuning_args.badam_mode == "ratio":
+ raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
+ elif not is_deepspeed_zero3_enabled():
+ raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
- if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
- raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
+ if finetuning_args.use_galore and training_args.deepspeed is not None:
+ raise ValueError("GaLore is incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
@@ -204,6 +227,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
+ if model_args.use_unsloth and is_deepspeed_zero3_enabled():
+ raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
+
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
@@ -233,7 +259,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
# Post-process training arguments
if (
- training_args.parallel_mode.value == "distributed"
+ training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
@@ -293,7 +319,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
training_args.local_rank,
training_args.device,
training_args.n_gpu,
- training_args.parallel_mode.value == "distributed",
+ training_args.parallel_mode == ParallelMode.DISTRIBUTED,
str(model_args.compute_dtype),
)
)
@@ -332,6 +358,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device("cpu")}
+ model_args.model_max_length = data_args.cutoff_len
else:
model_args.device_map = "auto"
diff --git a/src/llamafactory/launcher.py b/src/llamafactory/launcher.py
index de154db9..65e0b68f 100644
--- a/src/llamafactory/launcher.py
+++ b/src/llamafactory/launcher.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from llamafactory.train.tuner import run_exp
diff --git a/src/llamafactory/model/__init__.py b/src/llamafactory/model/__init__.py
index 9d23d59f..48cfe76c 100644
--- a/src/llamafactory/model/__init__.py
+++ b/src/llamafactory/model/__init__.py
@@ -1,9 +1,25 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .loader import load_config, load_model, load_tokenizer
from .model_utils.misc import find_all_linear_modules
+from .model_utils.quantization import QuantizationMethod
from .model_utils.valuehead import load_valuehead_params
__all__ = [
+ "QuantizationMethod",
"load_config",
"load_model",
"load_tokenizer",
diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py
index f4e501a7..7caef9cc 100644
--- a/src/llamafactory/model/adapter.py
+++ b/src/llamafactory/model/adapter.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import re
from typing import TYPE_CHECKING
@@ -25,8 +39,12 @@ def _setup_full_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
+ is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> None:
+ if not is_trainable:
+ return
+
logger.info("Fine-tuning method: Full")
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
@@ -47,8 +65,12 @@ def _setup_freeze_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
+ is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> None:
+ if not is_trainable:
+ return
+
logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs:
config = model.config.text_config
@@ -132,7 +154,9 @@ def _setup_lora_tuning(
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
- logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
+ if is_trainable:
+ logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
+
adapter_to_resume = None
if model_args.adapter_name_or_path is not None:
@@ -155,8 +179,16 @@ def _setup_lora_tuning(
else:
adapter_to_merge = model_args.adapter_name_or_path
+ init_kwargs = {
+ "subfolder": model_args.adapter_folder,
+ "offload_folder": model_args.offload_folder,
+ "cache_dir": model_args.cache_dir,
+ "revision": model_args.model_revision,
+ "token": model_args.hf_hub_token,
+ }
+
for adapter in adapter_to_merge:
- model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder)
+ model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
@@ -166,12 +198,9 @@ def _setup_lora_tuning(
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
else:
- model = PeftModel.from_pretrained(
- model,
- adapter_to_resume,
- is_trainable=is_trainable,
- offload_folder=model_args.offload_folder,
- )
+ model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
+
+ logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
@@ -209,16 +238,24 @@ def _setup_lora_tuning(
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
+ "use_dora": finetuning_args.use_dora,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
+ if finetuning_args.pissa_init:
+ if finetuning_args.pissa_iter == -1:
+ logger.info("Using PiSSA initialization.")
+ peft_kwargs["init_lora_weights"] = "pissa"
+ else:
+ logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
+ peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
+
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
- use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
@@ -227,9 +264,6 @@ def _setup_lora_tuning(
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
- if model_args.adapter_name_or_path is not None:
- logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
-
return model
@@ -247,29 +281,36 @@ def init_adapter(
Note that the trainable parameters must be cast to float32.
"""
- if (not is_trainable) and model_args.adapter_name_or_path is None:
- logger.info("Adapter is not found at evaluation, load the base model.")
- return model
+ if is_trainable and getattr(model, "quantization_method", None) is not None:
+ if finetuning_args.finetuning_type != "lora":
+ raise ValueError("Quantized models can only be used for the LoRA tuning.")
- if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
- raise ValueError("You can only use lora for quantized models.")
+ if finetuning_args.pissa_init:
+ raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
- if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
- logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
- cast_trainable_params_to_fp32 = False
+ # cast trainable parameters to float32 if:
+ # 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
+ # 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
+ cast_trainable_params_to_fp32 = False
+ if not is_trainable:
+ pass
+ elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
+ logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
+ elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
+ logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else:
logger.info("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
- if is_trainable and finetuning_args.finetuning_type == "full":
- _setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
-
- if is_trainable and finetuning_args.finetuning_type == "freeze":
- _setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
-
- if finetuning_args.finetuning_type == "lora":
+ if finetuning_args.finetuning_type == "full":
+ _setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
+ elif finetuning_args.finetuning_type == "freeze":
+ _setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
+ elif finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
)
+ else:
+ raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
return model
diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py
index 026a09be..43e65d52 100644
--- a/src/llamafactory/model/loader.py
+++ b/src/llamafactory/model/loader.py
@@ -1,10 +1,25 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
+import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
-from ..extras.misc import count_parameters, try_download_model_from_ms
+from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
from .adapter import init_adapter
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
@@ -33,6 +48,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
Note: including inplace operation of model_args.
"""
+ skip_check_imports()
model_args.model_name_or_path = try_download_model_from_ms(model_args)
return {
"trust_remote_code": True,
@@ -162,17 +178,21 @@ def load_model(
if not is_trainable:
model.requires_grad_(False)
+ for param in model.parameters():
+ if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
+ param.data = param.data.to(model_args.compute_dtype)
+
model.eval()
else:
model.train()
trainable_params, all_param = count_parameters(model)
if is_trainable:
- param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
+ param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
else:
- param_stats = "all params: {:d}".format(all_param)
+ param_stats = "all params: {:,}".format(all_param)
logger.info(param_stats)
diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py
index b52ddc86..4bed7e21 100644
--- a/src/llamafactory/model/model_utils/attention.py
+++ b/src/llamafactory/model/model_utils/attention.py
@@ -1,7 +1,22 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING
+from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
+
from ...extras.logging import get_logger
-from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING:
@@ -13,21 +28,33 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
-def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
+def configure_attn_implementation(
+ config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
+) -> None:
+ if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
+ if model_args.flash_attn == "auto":
+ logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
+ model_args.flash_attn = "disabled"
+ elif model_args.flash_attn != "disabled":
+ logger.warning(
+ "Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. "
+ "Will proceed at your own risk.".format(model_args.flash_attn)
+ )
+
if model_args.flash_attn == "auto":
return
- elif model_args.flash_attn == "off":
+ elif model_args.flash_attn == "disabled":
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
- if not is_sdpa_available():
+ if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
- if not is_flash_attn2_available():
+ if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.")
return
diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py
index e0657be8..f4f3d8a5 100644
--- a/src/llamafactory/model/model_utils/checkpointing.py
+++ b/src/llamafactory/model/model_utils/checkpointing.py
@@ -1,3 +1,21 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's Transformers and PEFT library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
+# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import inspect
from functools import partial
from types import MethodType
@@ -60,15 +78,12 @@ def _fp32_forward_post_hook(
return output.to(torch.float32)
-def prepare_model_for_training(
- model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
-) -> None:
+def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
- Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
@@ -87,8 +102,8 @@ def prepare_model_for_training(
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
- if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
- logger.info("Upcasting lm_head outputs in float32.")
- output_layer = getattr(model, output_layer_name)
+ if model_args.upcast_lmhead_output:
+ output_layer = model.get_output_embeddings()
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
+ logger.info("Upcasting lm_head outputs in float32.")
output_layer.register_forward_hook(_fp32_forward_post_hook)
diff --git a/src/llamafactory/model/model_utils/embedding.py b/src/llamafactory/model/model_utils/embedding.py
index 3d9278e3..3ff79828 100644
--- a/src/llamafactory/model/model_utils/embedding.py
+++ b/src/llamafactory/model/model_utils/embedding.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py
index c8dc52f5..af30bd50 100644
--- a/src/llamafactory/model/model_utils/longlora.py
+++ b/src/llamafactory/model/model_utils/longlora.py
@@ -1,3 +1,22 @@
+# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
+#
+# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
+# This code is also inspired by the original LongLoRA implementation.
+# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
from typing import TYPE_CHECKING, Optional, Tuple
@@ -96,7 +115,8 @@ def llama_attention_forward(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
- )
+ ),
+ dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -181,11 +201,9 @@ def llama_flash_attention_2_forward(
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
- else:
- groupsz = q_len
attn_output: torch.Tensor = self._flash_attention_forward(
- query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
+ query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
@@ -194,7 +212,8 @@ def llama_flash_attention_2_forward(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
- )
+ ),
+ dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -293,7 +312,8 @@ def llama_sdpa_attention_forward(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
- )
+ ),
+ dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -303,7 +323,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
- require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2")
+ require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py
index 4851bd29..a2812228 100644
--- a/src/llamafactory/model/model_utils/misc.py
+++ b/src/llamafactory/model/model_utils/misc.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, List
from ...extras.logging import get_logger
diff --git a/src/llamafactory/model/model_utils/mod.py b/src/llamafactory/model/model_utils/mod.py
index 5708a1a8..ec73af00 100644
--- a/src/llamafactory/model/model_utils/mod.py
+++ b/src/llamafactory/model/model_utils/mod.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING
from ...extras.constants import MOD_SUPPORTED_MODELS
diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py
index e554e45a..5c7473aa 100644
--- a/src/llamafactory/model/model_utils/moe.py
+++ b/src/llamafactory/model/model_utils/moe.py
@@ -1,5 +1,20 @@
-from typing import TYPE_CHECKING
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Sequence
+
+import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
@@ -10,6 +25,13 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
+def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
+ require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
+ from deepspeed.utils import set_z3_leaf_modules # type: ignore
+
+ set_z3_leaf_modules(model, leaf_modules)
+
+
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
@@ -17,33 +39,30 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
if not is_deepspeed_zero3_enabled():
return
- require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
- from deepspeed.utils import set_z3_leaf_modules # type: ignore
-
if getattr(model.config, "model_type", None) == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
- set_z3_leaf_modules(model, [DbrxFFN])
+ _set_z3_leaf_modules(model, [DbrxFFN])
if getattr(model.config, "model_type", None) == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
- set_z3_leaf_modules(model, [JambaSparseMoeBlock])
+ _set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jetmoe":
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
- set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
+ _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if getattr(model.config, "model_type", None) == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
- set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
+ _set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
- set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
+ _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py
index 02a54f07..317646e0 100644
--- a/src/llamafactory/model/model_utils/quantization.py
+++ b/src/llamafactory/model/model_utils/quantization.py
@@ -1,3 +1,21 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's Transformers and Optimum library.
+# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
+# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import random
from enum import Enum, unique
@@ -5,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List
import torch
from datasets import load_dataset
-from transformers import BitsAndBytesConfig, GPTQConfig
+from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
@@ -39,10 +57,9 @@ class QuantizationMethod(str, Enum):
HQQ = "hqq"
-def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
+def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
r"""
- Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
- TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
+ Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
@@ -51,20 +68,32 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
data_path = model_args.export_quantization_dataset
data_files = None
- dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
- maxlen = model_args.export_quantization_maxlen
+ dataset = load_dataset(
+ path=data_path,
+ data_files=data_files,
+ split="train",
+ cache_dir=model_args.cache_dir,
+ token=model_args.hf_hub_token,
+ )
samples = []
+ maxlen = model_args.export_quantization_maxlen
for _ in range(model_args.export_quantization_nsamples):
+ n_try = 0
while True:
+ if n_try > 100:
+ raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
+
sample_idx = random.randint(0, len(dataset) - 1)
- sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
- if sample["input_ids"].size(1) >= maxlen:
+ sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
+ n_try += 1
+ if sample["input_ids"].size(1) > maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
- samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
+ attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
+ samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()})
return samples
@@ -76,14 +105,14 @@ def configure_quantization(
init_kwargs: Dict[str, Any],
) -> None:
r"""
- Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
+ Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
if getattr(config, "quantization_config", None): # ptq
- if is_deepspeed_zero3_enabled():
- raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
+ if model_args.quantization_bit is not None:
+ logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
- if model_args.quantization_device_map != "auto":
- init_kwargs["device_map"] = {"": get_current_device()}
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
+ raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
@@ -105,46 +134,72 @@ def configure_quantization(
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
- require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
+ if model_args.export_quantization_bit not in [8, 4, 3, 2]:
+ raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
+
+ require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
- raise ValueError("ChatGLM model is not supported.")
+ raise ValueError("ChatGLM model is not supported yet.")
init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
- tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
- logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
+ logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
- elif model_args.quantization_bit is not None: # bnb
- if model_args.quantization_bit == 8:
- require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
- init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
+ elif model_args.quantization_bit is not None: # on-the-fly
+ if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
+ if model_args.quantization_bit == 8:
+ require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
+ init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
+ elif model_args.quantization_bit == 4:
+ require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
+ init_kwargs["quantization_config"] = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=model_args.compute_dtype,
+ bnb_4bit_use_double_quant=model_args.double_quantization,
+ bnb_4bit_quant_type=model_args.quantization_type,
+ bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
+ )
+ else:
+ raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
- elif model_args.quantization_bit == 4:
- require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
- init_kwargs["quantization_config"] = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_compute_dtype=model_args.compute_dtype,
- bnb_4bit_use_double_quant=model_args.double_quantization,
- bnb_4bit_quant_type=model_args.quantization_type,
- bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
- )
+ # Do not assign device map if:
+ # 1. deepspeed zero3 or fsdp (train)
+ # 2. auto quantization device map (inference)
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
+ if model_args.quantization_bit != 4:
+ raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
- if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
- if model_args.quantization_bit != 4:
- raise ValueError("Only 4-bit quantized model can use auto device map.")
+ require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
+ else:
+ init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
- require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
- require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
- require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
- init_kwargs["torch_dtype"] = model_args.compute_dtype # fsdp+qlora requires same dtype
- else:
- init_kwargs["device_map"] = {"": get_current_device()}
+ logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
+ elif model_args.quantization_method == QuantizationMethod.HQQ.value:
+ if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
+ raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
- logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
+ raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
+
+ require_version("hqq", "To fix: pip install hqq")
+ init_kwargs["quantization_config"] = HqqConfig(
+ nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
+ ) # use ATEN kernel (axis=0) for performance
+ logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
+ elif model_args.quantization_method == QuantizationMethod.EETQ.value:
+ if model_args.quantization_bit != 8:
+ raise ValueError("EETQ only accepts 8-bit quantization.")
+
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
+ raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
+
+ require_version("eetq", "To fix: pip install eetq")
+ init_kwargs["quantization_config"] = EetqConfig()
+ logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
diff --git a/src/llamafactory/model/model_utils/rope.py b/src/llamafactory/model/model_utils/rope.py
index 93ab8929..4373ee19 100644
--- a/src/llamafactory/model/model_utils/rope.py
+++ b/src/llamafactory/model/model_utils/rope.py
@@ -1,3 +1,21 @@
+# Copyright 2024 LMSYS and the LlamaFactory team.
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# This code is inspired by the LMSYS's FastChat library.
+# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
from typing import TYPE_CHECKING
@@ -21,8 +39,8 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.warning("Current model does not support RoPE scaling.")
return
- if is_trainable:
- if model_args.rope_scaling == "dynamic":
+ if model_args.model_max_length is not None:
+ if is_trainable and model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
diff --git a/src/llamafactory/model/model_utils/unsloth.py b/src/llamafactory/model/model_utils/unsloth.py
index 8a16409d..9cfaec61 100644
--- a/src/llamafactory/model/model_utils/unsloth.py
+++ b/src/llamafactory/model/model_utils/unsloth.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger
diff --git a/src/llamafactory/model/model_utils/valuehead.py b/src/llamafactory/model/model_utils/valuehead.py
index 64333688..9ab3d45a 100644
--- a/src/llamafactory/model/model_utils/valuehead.py
+++ b/src/llamafactory/model/model_utils/valuehead.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict
import torch
diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py
index c8260b7f..700bf470 100644
--- a/src/llamafactory/model/model_utils/visual.py
+++ b/src/llamafactory/model/model_utils/visual.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's Transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Tuple
import torch
diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py
index 47591de6..f1831ced 100644
--- a/src/llamafactory/model/patcher.py
+++ b/src/llamafactory/model/patcher.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict
@@ -46,13 +60,16 @@ def patch_config(
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
- model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
+ if model_args.infer_dtype != "auto" and not is_trainable:
+ model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
+ else:
+ model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available():
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
- configure_attn_implementation(config, model_args)
+ configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
@@ -74,14 +91,17 @@ def patch_config(
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
- if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp
+ # cast data type of the model if:
+ # 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
+ # 2. quantization_bit is not None (qlora)
+ if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None:
init_kwargs["torch_dtype"] = model_args.compute_dtype
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map
- if init_kwargs["device_map"] == "auto":
+ if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
if finetune_args.stage == "sft" and data_args.efficient_packing:
@@ -137,6 +157,10 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings()
+ def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
+ if isinstance(self.pretrained_model, PreTrainedModel):
+ return self.pretrained_model.get_output_embeddings()
+
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
if isinstance(self.pretrained_model, PeftModel):
self.pretrained_model.create_or_update_model_card(output_dir)
@@ -145,4 +169,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
+ setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
diff --git a/src/llamafactory/extras/callbacks.py b/src/llamafactory/train/callbacks.py
similarity index 56%
rename from src/llamafactory/extras/callbacks.py
rename to src/llamafactory/train/callbacks.py
index 441ebbfd..4d024278 100644
--- a/src/llamafactory/extras/callbacks.py
+++ b/src/llamafactory/train/callbacks.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import logging
import os
@@ -8,22 +22,78 @@ from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional
+import torch
import transformers
-from transformers import TrainerCallback
+from peft import PeftModel
+from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
+from transformers.utils import (
+ SAFE_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ is_safetensors_available,
+)
-from .constants import TRAINER_LOG
-from .logging import LoggerHandler, get_logger
-from .misc import fix_valuehead_checkpoint
+from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
+from ..extras.logging import LoggerHandler, get_logger
+if is_safetensors_available():
+ from safetensors import safe_open
+ from safetensors.torch import save_file
+
if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments
+ from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
+def fix_valuehead_checkpoint(
+ model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
+) -> None:
+ r"""
+ The model is already unwrapped.
+
+ There are three cases:
+ 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
+ 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
+ 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
+
+ We assume `stage3_gather_16bit_weights_on_model_save=true`.
+ """
+ if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
+ return
+
+ if safe_serialization:
+ path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
+ with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
+ state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
+ else:
+ path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
+ state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
+
+ decoder_state_dict = {}
+ v_head_state_dict = {}
+ for name, param in state_dict.items():
+ if name.startswith("v_head."):
+ v_head_state_dict[name] = param
+ else:
+ decoder_state_dict[name.replace("pretrained_model.", "")] = param
+
+ os.remove(path_to_checkpoint)
+ model.pretrained_model.save_pretrained(
+ output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
+ )
+
+ if safe_serialization:
+ save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
+ else:
+ torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
+
+ logger.info("Value head model saved at: {}".format(output_dir))
+
+
class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
@@ -37,8 +107,70 @@ class FixValueHeadModelCallback(TrainerCallback):
)
+class SaveProcessorCallback(TrainerCallback):
+ def __init__(self, processor: "ProcessorMixin") -> None:
+ r"""
+ Initializes a callback for saving the processor.
+ """
+ self.processor = processor
+
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of training.
+ """
+ if args.should_save:
+ getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
+
+
+class PissaConvertCallback(TrainerCallback):
+ r"""
+ Initializes a callback for converting the PiSSA adapter to a normal one.
+ """
+
+ def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the beginning of training.
+ """
+ if args.should_save:
+ model = kwargs.pop("model")
+ pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
+ logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir))
+ if isinstance(model, PeftModel):
+ init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
+ setattr(model.peft_config["default"], "init_lora_weights", True)
+ model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
+ setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
+
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of training.
+ """
+ if args.should_save:
+ model = kwargs.pop("model")
+ pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
+ pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
+ pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
+ logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
+ # 1. save a pissa backup with init_lora_weights: True
+ # 2. save a converted lora with init_lora_weights: pissa
+ # 3. load the pissa backup with init_lora_weights: True
+ # 4. delete the initial adapter and change init_lora_weights to pissa
+ if isinstance(model, PeftModel):
+ init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
+ setattr(model.peft_config["default"], "init_lora_weights", True)
+ model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
+ setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
+ model.save_pretrained(
+ pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
+ )
+ model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
+ model.set_adapter("default")
+ model.delete_adapter("pissa_init")
+ setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
+
+
class LogCallback(TrainerCallback):
- def __init__(self, output_dir: str) -> None:
+ def __init__(self) -> None:
r"""
Initializes a callback for logging training and evaluation status.
"""
@@ -56,7 +188,7 @@ class LogCallback(TrainerCallback):
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
- self.logger_handler = LoggerHandler(output_dir)
+ self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
diff --git a/src/llamafactory/train/dpo/__init__.py b/src/llamafactory/train/dpo/__init__.py
index 43fe9420..9ce0d089 100644
--- a/src/llamafactory/train/dpo/__init__.py
+++ b/src/llamafactory/train/dpo/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_dpo
diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py
index d860b29a..e45467d6 100644
--- a/src/llamafactory/train/dpo/trainer.py
+++ b/src/llamafactory/train/dpo/trainer.py
@@ -1,3 +1,21 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
@@ -10,7 +28,8 @@ from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
-from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
+from ..callbacks import PissaConvertCallback, SaveProcessorCallback
+from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
@@ -35,7 +54,6 @@ class CustomDPOTrainer(DPOTrainer):
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
- self.processor = processor
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
@@ -61,6 +79,8 @@ class CustomDPOTrainer(DPOTrainer):
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
+ warnings.simplefilter("ignore") # remove gc warnings on ref model
+
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
@@ -71,10 +91,17 @@ class CustomDPOTrainer(DPOTrainer):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if finetuning_args.pissa_convert:
+ self.callback_handler.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
@@ -87,12 +114,6 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
- def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
- super()._save(output_dir, state_dict)
- if self.processor is not None:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
-
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
@@ -176,7 +197,7 @@ class CustomDPOTrainer(DPOTrainer):
if self.ref_model is None:
ref_model = model
- ref_context = get_ref_context(self.accelerator, model)
+ ref_context = self.accelerator.unwrap_model(model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py
index 992985b0..431b5285 100644
--- a/src/llamafactory/train/dpo/workflow.py
+++ b/src/llamafactory/train/dpo/workflow.py
@@ -1,4 +1,19 @@
-# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
diff --git a/src/llamafactory/train/kto/__init__.py b/src/llamafactory/train/kto/__init__.py
index 34c7905a..a1900368 100644
--- a/src/llamafactory/train/kto/__init__.py
+++ b/src/llamafactory/train/kto/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_kto
diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py
index 22a84e4a..460311e4 100644
--- a/src/llamafactory/train/kto/trainer.py
+++ b/src/llamafactory/train/kto/trainer.py
@@ -1,3 +1,21 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
@@ -9,7 +27,8 @@ from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
-from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
+from ..callbacks import SaveProcessorCallback
+from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
@@ -35,7 +54,6 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
- self.processor = processor
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
@@ -60,6 +78,8 @@ class CustomKTOTrainer(KTOTrainer):
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
+ warnings.simplefilter("ignore") # remove gc warnings on ref model
+
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
@@ -70,10 +90,14 @@ class CustomKTOTrainer(KTOTrainer):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
@@ -92,12 +116,6 @@ class CustomKTOTrainer(KTOTrainer):
"""
return Trainer._get_train_sampler(self)
- def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
- super()._save(output_dir, state_dict)
- if self.processor is not None:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
-
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]:
@@ -143,7 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
"""
if self.ref_model is None:
ref_model = model
- ref_context = get_ref_context(self.accelerator, model)
+ ref_context = self.accelerator.unwrap_model(model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py
index c79b160b..8182a184 100644
--- a/src/llamafactory/train/kto/workflow.py
+++ b/src/llamafactory/train/kto/workflow.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, List, Optional
from ...data import KTODataCollatorWithPadding, get_dataset, split_dataset
diff --git a/src/llamafactory/train/ppo/__init__.py b/src/llamafactory/train/ppo/__init__.py
index d17336d5..161f6f5d 100644
--- a/src/llamafactory/train/ppo/__init__.py
+++ b/src/llamafactory/train/ppo/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_ppo
diff --git a/src/llamafactory/train/ppo/ppo_utils.py b/src/llamafactory/train/ppo/ppo_utils.py
index fec3fc1e..05c40946 100644
--- a/src/llamafactory/train/ppo/ppo_utils.py
+++ b/src/llamafactory/train/ppo/ppo_utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py
index 2e1288e4..57f0b848 100644
--- a/src/llamafactory/train/ppo/trainer.py
+++ b/src/llamafactory/train/ppo/trainer.py
@@ -1,6 +1,24 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
import os
import sys
+import warnings
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -9,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler
+from transformers.trainer_callback import CallbackHandler
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
@@ -16,9 +35,9 @@ from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation
-from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
+from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@@ -81,10 +100,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
)
# Add deepspeed config
- ppo_config.accelerator_kwargs["kwargs_handlers"] = [
- DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
- ]
if training_args.deepspeed_plugin is not None:
+ ppo_config.accelerator_kwargs["kwargs_handlers"] = [
+ DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
+ ]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
# Create optimizer and scheduler
@@ -113,7 +132,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.current_device = get_current_device() # patch for deepspeed training
- self.processor = processor
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
@@ -125,8 +143,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.control = TrainerControl()
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
- self.log_callback, self.save_callback = callbacks[0], callbacks[1]
- assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
+ self.callback_handler = CallbackHandler(
+ [callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
+ )
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
@@ -134,8 +153,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
- device_type = unwrapped_model.pretrained_model.device.type
- self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
+ self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype)
+ warnings.simplefilter("ignore") # remove gc warnings on ref model
if finetuning_args.reward_model_type == "full":
if self.is_deepspeed_enabled:
@@ -147,10 +166,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
+ self.add_callback(FixValueHeadModelCallback)
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
@@ -184,23 +209,23 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.is_world_process_zero():
logger.info("***** Running training *****")
- logger.info(" Num examples = {}".format(num_examples))
- logger.info(" Num Epochs = {}".format(num_train_epochs))
- logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
+ logger.info(" Num examples = {:,}".format(num_examples))
+ logger.info(" Num Epochs = {:,}".format(num_train_epochs))
+ logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size))
logger.info(
- " Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
+ " Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
total_train_batch_size
)
)
- logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
- logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
- logger.info(" Total training steps = {}".format(max_steps))
- logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
+ logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps))
+ logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs))
+ logger.info(" Total training steps = {:,}".format(max_steps))
+ logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0]))
dataiter = iter(self.dataloader)
loss_meter = AverageMeter()
reward_meter = AverageMeter()
- self.log_callback.on_train_begin(self.args, self.state, self.control)
+ self.callback_handler.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
try:
@@ -238,7 +263,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1
- self.log_callback.on_step_end(self.args, self.state, self.control)
+ self.callback_handler.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
logs = dict(
@@ -250,7 +275,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
tqdm.write(str(logs))
logs["step"] = step
self.state.log_history.append(logs)
- self.log_callback.on_log(self.args, self.state, self.control)
+ self.callback_handler.on_log(self.args, self.state, self.control, logs)
loss_meter.reset()
reward_meter.reset()
@@ -258,17 +283,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
)
- self.save_callback.on_save(
- self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
- )
+ self.callback_handler.on_save(self.args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
- self.log_callback.on_train_end(self.args, self.state, self.control)
- self.save_callback.on_train_end(
- self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
- )
+ self.callback_handler.on_train_end(self.args, self.state, self.control)
def create_optimizer(
self,
@@ -486,7 +506,3 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
elif self.args.should_save:
self._save(output_dir)
-
- if self.processor is not None and self.args.should_save:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py
index 111704c6..651296f3 100644
--- a/src/llamafactory/train/ppo/workflow.py
+++ b/src/llamafactory/train/ppo/workflow.py
@@ -1,14 +1,28 @@
-# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorWithPadding
from ...data import get_dataset
-from ...extras.callbacks import FixValueHeadModelCallback
-from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
+from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint
from ..trainer_utils import create_ref_model, create_reward_model
from .trainer import CustomPPOTrainer
@@ -60,6 +74,7 @@ def run_ppo(
ppo_trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
+
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])
diff --git a/src/llamafactory/train/pt/__init__.py b/src/llamafactory/train/pt/__init__.py
index bdf397f6..d80e6f22 100644
--- a/src/llamafactory/train/pt/__init__.py
+++ b/src/llamafactory/train/pt/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_pt
diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py
index 1d96e82f..e8f180a6 100644
--- a/src/llamafactory/train/pt/trainer.py
+++ b/src/llamafactory/train/pt/trainer.py
@@ -1,9 +1,24 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from types import MethodType
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Optional
from transformers import Trainer
from ...extras.logging import get_logger
+from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
@@ -27,11 +42,18 @@ class CustomTrainer(Trainer):
) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
- self.processor = processor
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.pissa_convert:
+ self.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
@@ -43,9 +65,3 @@ class CustomTrainer(Trainer):
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
-
- def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
- super()._save(output_dir, state_dict)
- if self.processor is not None:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py
index 8a635567..b84a0e7d 100644
--- a/src/llamafactory/train/pt/workflow.py
+++ b/src/llamafactory/train/pt/workflow.py
@@ -1,4 +1,19 @@
-# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import math
from typing import TYPE_CHECKING, List, Optional
diff --git a/src/llamafactory/train/rm/__init__.py b/src/llamafactory/train/rm/__init__.py
index dedac35f..48278315 100644
--- a/src/llamafactory/train/rm/__init__.py
+++ b/src/llamafactory/train/rm/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_rm
diff --git a/src/llamafactory/train/rm/metric.py b/src/llamafactory/train/rm/metric.py
index 99dc6ab8..fb880b1c 100644
--- a/src/llamafactory/train/rm/metric.py
+++ b/src/llamafactory/train/rm/metric.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import Dict, Sequence, Tuple, Union
import numpy as np
diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py
index bfb344dc..accc877d 100644
--- a/src/llamafactory/train/rm/trainer.py
+++ b/src/llamafactory/train/rm/trainer.py
@@ -1,3 +1,42 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# This code is inspired by the CarperAI's trlx library.
+# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# MIT License
+#
+# Copyright (c) 2022 CarperAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
import json
import os
from types import MethodType
@@ -7,6 +46,7 @@ import torch
from transformers import Trainer
from ...extras.logging import get_logger
+from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
@@ -30,12 +70,20 @@ class PairwiseTrainer(Trainer):
) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
- self.processor = processor
self.can_return_loss = True # override property to return eval_loss
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
+ self.add_callback(FixValueHeadModelCallback)
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.pissa_convert:
+ self.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
@@ -48,12 +96,6 @@ class PairwiseTrainer(Trainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
- def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
- super()._save(output_dir, state_dict)
- if self.processor is not None:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
-
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
@@ -63,7 +105,7 @@ class PairwiseTrainer(Trainer):
Subclass and override to inject custom behavior.
Note that the first element will be removed from the output tuple.
- See: https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/trainer.py#L3777
+ See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
"""
# Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
@@ -79,7 +121,6 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = [], []
# Compute pairwise loss. Only backprop on the different tokens before padding
- # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
loss = 0
for i in range(batch_size):
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
@@ -125,4 +166,5 @@ class PairwiseTrainer(Trainer):
res: List[str] = []
for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
+
writer.write("\n".join(res))
diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py
index 2e9e194b..e0b32b77 100644
--- a/src/llamafactory/train/rm/workflow.py
+++ b/src/llamafactory/train/rm/workflow.py
@@ -1,12 +1,48 @@
-# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
+# Copyright 2024 the LlamaFactory team.
+#
+# This code is inspired by the CarperAI's trlx library.
+# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# MIT License
+#
+# Copyright (c) 2022 CarperAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
-from ...extras.callbacks import FixValueHeadModelCallback
-from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
+from ..callbacks import fix_valuehead_checkpoint
from ..trainer_utils import create_modelcard_and_push
from .metric import compute_accuracy
from .trainer import PairwiseTrainer
@@ -40,7 +76,7 @@ def run_rm(
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
- callbacks=callbacks + [FixValueHeadModelCallback()],
+ callbacks=callbacks,
compute_metrics=compute_accuracy,
**tokenizer_module,
**split_dataset(dataset, data_args, training_args),
@@ -52,6 +88,7 @@ def run_rm(
trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
+
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
diff --git a/src/llamafactory/train/sft/__init__.py b/src/llamafactory/train/sft/__init__.py
index f2f84e78..475dfe5f 100644
--- a/src/llamafactory/train/sft/__init__.py
+++ b/src/llamafactory/train/sft/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_sft
diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py
index b135fcfb..c69608c0 100644
--- a/src/llamafactory/train/sft/metric.py
+++ b/src/llamafactory/train/sft/metric.py
@@ -1,14 +1,35 @@
+# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
+# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Dict
import numpy as np
+import torch
+from transformers import EvalPrediction
+from transformers.utils import is_jieba_available, is_nltk_available
from ...extras.constants import IGNORE_INDEX
-from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
+from ...extras.packages import is_rouge_available
if TYPE_CHECKING:
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer
if is_jieba_available():
@@ -23,6 +44,22 @@ if is_rouge_available():
from rouge_chinese import Rouge
+def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
+ preds, labels = eval_preds.predictions, eval_preds.label_ids
+ accuracies = []
+ for i in range(len(preds)):
+ pred, label = preds[i, :-1], labels[i, 1:]
+ label_mask = label != IGNORE_INDEX
+ accuracies.append(np.mean(pred[label_mask] == label[label_mask]))
+
+ return {"accuracy": float(np.mean(accuracies))}
+
+
+def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
+ logits = logits[0] if isinstance(logits, (list, tuple)) else logits
+ return torch.argmax(logits, dim=-1)
+
+
@dataclass
class ComputeMetrics:
r"""
@@ -31,11 +68,11 @@ class ComputeMetrics:
tokenizer: "PreTrainedTokenizer"
- def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
+ def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]:
r"""
Uses the model predictions to compute metrics.
"""
- preds, labels = eval_preds
+ preds, labels = eval_preds.predictions, eval_preds.label_ids
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py
index c063b214..954bb69f 100644
--- a/src/llamafactory/train/sft/trainer.py
+++ b/src/llamafactory/train/sft/trainer.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from types import MethodType
@@ -9,10 +26,12 @@ from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
+from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
+ from torch.utils.data import Dataset
from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput
@@ -32,11 +51,18 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
- self.processor = processor
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.pissa_convert:
+ self.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
@@ -49,12 +75,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
- def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
- super()._save(output_dir, state_dict)
- if self.processor is not None:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
-
def prediction_step(
self,
model: "torch.nn.Module",
@@ -94,7 +114,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory
- def save_predictions(self, predict_results: "PredictionOutput") -> None:
+ def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None:
r"""
Saves model predictions to `output_dir`.
@@ -115,18 +135,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
- if len(pad_len):
- preds[i] = np.concatenate(
- (preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
- ) # move pad token to last
+ if len(pad_len): # move pad token to last
+ preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
- decoded_labels = self.tokenizer.batch_decode(
- labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
- )
- decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
+ decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
+ decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
+ decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
- for label, pred in zip(decoded_labels, decoded_preds):
- res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
+ for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
+ res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
+
writer.write("\n".join(res))
diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py
index f1e000bd..c12a70aa 100644
--- a/src/llamafactory/train/sft/workflow.py
+++ b/src/llamafactory/train/sft/workflow.py
@@ -1,4 +1,19 @@
-# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
@@ -10,7 +25,7 @@ from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
-from .metric import ComputeMetrics
+from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
@@ -56,7 +71,8 @@ def run_sft(
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
- compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
+ compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
+ preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
**tokenizer_module,
**split_dataset(dataset, data_args, training_args),
)
@@ -75,7 +91,7 @@ def run_sft(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
- plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
# Evaluation
if training_args.do_eval:
@@ -92,7 +108,7 @@ def run_sft(
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
- trainer.save_predictions(predict_results)
+ trainer.save_predictions(dataset, predict_results)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py
index 0ddcdb11..4b581691 100644
--- a/src/llamafactory/train/trainer_utils.py
+++ b/src/llamafactory/train/trainer_utils.py
@@ -1,8 +1,27 @@
-from contextlib import contextmanager
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
+# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus
+# and the original BAdam's implementation: https://github.com/Ledzy/BAdam
+# and the HuggingFace's TRL library: https://github.com/huggingface/trl
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
+from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
@@ -19,7 +38,6 @@ if is_galore_available():
if TYPE_CHECKING:
- from accelerate import Accelerator
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
@@ -83,15 +101,12 @@ def create_ref_model(
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
- ref_model_args_dict = model_args.to_dict()
- ref_model_args_dict.update(
- dict(
- model_name_or_path=finetuning_args.ref_model,
- adapter_name_or_path=finetuning_args.ref_model_adapters,
- quantization_bit=finetuning_args.ref_model_quantization_bit,
- )
+ ref_model_args = ModelArguments.copyfrom(
+ model_args,
+ model_name_or_path=finetuning_args.ref_model,
+ adapter_name_or_path=finetuning_args.ref_model_adapters,
+ quantization_bit=finetuning_args.ref_model_quantization_bit,
)
- ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model(
@@ -102,9 +117,11 @@ def create_ref_model(
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
- tokenizer = load_tokenizer(model_args)["tokenizer"]
+ ref_model_args = ModelArguments.copyfrom(model_args)
+ ref_finetuning_args = FinetuningArguments()
+ tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model(
- tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
+ tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from the model itself.")
@@ -139,15 +156,12 @@ def create_reward_model(
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None
else:
- reward_model_args_dict = model_args.to_dict()
- reward_model_args_dict.update(
- dict(
- model_name_or_path=finetuning_args.reward_model,
- adapter_name_or_path=finetuning_args.reward_model_adapters,
- quantization_bit=finetuning_args.reward_model_quantization_bit,
- )
+ reward_model_args = ModelArguments.copyfrom(
+ model_args,
+ model_name_or_path=finetuning_args.reward_model,
+ adapter_name_or_path=finetuning_args.reward_model_adapters,
+ quantization_bit=finetuning_args.reward_model_quantization_bit,
)
- reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
reward_model = load_model(
@@ -158,17 +172,6 @@ def create_reward_model(
return reward_model
-@contextmanager
-def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
- r"""
- Gets adapter context for the reference model.
- """
- with accelerator.unwrap_model(model).disable_adapter():
- model.eval()
- yield
- model.train()
-
-
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
@@ -184,7 +187,7 @@ def _create_galore_optimizer(
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
- galore_targets = find_all_linear_modules(model)
+ galore_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
galore_targets = finetuning_args.galore_target
@@ -334,6 +337,7 @@ def _create_badam_optimizer(
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
+ ds_zero3_enabled=is_deepspeed_zero3_enabled(),
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
@@ -355,7 +359,7 @@ def _create_badam_optimizer(
**optim_kwargs,
)
logger.info(
- f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
+ f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)
diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py
index eed875e9..dc982e07 100644
--- a/src/llamafactory/train/tuner.py
+++ b/src/llamafactory/train/tuner.py
@@ -1,13 +1,30 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
-from ..extras.callbacks import LogCallback
+from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
+from .callbacks import LogCallback
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
@@ -24,8 +41,8 @@ logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
+ callbacks.append(LogCallback())
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
- callbacks.append(LogCallback(training_args.output_dir))
if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
@@ -84,6 +101,25 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
safe_serialization=(not model_args.export_legacy_format),
)
+ if finetuning_args.stage == "rm":
+ if model_args.adapter_name_or_path is not None:
+ vhead_path = model_args.adapter_name_or_path[-1]
+ else:
+ vhead_path = model_args.model_name_or_path
+
+ if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
+ shutil.copy(
+ os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
+ os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
+ )
+ logger.info("Copied valuehead to {}.".format(model_args.export_dir))
+ elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
+ shutil.copy(
+ os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
+ os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
+ )
+ logger.info("Copied valuehead to {}.".format(model_args.export_dir))
+
try:
tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left"
diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py
index c82710d3..8abef920 100644
--- a/src/llamafactory/webui/chatter.py
+++ b/src/llamafactory/webui/chatter.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
@@ -9,7 +23,7 @@ from ..data import Role
from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
-from .common import get_save_dir
+from .common import QUANTIZATION_BITS, get_save_dir
from .locales import ALERTS
@@ -62,17 +76,24 @@ class WebChatModel(ChatModel):
yield error
return
+ if get("top.quantization_bit") in QUANTIZATION_BITS:
+ quantization_bit = int(get("top.quantization_bit"))
+ else:
+ quantization_bit = None
+
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_path,
finetuning_type=finetuning_type,
- quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ quantization_bit=quantization_bit,
+ quantization_method=get("top.quantization_method"),
template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
+ infer_dtype=get("infer.infer_dtype"),
)
if checkpoint_path:
@@ -126,16 +147,15 @@ class WebChatModel(ChatModel):
):
response += new_text
if tools:
- result = self.engine.template.format_tools.extract(response)
+ result = self.engine.template.extract_tool(response)
else:
result = response
- if isinstance(result, tuple):
- name, arguments = result
- arguments = json.loads(arguments)
- tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
- output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
- bot_text = "```json\n" + tool_call + "\n```"
+ if isinstance(result, list):
+ tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result]
+ tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
+ output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
+ bot_text = "```json\n" + tool_calls + "\n```"
else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = result
diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py
index 37b38df0..bced18f0 100644
--- a/src/llamafactory/webui/common.py
+++ b/src/llamafactory/webui/common.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from collections import defaultdict
@@ -33,13 +47,19 @@ DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"
+QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"]
+GPTQ_BITS = ["8", "4", "3", "2"]
def get_save_dir(*paths: str) -> os.PathLike:
r"""
Gets the path to saved model checkpoints.
"""
- paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths)
+ if os.path.sep in paths[-1]:
+ logger.warning("Found complex path, some features may be not available.")
+ return paths[-1]
+
+ paths = (path.replace(" ", "").strip() for path in paths)
return os.path.join(DEFAULT_SAVE_DIR, *paths)
diff --git a/src/llamafactory/webui/components/__init__.py b/src/llamafactory/webui/components/__init__.py
index 5c1e21b8..715fb6e4 100644
--- a/src/llamafactory/webui/components/__init__.py
+++ b/src/llamafactory/webui/components/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .chatbot import create_chat_box
from .eval import create_eval_tab
from .export import create_export_tab
diff --git a/src/llamafactory/webui/components/chatbot.py b/src/llamafactory/webui/components/chatbot.py
index f83694b1..ad74114b 100644
--- a/src/llamafactory/webui/components/chatbot.py
+++ b/src/llamafactory/webui/components/chatbot.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict, Tuple
from ...data import Role
diff --git a/src/llamafactory/webui/components/data.py b/src/llamafactory/webui/components/data.py
index 232b973d..88e500cf 100644
--- a/src/llamafactory/webui/components/data.py
+++ b/src/llamafactory/webui/components/data.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
diff --git a/src/llamafactory/webui/components/eval.py b/src/llamafactory/webui/components/eval.py
index 0a7a0f44..b522913e 100644
--- a/src/llamafactory/webui/components/eval.py
+++ b/src/llamafactory/webui/components/eval.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py
index 7e1493c8..0a938f02 100644
--- a/src/llamafactory/webui/components/export.py
+++ b/src/llamafactory/webui/components/export.py
@@ -1,10 +1,24 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict, Generator, List, Union
from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train.tuner import export_model
-from ..common import get_save_dir
+from ..common import GPTQ_BITS, get_save_dir
from ..locales import ALERTS
@@ -18,7 +32,11 @@ if TYPE_CHECKING:
from ..engine import Engine
-GPTQ_BITS = ["8", "4", "3", "2"]
+def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
+ if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
+ return gr.Dropdown(value="none", interactive=False)
+ else:
+ return gr.Dropdown(interactive=True)
def save_model(
@@ -96,6 +114,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
export_dir = gr.Textbox()
export_hub_model_id = gr.Textbox()
+ checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path")
+ checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False)
+
export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
diff --git a/src/llamafactory/webui/components/infer.py b/src/llamafactory/webui/components/infer.py
index 970f4629..a0064479 100644
--- a/src/llamafactory/webui/components/infer.py
+++ b/src/llamafactory/webui/components/infer.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
@@ -18,15 +32,26 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
- infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
+ with gr.Row():
+ infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
+ infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto")
+
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
- input_elems.update({infer_backend})
- elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
+ input_elems.update({infer_backend, infer_dtype})
+ elem_dict.update(
+ dict(
+ infer_backend=infer_backend,
+ infer_dtype=infer_dtype,
+ load_btn=load_btn,
+ unload_btn=unload_btn,
+ info_box=info_box,
+ )
+ )
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(chat_elems)
diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py
index fd0ead3d..9df3f062 100644
--- a/src/llamafactory/webui/components/top.py
+++ b/src/llamafactory/webui/components/top.py
@@ -1,10 +1,24 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict
from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_info, list_checkpoints, save_config
-from ..utils import can_quantize
+from ..utils import can_quantize, can_quantize_to
if is_gradio_available():
@@ -29,17 +43,23 @@ def create_top() -> Dict[str, "Component"]:
with gr.Accordion(open=False) as advanced_tab:
with gr.Row():
- quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
- template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
- rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
- booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
+ quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=1)
+ quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1)
+ template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
+ rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
+ booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2)
visual_inputs = gr.Checkbox(scale=1)
- model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False)
+ model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
+ list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
+ )
model_name.input(save_config, inputs=[lang, model_name], queue=False)
model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
- finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False)
+ finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then(
+ list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
+ )
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
+ quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False)
return dict(
lang=lang,
@@ -49,6 +69,7 @@ def create_top() -> Dict[str, "Component"]:
checkpoint_path=checkpoint_path,
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
+ quantization_method=quantization_method,
template=template,
rope_scaling=rope_scaling,
booster=booster,
diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py
index dccc8500..4636050b 100644
--- a/src/llamafactory/webui/components/train.py
+++ b/src/llamafactory/webui/components/train.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
@@ -40,7 +54,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
num_train_epochs = gr.Textbox(value="3.0")
max_grad_norm = gr.Textbox(value="1.0")
max_samples = gr.Textbox(value="100000")
- compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
+ compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16")
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
elem_dict.update(
@@ -152,10 +166,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter = gr.Checkbox()
with gr.Row():
- with gr.Column(scale=1):
- use_rslora = gr.Checkbox()
- use_dora = gr.Checkbox()
-
+ use_rslora = gr.Checkbox()
+ use_dora = gr.Checkbox()
+ use_pissa = gr.Checkbox()
lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=2)
@@ -168,6 +181,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter,
use_rslora,
use_dora,
+ use_pissa,
lora_target,
additional_target,
}
@@ -182,6 +196,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter=create_new_adapter,
use_rslora=use_rslora,
use_dora=use_dora,
+ use_pissa=use_pissa,
lora_target=lora_target,
additional_target=additional_target,
)
@@ -279,7 +294,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Column(scale=1):
loss_viewer = gr.Plot()
- input_elems.update({output_dir, config_path, device_count, ds_stage, ds_offload})
+ input_elems.update({output_dir, config_path, ds_stage, ds_offload})
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
diff --git a/src/llamafactory/webui/css.py b/src/llamafactory/webui/css.py
index 36e3d4c2..53982119 100644
--- a/src/llamafactory/webui/css.py
+++ b/src/llamafactory/webui/css.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
CSS = r"""
.duplicate-button {
margin: auto !important;
diff --git a/src/llamafactory/webui/engine.py b/src/llamafactory/webui/engine.py
index eb6142d3..04893215 100644
--- a/src/llamafactory/webui/engine.py
+++ b/src/llamafactory/webui/engine.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict
from .chatter import WebChatModel
diff --git a/src/llamafactory/webui/interface.py b/src/llamafactory/webui/interface.py
index bae3ba76..d25f4d38 100644
--- a/src/llamafactory/webui/interface.py
+++ b/src/llamafactory/webui/interface.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from ..extras.packages import is_gradio_available
diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py
index 05cf3bed..852b1b3c 100644
--- a/src/llamafactory/webui/locales.py
+++ b/src/llamafactory/webui/locales.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
LOCALES = {
"lang": {
"en": {
@@ -71,15 +85,29 @@ LOCALES = {
"quantization_bit": {
"en": {
"label": "Quantization bit",
- "info": "Enable 4/8-bit model quantization (QLoRA).",
+ "info": "Enable quantization (QLoRA).",
},
"ru": {
"label": "Уровень квантования",
- "info": "Включить 4/8-битное квантование модели (QLoRA).",
+ "info": "Включить квантование (QLoRA).",
},
"zh": {
"label": "量化等级",
- "info": "启用 4/8 比特模型量化(QLoRA)。",
+ "info": "启用量化(QLoRA)。",
+ },
+ },
+ "quantization_method": {
+ "en": {
+ "label": "Quantization method",
+ "info": "Quantization algorithm to use.",
+ },
+ "ru": {
+ "label": "Метод квантования",
+ "info": "Алгоритм квантования, который следует использовать.",
+ },
+ "zh": {
+ "label": "量化方法",
+ "info": "使用的量化算法。",
},
},
"template": {
@@ -732,6 +760,20 @@ LOCALES = {
"info": "使用权重分解的 LoRA。",
},
},
+ "use_pissa": {
+ "en": {
+ "label": "Use PiSSA",
+ "info": "Use PiSSA method.",
+ },
+ "ru": {
+ "label": "используйте PiSSA",
+ "info": "Используйте метод PiSSA.",
+ },
+ "zh": {
+ "label": "使用 PiSSA",
+ "info": "使用 PiSSA 方法。",
+ },
+ },
"lora_target": {
"en": {
"label": "LoRA modules (optional)",
@@ -1192,6 +1234,17 @@ LOCALES = {
"label": "推理引擎",
},
},
+ "infer_dtype": {
+ "en": {
+ "label": "Inference data type",
+ },
+ "ru": {
+ "label": "Тип данных для вывода",
+ },
+ "zh": {
+ "label": "推理数据类型",
+ },
+ },
"load_btn": {
"en": {
"value": "Load model",
diff --git a/src/llamafactory/webui/manager.py b/src/llamafactory/webui/manager.py
index 326fdb8d..ebe9f1b9 100644
--- a/src/llamafactory/webui/manager.py
+++ b/src/llamafactory/webui/manager.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple
@@ -57,6 +71,7 @@ class Manager:
self._id_to_elem["top.finetuning_type"],
self._id_to_elem["top.checkpoint_path"],
self._id_to_elem["top.quantization_bit"],
+ self._id_to_elem["top.quantization_method"],
self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"],
diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py
index 852805da..ffec54e2 100644
--- a/src/llamafactory/webui/runner.py
+++ b/src/llamafactory/webui/runner.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from copy import deepcopy
from subprocess import Popen, TimeoutExpired
@@ -8,9 +22,9 @@ from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.packages import is_gradio_available
-from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
+from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .locales import ALERTS, LOCALES
-from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
+from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
if is_gradio_available():
@@ -38,7 +52,7 @@ class Runner:
def set_abort(self) -> None:
self.aborted = True
if self.trainer is not None:
- abort_leaf_process(self.trainer.pid)
+ abort_process(self.trainer.pid)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
@@ -90,6 +104,11 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
+ if get("top.quantization_bit") in QUANTIZATION_BITS:
+ quantization_bit = int(get("top.quantization_bit"))
+ else:
+ quantization_bit = None
+
args = dict(
stage=TRAINING_STAGES[get("train.training_stage")],
do_train=True,
@@ -97,7 +116,8 @@ class Runner:
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=finetuning_type,
- quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ quantization_bit=quantization_bit,
+ quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
@@ -160,6 +180,8 @@ class Runner:
args["create_new_adapter"] = get("train.create_new_adapter")
args["use_rslora"] = get("train.use_rslora")
args["use_dora"] = get("train.use_dora")
+ args["pissa_init"] = get("train.use_pissa")
+ args["pissa_convert"] = get("train.use_pissa")
args["lora_target"] = get("train.lora_target") or "all"
args["additional_target"] = get("train.additional_target") or None
@@ -219,13 +241,19 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
+ if get("top.quantization_bit") in QUANTIZATION_BITS:
+ quantization_bit = int(get("top.quantization_bit"))
+ else:
+ quantization_bit = None
+
args = dict(
stage="sft",
model_name_or_path=get("top.model_path"),
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=finetuning_type,
- quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ quantization_bit=quantization_bit,
+ quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
@@ -283,6 +311,7 @@ class Runner:
env = deepcopy(os.environ)
env["LLAMABOARD_ENABLED"] = "1"
+ env["LLAMABOARD_WORKDIR"] = args["output_dir"]
if args.get("deepspeed", None) is not None:
env["FORCE_TORCHRUN"] = "1"
@@ -291,7 +320,7 @@ class Runner:
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
config_dict = {}
- skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
+ skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
for elem, value in data.items():
elem_id = self.manager.get_id_by_elem(elem)
if elem_id not in skip_ids:
diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py
index e39f2aa4..6e5fdbe4 100644
--- a/src/llamafactory/webui/utils.py
+++ b/src/llamafactory/webui/utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
import signal
@@ -11,6 +25,7 @@ from yaml import safe_dump, safe_load
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
+from ..model import QuantizationMethod
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
from .locales import ALERTS
@@ -19,16 +34,19 @@ if is_gradio_available():
import gradio as gr
-def abort_leaf_process(pid: int) -> None:
+def abort_process(pid: int) -> None:
r"""
- Aborts the leaf processes.
+ Aborts the processes recursively in a bottom-up way.
"""
- children = psutil.Process(pid).children()
- if children:
- for child in children:
- abort_leaf_process(child.pid)
- else:
+ try:
+ children = psutil.Process(pid).children()
+ if children:
+ for child in children:
+ abort_process(child.pid)
+
os.kill(pid, signal.SIGABRT)
+ except Exception:
+ pass
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
@@ -41,6 +59,20 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
return gr.Dropdown(interactive=True)
+def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
+ r"""
+ Returns the available quantization bits.
+ """
+ if quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
+ available_bits = ["none", "8", "4"]
+ elif quantization_method == QuantizationMethod.HQQ.value:
+ available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"]
+ elif quantization_method == QuantizationMethod.EETQ.value:
+ available_bits = ["none", "8"]
+
+ return gr.Dropdown(choices=available_bits)
+
+
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
diff --git a/src/train.py b/src/train.py
index b20aa9d2..6703ffdb 100644
--- a/src/train.py
+++ b/src/train.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from llamafactory.train.tuner import run_exp
diff --git a/src/webui.py b/src/webui.py
index bbefb54e..99370af2 100644
--- a/src/webui.py
+++ b/src/webui.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from llamafactory.webui.interface import create_ui
diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py
new file mode 100644
index 00000000..1845df24
--- /dev/null
+++ b/tests/data/test_formatter.py
@@ -0,0 +1,123 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
+
+
+def test_empty_formatter():
+ formatter = EmptyFormatter(slots=["\n"])
+ assert formatter.apply() == ["\n"]
+
+
+def test_string_formatter():
+ formatter = StringFormatter(slots=["", "Human: {{content}}\nAssistant:"])
+ assert formatter.apply(content="Hi") == ["", "Human: Hi\nAssistant:"]
+
+
+def test_function_formatter():
+ formatter = FunctionFormatter(slots=[], tool_format="default")
+ tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
+ assert formatter.apply(content=tool_calls) == [
+ """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
+ ]
+
+
+def test_multi_function_formatter():
+ formatter = FunctionFormatter(slots=[], tool_format="default")
+ tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
+ assert formatter.apply(content=tool_calls) == [
+ """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
+ """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
+ ]
+
+
+def test_default_tool_formatter():
+ formatter = ToolFormatter(tool_format="default")
+ tools = [
+ {
+ "name": "test_tool",
+ "description": "tool_desc",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "foo": {"type": "string", "description": "foo_desc"},
+ "bar": {"type": "number", "description": "bar_desc"},
+ },
+ "required": ["foo"],
+ },
+ }
+ ]
+ assert formatter.apply(content=json.dumps(tools)) == [
+ "You have access to the following tools:\n"
+ "> Tool Name: test_tool\n"
+ "Tool Description: tool_desc\n"
+ "Tool Args:\n"
+ " - foo (string, required): foo_desc\n"
+ " - bar (number): bar_desc\n\n"
+ "Use the following format if using a tool:\n"
+ "```\n"
+ "Action: tool name (one of [test_tool]).\n"
+ "Action Input: the input to the tool, in a JSON format representing the kwargs "
+ """(e.g. ```{"input": "hello world", "num_beams": 5}```).\n"""
+ "```\n"
+ ]
+
+
+def test_default_tool_extractor():
+ formatter = ToolFormatter(tool_format="default")
+ result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
+ assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
+
+
+def test_default_multi_tool_extractor():
+ formatter = ToolFormatter(tool_format="default")
+ result = (
+ """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
+ """Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
+ )
+ assert formatter.extract(result) == [
+ ("test_tool", """{"foo": "bar", "size": 10}"""),
+ ("another_tool", """{"foo": "job", "size": 2}"""),
+ ]
+
+
+def test_glm4_tool_formatter():
+ formatter = ToolFormatter(tool_format="glm4")
+ tools = [
+ {
+ "name": "test_tool",
+ "description": "tool_desc",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "foo": {"type": "string", "description": "foo_desc"},
+ "bar": {"type": "number", "description": "bar_desc"},
+ },
+ "required": ["foo"],
+ },
+ }
+ ]
+ assert formatter.apply(content=json.dumps(tools)) == [
+ "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
+ "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
+ "## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(json.dumps(tools[0], indent=4))
+ ]
+
+
+def test_glm4_tool_extractor():
+ formatter = ToolFormatter(tool_format="glm4")
+ result = """test_tool\n{"foo": "bar", "size": 10}\n"""
+ assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
diff --git a/tests/data/test_processor.py b/tests/data/test_processor.py
new file mode 100644
index 00000000..fa8f7172
--- /dev/null
+++ b/tests/data/test_processor.py
@@ -0,0 +1,32 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple
+
+import pytest
+
+from llamafactory.data.processors.processor_utils import infer_seqlen
+
+
+@pytest.mark.parametrize(
+ "test_input,test_output",
+ [
+ ((3000, 2000, 1000), (600, 400)),
+ ((2000, 3000, 1000), (400, 600)),
+ ((1000, 100, 1000), (900, 100)),
+ ((100, 1000, 1000), (100, 900)),
+ ],
+)
+def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):
+ assert test_output == infer_seqlen(*test_input)
diff --git a/tests/data/test_supervised.py b/tests/data/test_supervised.py
index bb7f71df..9cb49615 100644
--- a/tests/data/test_supervised.py
+++ b/tests/data/test_supervised.py
@@ -1,24 +1,40 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
+import random
import pytest
from datasets import load_dataset
+from transformers import AutoTokenizer
from llamafactory.data import get_dataset
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
-TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
-TRAINING_ARGS = {
+TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
- "dataset": "llamafactory/tiny_dataset",
+ "dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
- "cutoff_len": 1024,
+ "cutoff_len": 8192,
"overwrite_cache": True,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
@@ -26,19 +42,26 @@ TRAINING_ARGS = {
}
-@pytest.mark.parametrize("test_num", [5])
-def test_supervised(test_num: int):
- model_args, data_args, training_args, _, _ = get_train_args(TRAINING_ARGS)
+@pytest.mark.parametrize("num_samples", [16])
+def test_supervised(num_samples: int):
+ model_args, data_args, training_args, _, _ = get_train_args(TRAIN_ARGS)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
tokenized_data = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
- original_data = load_dataset(TRAINING_ARGS["dataset"], split="train")
- for test_idx in range(test_num):
- decode_result = tokenizer.decode(tokenized_data["input_ids"][test_idx])
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+
+ original_data = load_dataset(TRAIN_ARGS["dataset"], split="train")
+ indexes = random.choices(range(len(original_data)), k=num_samples)
+ for index in indexes:
+ prompt = original_data[index]["instruction"]
+ if original_data[index]["input"]:
+ prompt += "\n" + original_data[index]["input"]
+
messages = [
- {"role": "user", "content": original_data[test_idx]["instruction"]},
- {"role": "assistant", "content": original_data[test_idx]["output"]},
+ {"role": "user", "content": prompt},
+ {"role": "assistant", "content": original_data[index]["output"]},
]
- templated_result = tokenizer.apply_chat_template(messages, tokenize=False)
- assert decode_result == templated_result
+ templated_result = ref_tokenizer.apply_chat_template(messages, tokenize=False)
+ decoded_result = tokenizer.decode(tokenized_data["input_ids"][index])
+ assert templated_result == decoded_result
diff --git a/tests/data/test_template.py b/tests/data/test_template.py
new file mode 100644
index 00000000..e4728a84
--- /dev/null
+++ b/tests/data/test_template.py
@@ -0,0 +1,80 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from transformers import AutoTokenizer
+
+from llamafactory.data import get_template_and_fix_tokenizer
+
+
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+
+MESSAGES = [
+ {"role": "user", "content": "How are you"},
+ {"role": "assistant", "content": "I am fine!"},
+ {"role": "user", "content": "你好"},
+ {"role": "assistant", "content": "很高兴认识你!"},
+]
+
+
+def test_encode_oneturn():
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
+ prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
+ assert tokenizer.decode(prompt_ids) == (
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+ assert tokenizer.decode(answer_ids) == "很高兴认识你!<|eot_id|>"
+
+
+def test_encode_multiturn():
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
+ encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
+ assert tokenizer.decode(encoded_pairs[0][0]) == (
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+ assert tokenizer.decode(encoded_pairs[0][1]) == "I am fine!<|eot_id|>"
+ assert tokenizer.decode(encoded_pairs[1][0]) == (
+ "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+ assert tokenizer.decode(encoded_pairs[1][1]) == "很高兴认识你!<|eot_id|>"
+
+
+def test_jinja_template():
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ get_template_and_fix_tokenizer(tokenizer, name="llama3")
+ assert tokenizer.chat_template != ref_tokenizer.chat_template
+ assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
+
+
+def test_qwen_template():
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
+ template = get_template_and_fix_tokenizer(tokenizer, name="qwen")
+ prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
+ assert tokenizer.decode(prompt_ids) == (
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
+ "<|im_start|>user\nHow are you<|im_end|>\n"
+ "<|im_start|>assistant\nI am fine!<|im_end|>\n"
+ "<|im_start|>user\n你好<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+ assert tokenizer.decode(answer_ids) == "很高兴认识你!<|im_end|>"
diff --git a/tests/eval/test_eval_template.py b/tests/eval/test_eval_template.py
new file mode 100644
index 00000000..f85d9d57
--- /dev/null
+++ b/tests/eval/test_eval_template.py
@@ -0,0 +1,91 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from llamafactory.eval.template import get_eval_template
+
+
+def test_eval_template_en():
+ support_set = [
+ {
+ "question": "Fewshot question",
+ "A": "Fewshot1",
+ "B": "Fewshot2",
+ "C": "Fewshot3",
+ "D": "Fewshot4",
+ "answer": "B",
+ }
+ ]
+ example = {
+ "question": "Target question",
+ "A": "Target1",
+ "B": "Target2",
+ "C": "Target3",
+ "D": "Target4",
+ "answer": "C",
+ }
+ template = get_eval_template(name="en")
+ messages = template.format_example(example, support_set=support_set, subject_name="SubName")
+ assert messages == [
+ {
+ "role": "user",
+ "content": (
+ "The following are multiple choice questions (with answers) about SubName.\n\n"
+ "Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:"
+ ),
+ },
+ {"role": "assistant", "content": "B"},
+ {
+ "role": "user",
+ "content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:",
+ },
+ {"role": "assistant", "content": "C"},
+ ]
+
+
+def test_eval_template_zh():
+ support_set = [
+ {
+ "question": "示例问题",
+ "A": "示例答案1",
+ "B": "示例答案2",
+ "C": "示例答案3",
+ "D": "示例答案4",
+ "answer": "B",
+ }
+ ]
+ example = {
+ "question": "目标问题",
+ "A": "目标答案1",
+ "B": "目标答案2",
+ "C": "目标答案3",
+ "D": "目标答案4",
+ "answer": "C",
+ }
+ template = get_eval_template(name="zh")
+ messages = template.format_example(example, support_set=support_set, subject_name="主题")
+ assert messages == [
+ {
+ "role": "user",
+ "content": (
+ "以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n"
+ "示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:"
+ ),
+ },
+ {"role": "assistant", "content": "B"},
+ {
+ "role": "user",
+ "content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:",
+ },
+ {"role": "assistant", "content": "C"},
+ ]
diff --git a/tests/model/model_utils/test_attention.py b/tests/model/model_utils/test_attention.py
index 4d414289..4cae3d7c 100644
--- a/tests/model/model_utils/test_attention.py
+++ b/tests/model/model_utils/test_attention.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
@@ -6,11 +20,16 @@ from llamafactory.hparams import get_infer_args
from llamafactory.model import load_model, load_tokenizer
-TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "template": "llama3",
+}
def test_attention():
- attention_available = ["off"]
+ attention_available = ["disabled"]
if is_torch_sdpa_available():
attention_available.append("sdpa")
@@ -18,18 +37,12 @@ def test_attention():
attention_available.append("fa2")
llama_attention_classes = {
- "off": "LlamaAttention",
+ "disabled": "LlamaAttention",
"sdpa": "LlamaSdpaAttention",
"fa2": "LlamaFlashAttention2",
}
for requested_attention in attention_available:
- model_args, _, finetuning_args, _ = get_infer_args(
- {
- "model_name_or_path": TINY_LLAMA,
- "template": "llama2",
- "flash_attn": requested_attention,
- }
- )
+ model_args, _, finetuning_args, _ = get_infer_args({"flash_attn": requested_attention, **INFER_ARGS})
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args)
for module in model.modules():
diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py
new file mode 100644
index 00000000..9b6dfc9e
--- /dev/null
+++ b/tests/model/model_utils/test_checkpointing.py
@@ -0,0 +1,74 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import torch
+
+from llamafactory.extras.misc import get_current_device
+from llamafactory.hparams import get_train_args
+from llamafactory.model import load_model, load_tokenizer
+
+
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+
+TRAIN_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "stage": "sft",
+ "do_train": True,
+ "finetuning_type": "lora",
+ "lora_target": "all",
+ "dataset": "llamafactory/tiny-supervised-dataset",
+ "dataset_dir": "ONLINE",
+ "template": "llama3",
+ "cutoff_len": 1024,
+ "overwrite_cache": True,
+ "output_dir": "dummy_dir",
+ "overwrite_output_dir": True,
+ "fp16": True,
+}
+
+
+def test_checkpointing_enable():
+ model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": False, **TRAIN_ARGS})
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+ for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
+ assert getattr(module, "gradient_checkpointing") is True
+
+
+def test_checkpointing_disable():
+ model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": True, **TRAIN_ARGS})
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+ for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
+ assert getattr(module, "gradient_checkpointing") is False
+
+
+def test_upcast_layernorm():
+ model_args, _, _, finetuning_args, _ = get_train_args({"upcast_layernorm": True, **TRAIN_ARGS})
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+ for name, param in model.named_parameters():
+ if param.ndim == 1 and "norm" in name:
+ assert param.dtype == torch.float32
+
+
+def test_upcast_lmhead_output():
+ model_args, _, _, finetuning_args, _ = get_train_args({"upcast_lmhead_output": True, **TRAIN_ARGS})
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+ inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
+ outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)
+ assert outputs.dtype == torch.float32
diff --git a/tests/model/test_base.py b/tests/model/test_base.py
new file mode 100644
index 00000000..6431a504
--- /dev/null
+++ b/tests/model/test_base.py
@@ -0,0 +1,80 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Dict
+
+import pytest
+import torch
+from transformers import AutoModelForCausalLM
+from trl import AutoModelForCausalLMWithValueHead
+
+from llamafactory.extras.misc import get_current_device
+from llamafactory.hparams import get_infer_args
+from llamafactory.model import load_model, load_tokenizer
+
+
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+
+TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
+
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "template": "llama3",
+ "infer_dtype": "float16",
+}
+
+
+def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
+ state_dict_a = model_a.state_dict()
+ state_dict_b = model_b.state_dict()
+ assert set(state_dict_a.keys()) == set(state_dict_b.keys())
+ for name in state_dict_a.keys():
+ assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5)
+
+
+@pytest.fixture
+def fix_valuehead_cpu_loading():
+ def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
+ state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
+ self.v_head.load_state_dict(state_dict, strict=False)
+ del state_dict
+
+ AutoModelForCausalLMWithValueHead.post_init = post_init
+
+
+def test_base():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
+
+ ref_model = AutoModelForCausalLM.from_pretrained(
+ TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ compare_model(model, ref_model)
+
+
+@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
+def test_valuehead():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(
+ tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True
+ )
+
+ ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
+ TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ ref_model.v_head = ref_model.v_head.to(torch.float16)
+ compare_model(model, ref_model)
diff --git a/tests/model/test_freeze.py b/tests/model/test_freeze.py
index c6cdec78..5f478af6 100644
--- a/tests/model/test_freeze.py
+++ b/tests/model/test_freeze.py
@@ -1,19 +1,33 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import torch
-from llamafactory.hparams import get_train_args
+from llamafactory.hparams import get_infer_args, get_train_args
from llamafactory.model import load_model, load_tokenizer
-TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
-TRAINING_ARGS = {
+TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "freeze",
- "dataset": "llamafactory/tiny_dataset",
+ "dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
@@ -23,16 +37,19 @@ TRAINING_ARGS = {
"fp16": True,
}
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "finetuning_type": "freeze",
+ "template": "llama3",
+ "infer_dtype": "float16",
+}
-def test_freeze_all_modules():
- model_args, _, _, finetuning_args, _ = get_train_args(
- {
- "freeze_trainable_layers": 1,
- **TRAINING_ARGS,
- }
- )
+
+def test_freeze_train_all_modules():
+ model_args, _, _, finetuning_args, _ = get_train_args({"freeze_trainable_layers": 1, **TRAIN_ARGS})
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
for name, param in model.named_parameters():
if name.startswith("model.layers.1."):
assert param.requires_grad is True
@@ -42,16 +59,13 @@ def test_freeze_all_modules():
assert param.dtype == torch.float16
-def test_freeze_extra_modules():
+def test_freeze_train_extra_modules():
model_args, _, _, finetuning_args, _ = get_train_args(
- {
- "freeze_trainable_layers": 1,
- "freeze_extra_modules": "embed_tokens,lm_head",
- **TRAINING_ARGS,
- }
+ {"freeze_trainable_layers": 1, "freeze_extra_modules": "embed_tokens,lm_head", **TRAIN_ARGS}
)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
for name, param in model.named_parameters():
if name.startswith("model.layers.1.") or any(module in name for module in ["embed_tokens", "lm_head"]):
assert param.requires_grad is True
@@ -59,3 +73,13 @@ def test_freeze_extra_modules():
else:
assert param.requires_grad is False
assert param.dtype == torch.float16
+
+
+def test_freeze_inference():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
+
+ for param in model.parameters():
+ assert param.requires_grad is False
+ assert param.dtype == torch.float16
diff --git a/tests/model/test_full.py b/tests/model/test_full.py
index ef57a980..0a6e0743 100644
--- a/tests/model/test_full.py
+++ b/tests/model/test_full.py
@@ -1,19 +1,33 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import torch
-from llamafactory.hparams import get_train_args
+from llamafactory.hparams import get_infer_args, get_train_args
from llamafactory.model import load_model, load_tokenizer
-TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
-TRAINING_ARGS = {
+TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
- "dataset": "llamafactory/tiny_dataset",
+ "dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
@@ -23,11 +37,29 @@ TRAINING_ARGS = {
"fp16": True,
}
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "finetuning_type": "full",
+ "template": "llama3",
+ "infer_dtype": "float16",
+}
-def test_full():
- model_args, _, _, finetuning_args, _ = get_train_args(TRAINING_ARGS)
+
+def test_full_train():
+ model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
for param in model.parameters():
assert param.requires_grad is True
assert param.dtype == torch.float32
+
+
+def test_full_inference():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
+
+ for param in model.parameters():
+ assert param.requires_grad is False
+ assert param.dtype == torch.float16
diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py
index 1f2c02ae..630e5f75 100644
--- a/tests/model/test_lora.py
+++ b/tests/model/test_lora.py
@@ -1,19 +1,43 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
+from typing import Dict, Sequence
+import pytest
import torch
+from peft import LoraModel, PeftModel
+from transformers import AutoModelForCausalLM
+from trl import AutoModelForCausalLMWithValueHead
-from llamafactory.hparams import get_train_args
+from llamafactory.extras.misc import get_current_device
+from llamafactory.hparams import get_infer_args, get_train_args
from llamafactory.model import load_model, load_tokenizer
-TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
-TRAINING_ARGS = {
+TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
+
+TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
+
+TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
- "dataset": "llamafactory/tiny_dataset",
+ "dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
@@ -23,16 +47,70 @@ TRAINING_ARGS = {
"fp16": True,
}
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "adapter_name_or_path": TINY_LLAMA_ADAPTER,
+ "finetuning_type": "lora",
+ "template": "llama3",
+ "infer_dtype": "float16",
+}
-def test_lora_all_modules():
- model_args, _, _, finetuning_args, _ = get_train_args(
- {
- "lora_target": "all",
- **TRAINING_ARGS,
- }
+
+def load_reference_model(is_trainable: bool = False) -> "LoraModel":
+ model = AutoModelForCausalLM.from_pretrained(
+ TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
)
+ lora_model = PeftModel.from_pretrained(model, TINY_LLAMA_ADAPTER, is_trainable=is_trainable)
+ for param in filter(lambda p: p.requires_grad, lora_model.parameters()):
+ param.data = param.data.to(torch.float32)
+
+ return lora_model
+
+
+def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []):
+ state_dict_a = model_a.state_dict()
+ state_dict_b = model_b.state_dict()
+ assert set(state_dict_a.keys()) == set(state_dict_b.keys())
+ for name in state_dict_a.keys():
+ if any(key in name for key in diff_keys):
+ assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
+ else:
+ assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
+
+
+@pytest.fixture
+def fix_valuehead_cpu_loading():
+ def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
+ state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
+ self.v_head.load_state_dict(state_dict, strict=False)
+ del state_dict
+
+ AutoModelForCausalLMWithValueHead.post_init = post_init
+
+
+def test_lora_train_qv_modules():
+ model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS})
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
+ linear_modules = set()
+ for name, param in model.named_parameters():
+ if any(module in name for module in ["lora_A", "lora_B"]):
+ linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1])
+ assert param.requires_grad is True
+ assert param.dtype == torch.float32
+ else:
+ assert param.requires_grad is False
+ assert param.dtype == torch.float16
+
+ assert linear_modules == {"q_proj", "v_proj"}
+
+
+def test_lora_train_all_modules():
+ model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "all", **TRAIN_ARGS})
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
linear_modules = set()
for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]):
@@ -46,16 +124,13 @@ def test_lora_all_modules():
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
-def test_lora_extra_modules():
+def test_lora_train_extra_modules():
model_args, _, _, finetuning_args, _ = get_train_args(
- {
- "lora_target": "all",
- "additional_target": "embed_tokens,lm_head",
- **TRAINING_ARGS,
- }
+ {"lora_target": "all", "additional_target": "embed_tokens,lm_head", **TRAIN_ARGS}
)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
extra_modules = set()
for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]):
@@ -70,3 +145,54 @@ def test_lora_extra_modules():
assert param.dtype == torch.float16
assert extra_modules == {"embed_tokens", "lm_head"}
+
+
+def test_lora_train_old_adapters():
+ model_args, _, _, finetuning_args, _ = get_train_args(
+ {"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": False, **TRAIN_ARGS}
+ )
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
+ ref_model = load_reference_model(is_trainable=True)
+ compare_model(model, ref_model)
+
+
+def test_lora_train_new_adapters():
+ model_args, _, _, finetuning_args, _ = get_train_args(
+ {"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": True, **TRAIN_ARGS}
+ )
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
+ ref_model = load_reference_model(is_trainable=True)
+ compare_model(
+ model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
+ )
+
+
+@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
+def test_lora_train_valuehead():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(
+ tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True, add_valuehead=True
+ )
+
+ ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
+ TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ state_dict = model.state_dict()
+ ref_state_dict = ref_model.state_dict()
+
+ assert torch.allclose(state_dict["v_head.summary.weight"], ref_state_dict["v_head.summary.weight"])
+ assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"])
+
+
+def test_lora_inference():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
+
+ ref_model = load_reference_model().merge_and_unload()
+ compare_model(model, ref_model)
diff --git a/tests/model/test_pissa.py b/tests/model/test_pissa.py
new file mode 100644
index 00000000..030310d0
--- /dev/null
+++ b/tests/model/test_pissa.py
@@ -0,0 +1,90 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import torch
+from peft import LoraModel, PeftModel
+from transformers import AutoModelForCausalLM
+
+from llamafactory.extras.misc import get_current_device
+from llamafactory.hparams import get_infer_args, get_train_args
+from llamafactory.model import load_model, load_tokenizer
+
+
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+
+TINY_LLAMA_PISSA = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa")
+
+TRAIN_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "stage": "sft",
+ "do_train": True,
+ "finetuning_type": "lora",
+ "pissa_init": True,
+ "pissa_iter": -1,
+ "dataset": "llamafactory/tiny-supervised-dataset",
+ "dataset_dir": "ONLINE",
+ "template": "llama3",
+ "cutoff_len": 1024,
+ "overwrite_cache": True,
+ "output_dir": "dummy_dir",
+ "overwrite_output_dir": True,
+ "fp16": True,
+}
+
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA_PISSA,
+ "adapter_name_or_path": TINY_LLAMA_PISSA,
+ "adapter_folder": "pissa_init",
+ "finetuning_type": "lora",
+ "template": "llama3",
+ "infer_dtype": "float16",
+}
+
+
+def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
+ state_dict_a = model_a.state_dict()
+ state_dict_b = model_b.state_dict()
+ assert set(state_dict_a.keys()) == set(state_dict_b.keys())
+ for name in state_dict_a.keys():
+ assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5)
+
+
+def test_pissa_init():
+ model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
+ base_model = AutoModelForCausalLM.from_pretrained(
+ TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init", is_trainable=True)
+ for param in filter(lambda p: p.requires_grad, ref_model.parameters()):
+ param.data = param.data.to(torch.float32)
+
+ compare_model(model, ref_model)
+
+
+def test_pissa_inference():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
+
+ base_model = AutoModelForCausalLM.from_pretrained(
+ TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ ref_model: "LoraModel" = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init")
+ ref_model = ref_model.merge_and_unload()
+ compare_model(model, ref_model)