Merge branch 'main' into main
This commit is contained in:
commit
e8e6af2651
|
@ -4,10 +4,10 @@
|
||||||
.venv
|
.venv
|
||||||
cache
|
cache
|
||||||
data
|
data
|
||||||
|
docker
|
||||||
|
saves
|
||||||
hf_cache
|
hf_cache
|
||||||
output
|
output
|
||||||
examples
|
|
||||||
.dockerignore
|
.dockerignore
|
||||||
.gitattributes
|
.gitattributes
|
||||||
.gitignore
|
.gitignore
|
||||||
Dockerfile
|
|
||||||
|
|
|
@ -38,7 +38,9 @@ body:
|
||||||
请合理使用 Markdown 标签来格式化您的文本。
|
请合理使用 Markdown 标签来格式化您的文本。
|
||||||
|
|
||||||
placeholder: |
|
placeholder: |
|
||||||
|
```bash
|
||||||
llamafactory-cli train ...
|
llamafactory-cli train ...
|
||||||
|
```
|
||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: expected-behavior
|
id: expected-behavior
|
||||||
|
|
|
@ -5,3 +5,4 @@ Fixes # (issue)
|
||||||
## Before submitting
|
## Before submitting
|
||||||
|
|
||||||
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
|
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
|
||||||
|
- [ ] Did you write any new necessary tests?
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
name: label_issue
|
||||||
|
|
||||||
|
on:
|
||||||
|
issues:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
label_issue:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
ISSUE_URL: ${{ github.event.issue.html_url }}
|
||||||
|
ISSUE_TITLE: ${{ github.event.issue.title }}
|
||||||
|
run: |
|
||||||
|
LABEL=pending
|
||||||
|
NPU_KEYWORDS=(npu ascend huawei 华为 昇腾)
|
||||||
|
ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]')
|
||||||
|
for KEYWORD in ${NPU_KEYWORDS[@]}; do
|
||||||
|
if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then
|
||||||
|
LABEL=pending,npu
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
gh issue edit $ISSUE_URL --add-label $LABEL
|
|
@ -0,0 +1,40 @@
|
||||||
|
name: publish
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types:
|
||||||
|
- published
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
publish:
|
||||||
|
name: Upload release to PyPI
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
environment:
|
||||||
|
name: release
|
||||||
|
url: https://pypi.org/p/llamafactory
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.8"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install build
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: |
|
||||||
|
python -m build
|
||||||
|
|
||||||
|
- name: Publish package
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
|
@ -19,21 +19,27 @@ on:
|
||||||
jobs:
|
jobs:
|
||||||
tests:
|
tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.8"
|
python-version: "3.8"
|
||||||
cache: "pip"
|
cache: "pip"
|
||||||
cache-dependency-path: "setup.py"
|
cache-dependency-path: "setup.py"
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install .[torch,dev]
|
python -m pip install ".[torch,dev]"
|
||||||
|
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
run: |
|
run: |
|
||||||
make style && make quality
|
make style && make quality
|
||||||
|
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
make test
|
make test
|
||||||
|
|
|
@ -160,6 +160,8 @@ cython_debug/
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
# custom .gitignore
|
# custom .gitignore
|
||||||
user.config
|
|
||||||
saves/
|
|
||||||
cache/
|
cache/
|
||||||
|
config/
|
||||||
|
saves/
|
||||||
|
output/
|
||||||
|
wandb/
|
||||||
|
|
11
CITATION.cff
11
CITATION.cff
|
@ -12,12 +12,16 @@ authors:
|
||||||
given-names: "Yanhan"
|
given-names: "Yanhan"
|
||||||
- family-names: "Luo"
|
- family-names: "Luo"
|
||||||
given-names: "Zheyan"
|
given-names: "Zheyan"
|
||||||
|
- family-names: "Feng"
|
||||||
|
given-names: "Zhangchi"
|
||||||
- family-names: "Ma"
|
- family-names: "Ma"
|
||||||
given-names: "Yongqiang"
|
given-names: "Yongqiang"
|
||||||
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
||||||
url: "https://arxiv.org/abs/2403.13372"
|
url: "https://arxiv.org/abs/2403.13372"
|
||||||
preferred-citation:
|
preferred-citation:
|
||||||
type: article
|
type: conference-paper
|
||||||
|
conference:
|
||||||
|
name: "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)"
|
||||||
authors:
|
authors:
|
||||||
- family-names: "Zheng"
|
- family-names: "Zheng"
|
||||||
given-names: "Yaowei"
|
given-names: "Yaowei"
|
||||||
|
@ -29,9 +33,12 @@ preferred-citation:
|
||||||
given-names: "Yanhan"
|
given-names: "Yanhan"
|
||||||
- family-names: "Luo"
|
- family-names: "Luo"
|
||||||
given-names: "Zheyan"
|
given-names: "Zheyan"
|
||||||
|
- family-names: "Feng"
|
||||||
|
given-names: "Zhangchi"
|
||||||
- family-names: "Ma"
|
- family-names: "Ma"
|
||||||
given-names: "Yongqiang"
|
given-names: "Yongqiang"
|
||||||
journal: "arXiv preprint arXiv:2403.13372"
|
|
||||||
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
||||||
url: "https://arxiv.org/abs/2403.13372"
|
url: "https://arxiv.org/abs/2403.13372"
|
||||||
year: 2024
|
year: 2024
|
||||||
|
publisher: "Association for Computational Linguistics"
|
||||||
|
address: "Bangkok, Thailand"
|
||||||
|
|
14
Dockerfile
14
Dockerfile
|
@ -1,14 +0,0 @@
|
||||||
FROM nvcr.io/nvidia/pytorch:24.01-py3
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY requirements.txt /app/
|
|
||||||
RUN pip install -r requirements.txt
|
|
||||||
|
|
||||||
COPY . /app/
|
|
||||||
RUN pip install -e .[metrics,bitsandbytes,qwen]
|
|
||||||
|
|
||||||
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
|
|
||||||
EXPOSE 7860
|
|
||||||
|
|
||||||
CMD [ "llamafactory-cli", "webui" ]
|
|
|
@ -0,0 +1 @@
|
||||||
|
include LICENSE requirements.txt
|
2
Makefile
2
Makefile
|
@ -11,4 +11,4 @@ style:
|
||||||
ruff format $(check_dirs)
|
ruff format $(check_dirs)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
pytest tests/
|
CUDA_VISIBLE_DEVICES= pytest tests/
|
||||||
|
|
227
README.md
227
README.md
|
@ -4,7 +4,7 @@
|
||||||
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
|
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
|
||||||
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
|
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
|
||||||
[![Citation](https://img.shields.io/badge/citation-44-green)](#projects-using-llama-factory)
|
[![Citation](https://img.shields.io/badge/citation-71-green)](#projects-using-llama-factory)
|
||||||
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
|
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
|
||||||
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
|
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
|
||||||
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
|
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
|
||||||
|
|
||||||
👋 Join our [WeChat](assets/wechat.jpg).
|
👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
|
||||||
|
|
||||||
\[ English | [中文](README_zh.md) \]
|
\[ English | [中文](README_zh.md) \]
|
||||||
|
|
||||||
|
@ -48,8 +48,8 @@ Choose your path:
|
||||||
|
|
||||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
||||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
|
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
||||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||||
|
@ -71,9 +71,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[24/06/07] We supported fine-tuning the **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** series models.
|
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/06/05] We supported fine-tuning the **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** models.
|
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
||||||
|
|
||||||
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
|
@ -151,35 +151,32 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
| Model | Model size | Template |
|
| Model | Model size | Template |
|
||||||
| -------------------------------------------------------- | -------------------------------- | --------- |
|
| ------------------------------------------------------------ | -------------------------------- | --------- |
|
||||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
|
||||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||||
|
@ -259,6 +256,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
||||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||||
|
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
||||||
|
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
||||||
|
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
||||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
|
@ -335,10 +335,10 @@ huggingface-cli login
|
||||||
```bash
|
```bash
|
||||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
cd LLaMA-Factory
|
cd LLaMA-Factory
|
||||||
pip install -e '.[torch,metrics]'
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
|
Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||||
|
@ -357,9 +357,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
|
||||||
|
|
||||||
<details><summary>For Ascend NPU users</summary>
|
<details><summary>For Ascend NPU users</summary>
|
||||||
|
|
||||||
Join [NPU user group](assets/wechat_npu.jpg).
|
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||||
|
|
||||||
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# replace the url according to your CANN version and devices
|
# replace the url according to your CANN version and devices
|
||||||
|
@ -382,15 +380,12 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
| deepspeed | 0.13.2 | 0.13.2 |
|
||||||
|
|
||||||
Docker image:
|
|
||||||
|
|
||||||
- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
|
|
||||||
- 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
|
||||||
|
|
||||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
||||||
|
|
||||||
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
||||||
|
|
||||||
|
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### Data Preparation
|
### Data Preparation
|
||||||
|
@ -405,9 +400,9 @@ Please refer to [data/README.md](data/README.md) for checking the details about
|
||||||
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
|
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
|
See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
|
||||||
|
@ -417,34 +412,89 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
|
||||||
|
|
||||||
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
||||||
|
|
||||||
#### Use local environment
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
|
llamafactory-cli webui
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
### Build Docker
|
||||||
|
|
||||||
#### Use Docker
|
For CUDA users:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker build -f ./Dockerfile -t llama-factory:latest .
|
cd docker/docker-cuda/
|
||||||
docker run --gpus=all \
|
docker-compose up -d
|
||||||
-v ./hf_cache:/root/.cache/huggingface/ \
|
docker-compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
For Ascend NPU users:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd docker/docker-npu/
|
||||||
|
docker-compose up -d
|
||||||
|
docker-compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
<details><summary>Build without Docker Compose</summary>
|
||||||
|
|
||||||
|
For CUDA users:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -f ./docker/docker-cuda/Dockerfile \
|
||||||
|
--build-arg INSTALL_BNB=false \
|
||||||
|
--build-arg INSTALL_VLLM=false \
|
||||||
|
--build-arg INSTALL_DEEPSPEED=false \
|
||||||
|
--build-arg INSTALL_FLASHATTN=false \
|
||||||
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
|
-t llamafactory:latest .
|
||||||
|
|
||||||
|
docker run -dit --gpus=all \
|
||||||
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-p 7860:7860 \
|
-p 7860:7860 \
|
||||||
|
-p 8000:8000 \
|
||||||
--shm-size 16G \
|
--shm-size 16G \
|
||||||
--name llama_factory \
|
--name llamafactory \
|
||||||
-d llama-factory:latest
|
llamafactory:latest
|
||||||
|
|
||||||
|
docker exec -it llamafactory bash
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Use Docker Compose
|
For Ascend NPU users:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker compose -f ./docker-compose.yml up -d
|
# Choose docker image upon your environment
|
||||||
|
docker build -f ./docker/docker-npu/Dockerfile \
|
||||||
|
--build-arg INSTALL_DEEPSPEED=false \
|
||||||
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
|
-t llamafactory:latest .
|
||||||
|
|
||||||
|
# Change `device` upon your resources
|
||||||
|
docker run -dit \
|
||||||
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./data:/app/data \
|
||||||
|
-v ./output:/app/output \
|
||||||
|
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||||
|
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
||||||
|
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
||||||
|
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
||||||
|
-p 7860:7860 \
|
||||||
|
-p 8000:8000 \
|
||||||
|
--device /dev/davinci0 \
|
||||||
|
--device /dev/davinci_manager \
|
||||||
|
--device /dev/devmm_svm \
|
||||||
|
--device /dev/hisi_hdc \
|
||||||
|
--shm-size 16G \
|
||||||
|
--name llamafactory \
|
||||||
|
llamafactory:latest
|
||||||
|
|
||||||
|
docker exec -it llamafactory bash
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details><summary>Details about volume</summary>
|
<details><summary>Details about volume</summary>
|
||||||
|
|
||||||
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
||||||
|
@ -456,7 +506,7 @@ docker compose -f ./docker-compose.yml up -d
|
||||||
### Deploy with OpenAI-style API and vLLM
|
### Deploy with OpenAI-style API and vLLM
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
|
@ -474,7 +524,7 @@ Train the model by specifying a model ID of the ModelScope Hub as the `model_nam
|
||||||
|
|
||||||
### Use W&B Logger
|
### Use W&B Logger
|
||||||
|
|
||||||
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments.
|
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
report_to: wandb
|
report_to: wandb
|
||||||
|
@ -494,38 +544,63 @@ If you have a project that should be incorporated, please contact via email or c
|
||||||
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
||||||
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||||
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||||
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||||
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||||
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||||
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||||
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
||||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||||
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
||||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||||
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
||||||
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
||||||
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
||||||
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
||||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||||
|
1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
|
||||||
|
1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
|
||||||
|
1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
|
||||||
|
1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
|
||||||
|
1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
|
||||||
|
1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
|
||||||
|
1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
|
||||||
|
1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
|
||||||
|
1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
|
||||||
|
1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
|
||||||
|
1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
|
||||||
|
1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
|
||||||
|
1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
|
||||||
|
1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
|
||||||
|
1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
|
||||||
|
1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
|
||||||
|
1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
|
||||||
|
1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
|
||||||
|
1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
|
||||||
|
1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
|
||||||
|
1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
|
||||||
|
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
|
||||||
|
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
|
||||||
|
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
|
||||||
|
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
|
||||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||||
|
@ -533,6 +608,8 @@ If you have a project that should be incorporated, please contact via email or c
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
||||||
|
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
||||||
|
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
@ -540,17 +617,19 @@ If you have a project that should be incorporated, please contact via email or c
|
||||||
|
|
||||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||||
|
|
||||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
If this work is helpful, please kindly cite as:
|
If this work is helpful, please kindly cite as:
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{zheng2024llamafactory,
|
@inproceedings{zheng2024llamafactory,
|
||||||
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
||||||
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
|
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
|
||||||
journal={arXiv preprint arXiv:2403.13372},
|
booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
|
||||||
|
address={Bangkok, Thailand},
|
||||||
|
publisher={Association for Computational Linguistics},
|
||||||
year={2024},
|
year={2024},
|
||||||
url={http://arxiv.org/abs/2403.13372}
|
url={http://arxiv.org/abs/2403.13372}
|
||||||
}
|
}
|
||||||
|
|
229
README_zh.md
229
README_zh.md
|
@ -4,7 +4,7 @@
|
||||||
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
|
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
|
||||||
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
|
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
|
||||||
[![Citation](https://img.shields.io/badge/citation-44-green)](#使用了-llama-factory-的项目)
|
[![Citation](https://img.shields.io/badge/citation-71-green)](#使用了-llama-factory-的项目)
|
||||||
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
|
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
|
||||||
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
|
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
|
||||||
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
|
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
|
||||||
|
|
||||||
👋 加入我们的[微信群](assets/wechat.jpg)。
|
👋 加入我们的[微信群](assets/wechat.jpg)或 [NPU 用户群](assets/wechat_npu.jpg)。
|
||||||
|
|
||||||
\[ [English](README.md) | 中文 \]
|
\[ [English](README.md) | 中文 \]
|
||||||
|
|
||||||
|
@ -48,8 +48,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
|
|
||||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
||||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
|
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
|
||||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||||
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||||
|
@ -71,9 +71,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
[24/06/07] 我们支持了 **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** 系列模型的微调。
|
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/06/05] 我们支持了 **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** 模型的微调。
|
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
|
||||||
|
|
||||||
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
|
@ -151,35 +151,32 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
|
|
||||||
## 模型
|
## 模型
|
||||||
|
|
||||||
| 模型名 | 模型大小 | Template |
|
| 模型名 | 模型大小 | Template |
|
||||||
| -------------------------------------------------------- | -------------------------------- | --------- |
|
| ------------------------------------------------------------ | -------------------------------- | --------- |
|
||||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
|
||||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||||
|
@ -259,6 +256,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||||
|
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
||||||
|
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
||||||
|
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
||||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
|
@ -335,10 +335,10 @@ huggingface-cli login
|
||||||
```bash
|
```bash
|
||||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
cd LLaMA-Factory
|
cd LLaMA-Factory
|
||||||
pip install -e '.[torch,metrics]'
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
|
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||||
|
@ -357,9 +357,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||||
|
|
||||||
<details><summary>昇腾 NPU 用户指南</summary>
|
<details><summary>昇腾 NPU 用户指南</summary>
|
||||||
|
|
||||||
加入 [NPU 用户群](assets/wechat_npu.jpg)。
|
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||||
|
|
||||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||||
|
@ -382,15 +380,12 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
| deepspeed | 0.13.2 | 0.13.2 |
|
||||||
|
|
||||||
Docker 镜像:
|
|
||||||
|
|
||||||
- 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
|
|
||||||
- 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
|
||||||
|
|
||||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
||||||
|
|
||||||
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
|
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
|
||||||
|
|
||||||
|
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 数据准备
|
### 数据准备
|
||||||
|
@ -405,9 +400,9 @@ Docker 镜像:
|
||||||
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
|
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。
|
高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。
|
||||||
|
@ -417,32 +412,89 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s
|
||||||
|
|
||||||
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
|
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
|
||||||
|
|
||||||
#### 使用本地环境
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
|
llamafactory-cli webui
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 Docker
|
### 构建 Docker
|
||||||
|
|
||||||
|
CUDA 用户:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker build -f ./Dockerfile -t llama-factory:latest .
|
cd docker/docker-cuda/
|
||||||
docker run --gpus=all \
|
docker-compose up -d
|
||||||
-v ./hf_cache:/root/.cache/huggingface/ \
|
docker-compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
昇腾 NPU 用户:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd docker/docker-npu/
|
||||||
|
docker-compose up -d
|
||||||
|
docker-compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
<details><summary>不使用 Docker Compose 构建</summary>
|
||||||
|
|
||||||
|
CUDA 用户:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -f ./docker/docker-cuda/Dockerfile \
|
||||||
|
--build-arg INSTALL_BNB=false \
|
||||||
|
--build-arg INSTALL_VLLM=false \
|
||||||
|
--build-arg INSTALL_DEEPSPEED=false \
|
||||||
|
--build-arg INSTALL_FLASHATTN=false \
|
||||||
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
|
-t llamafactory:latest .
|
||||||
|
|
||||||
|
docker run -dit --gpus=all \
|
||||||
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-p 7860:7860 \
|
-p 7860:7860 \
|
||||||
|
-p 8000:8000 \
|
||||||
--shm-size 16G \
|
--shm-size 16G \
|
||||||
--name llama_factory \
|
--name llamafactory \
|
||||||
-d llama-factory:latest
|
llamafactory:latest
|
||||||
|
|
||||||
|
docker exec -it llamafactory bash
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 Docker Compose
|
昇腾 NPU 用户:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker compose -f ./docker-compose.yml up -d
|
# 根据您的环境选择镜像
|
||||||
|
docker build -f ./docker/docker-npu/Dockerfile \
|
||||||
|
--build-arg INSTALL_DEEPSPEED=false \
|
||||||
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
|
-t llamafactory:latest .
|
||||||
|
|
||||||
|
# 根据您的资源更改 `device`
|
||||||
|
docker run -dit \
|
||||||
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./data:/app/data \
|
||||||
|
-v ./output:/app/output \
|
||||||
|
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||||
|
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
||||||
|
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
||||||
|
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
||||||
|
-p 7860:7860 \
|
||||||
|
-p 8000:8000 \
|
||||||
|
--device /dev/davinci0 \
|
||||||
|
--device /dev/davinci_manager \
|
||||||
|
--device /dev/devmm_svm \
|
||||||
|
--device /dev/hisi_hdc \
|
||||||
|
--shm-size 16G \
|
||||||
|
--name llamafactory \
|
||||||
|
llamafactory:latest
|
||||||
|
|
||||||
|
docker exec -it llamafactory bash
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details><summary>数据卷详情</summary>
|
<details><summary>数据卷详情</summary>
|
||||||
|
|
||||||
- hf_cache:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
- hf_cache:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
||||||
|
@ -454,7 +506,7 @@ docker compose -f ./docker-compose.yml up -d
|
||||||
### 利用 vLLM 部署 OpenAI API
|
### 利用 vLLM 部署 OpenAI API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
|
@ -472,7 +524,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||||
|
|
||||||
### 使用 W&B 面板
|
### 使用 W&B 面板
|
||||||
|
|
||||||
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请添加下面的参数。
|
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
report_to: wandb
|
report_to: wandb
|
||||||
|
@ -492,38 +544,63 @@ run_name: test_run # 可选
|
||||||
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
||||||
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||||
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||||
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||||
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||||
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||||
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||||
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
||||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||||
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
||||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||||
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
||||||
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
||||||
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
||||||
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
||||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||||
|
1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
|
||||||
|
1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
|
||||||
|
1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
|
||||||
|
1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
|
||||||
|
1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
|
||||||
|
1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
|
||||||
|
1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
|
||||||
|
1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
|
||||||
|
1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
|
||||||
|
1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
|
||||||
|
1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
|
||||||
|
1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
|
||||||
|
1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
|
||||||
|
1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
|
||||||
|
1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
|
||||||
|
1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
|
||||||
|
1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
|
||||||
|
1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
|
||||||
|
1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
|
||||||
|
1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
|
||||||
|
1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
|
||||||
|
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
|
||||||
|
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
|
||||||
|
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
|
||||||
|
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
|
||||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||||
|
@ -531,6 +608,8 @@ run_name: test_run # 可选
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
||||||
|
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
|
||||||
|
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: 在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
@ -538,17 +617,19 @@ run_name: test_run # 可选
|
||||||
|
|
||||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||||
|
|
||||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## 引用
|
## 引用
|
||||||
|
|
||||||
如果您觉得此项目有帮助,请考虑以下列格式引用
|
如果您觉得此项目有帮助,请考虑以下列格式引用
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{zheng2024llamafactory,
|
@inproceedings{zheng2024llamafactory,
|
||||||
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
||||||
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
|
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
|
||||||
journal={arXiv preprint arXiv:2403.13372},
|
booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
|
||||||
|
address={Bangkok, Thailand},
|
||||||
|
publisher={Association for Computational Linguistics},
|
||||||
year={2024},
|
year={2024},
|
||||||
url={http://arxiv.org/abs/2403.13372}
|
url={http://arxiv.org/abs/2403.13372}
|
||||||
}
|
}
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 142 KiB After Width: | Height: | Size: 141 KiB |
Binary file not shown.
Before Width: | Height: | Size: 148 KiB After Width: | Height: | Size: 147 KiB |
|
@ -248,6 +248,21 @@
|
||||||
"ruozhiba_gpt4": {
|
"ruozhiba_gpt4": {
|
||||||
"hf_hub_url": "hfl/ruozhiba_gpt4_turbo"
|
"hf_hub_url": "hfl/ruozhiba_gpt4_turbo"
|
||||||
},
|
},
|
||||||
|
"neo_sft": {
|
||||||
|
"hf_hub_url": "m-a-p/neo_sft_phase2",
|
||||||
|
"formatting": "sharegpt"
|
||||||
|
},
|
||||||
|
"magpie_pro_300k": {
|
||||||
|
"hf_hub_url": "Magpie-Align/Magpie-Pro-300K-Filtered",
|
||||||
|
"formatting": "sharegpt"
|
||||||
|
},
|
||||||
|
"web_instruct": {
|
||||||
|
"hf_hub_url": "TIGER-Lab/WebInstructSub",
|
||||||
|
"columns": {
|
||||||
|
"prompt": "question",
|
||||||
|
"response": "answer"
|
||||||
|
}
|
||||||
|
},
|
||||||
"llava_1k_en": {
|
"llava_1k_en": {
|
||||||
"hf_hub_url": "BUAADreamer/llava-en-zh-2k",
|
"hf_hub_url": "BUAADreamer/llava-en-zh-2k",
|
||||||
"subset": "en",
|
"subset": "en",
|
||||||
|
@ -520,13 +535,13 @@
|
||||||
"prompt": "text"
|
"prompt": "text"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fileweb": {
|
"fineweb": {
|
||||||
"hf_hub_url": "HuggingFaceFW/fineweb",
|
"hf_hub_url": "HuggingFaceFW/fineweb",
|
||||||
"columns": {
|
"columns": {
|
||||||
"prompt": "text"
|
"prompt": "text"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fileweb_edu": {
|
"fineweb_edu": {
|
||||||
"hf_hub_url": "HuggingFaceFW/fineweb-edu",
|
"hf_hub_url": "HuggingFaceFW/fineweb-edu",
|
||||||
"columns": {
|
"columns": {
|
||||||
"prompt": "text"
|
"prompt": "text"
|
||||||
|
|
|
@ -1,4 +1,41 @@
|
||||||
[
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "Hi, I have some ingredients and I want to cook something. Can you help me find a recipe?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "Of course! I can help you with that. Please tell me what ingredients you have."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "I have chicken, bell peppers, and rice."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "function_call",
|
||||||
|
"value": "{\"name\": \"search_recipes\", \"arguments\": {\"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"]}}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "observation",
|
||||||
|
"value": "{\"recipes\": [{\"name\": \"Chicken and Bell Pepper Stir Fry\", \"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"], \"instructions\": \"Cut the chicken into small pieces. Slice the bell peppers. Cook the rice. Stir fry the chicken and bell peppers. Serve over rice.\"}, {\"name\": \"Chicken and Rice Casserole\", \"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"], \"instructions\": \"Cook the chicken and rice separately. Mix them together with the bell peppers in a casserole dish. Bake until golden brown.\"}]}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "I found two recipes for you. The first one is \"Chicken and Bell Pepper Stir Fry\". The instructions are: Cut the chicken into small pieces. Slice the bell peppers. Cook the rice. Stir fry the chicken and bell peppers. Serve over rice. The second one is \"Chicken and Rice Casserole\". The instructions are: Cook the chicken and rice separately. Mix them together with the bell peppers in a casserole dish. Bake until golden brown. Which one would you like to try?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "They both sound delicious, but I think I'll try the stir fry. Can you order the ingredients for me?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "I'm sorry, but as an AI, I don't have the capability to perform external tasks such as ordering ingredients. However, I can help you find more recipes or provide cooking tips if you need."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": "[{\"name\": \"search_recipes\", \"description\": \"Search for recipes based on ingredients\", \"parameters\": {\"type\": \"object\", \"properties\": {\"ingredients\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}, \"description\": \"The ingredients to search for\"}}, \"required\": [\"ingredients\"]}}]"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
|
@ -155,43 +192,6 @@
|
||||||
],
|
],
|
||||||
"tools": "[]"
|
"tools": "[]"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "Hi, I have some ingredients and I want to cook something. Can you help me find a recipe?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "Of course! I can help you with that. Please tell me what ingredients you have."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "I have chicken, bell peppers, and rice."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "function_call",
|
|
||||||
"value": "{\"name\": \"search_recipes\", \"arguments\": {\"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"]}}"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "observation",
|
|
||||||
"value": "{\"recipes\": [{\"name\": \"Chicken and Bell Pepper Stir Fry\", \"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"], \"instructions\": \"Cut the chicken into small pieces. Slice the bell peppers. Cook the rice. Stir fry the chicken and bell peppers. Serve over rice.\"}, {\"name\": \"Chicken and Rice Casserole\", \"ingredients\": [\"chicken\", \"bell peppers\", \"rice\"], \"instructions\": \"Cook the chicken and rice separately. Mix them together with the bell peppers in a casserole dish. Bake until golden brown.\"}]}"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "I found two recipes for you. The first one is \"Chicken and Bell Pepper Stir Fry\". The instructions are: Cut the chicken into small pieces. Slice the bell peppers. Cook the rice. Stir fry the chicken and bell peppers. Serve over rice. The second one is \"Chicken and Rice Casserole\". The instructions are: Cook the chicken and rice separately. Mix them together with the bell peppers in a casserole dish. Bake until golden brown. Which one would you like to try?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "They both sound delicious, but I think I'll try the stir fry. Can you order the ingredients for me?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "I'm sorry, but as an AI, I don't have the capability to perform external tasks such as ordering ingredients. However, I can help you find more recipes or provide cooking tips if you need."
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"tools": "[{\"name\": \"search_recipes\", \"description\": \"Search for recipes based on ingredients\", \"parameters\": {\"type\": \"object\", \"properties\": {\"ingredients\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}, \"description\": \"The ingredients to search for\"}}, \"required\": [\"ingredients\"]}}]"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
|
|
|
@ -1,25 +1,4 @@
|
||||||
[
|
[
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "假设你有一个需要随机数的Java程序,范围在0到10之间。你可以使用什么代码片段来生成这样的数字?\n\n下面是一个示例代码片段,它使用Math.random()方法来生成一个在0到1之间的随机数,然后缩放并四舍五入结果,得到一个在0到10之间的随机数:\n\n```java\nint randomNumber = (int) (Math.random() * 11);\n```\n\n在Java中是否有更有效的方法来生成随机数?你会如何修改代码以生成具有不同范围或分布的随机数?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "您提供的代码片段是Java中生成随机数的一种常见且高效的方法。然而,如果您需要对随机数生成有更多的控制,您可以使用java.util.Random类,它提供了更多生成随机数的选项。\n\n要生成一个在不同范围内的随机数,您可以修改代码以调整缩放因子。例如,要生成一个在1到10之间的随机数,您可以使用以下代码:\n\nint randomNumber = (int) (Math.random() * 10) + 1;\n\n要生成具有不同分布的随机数,您可以使用各种技术,如正态分布、指数分布或自定义分布。生成正态分布的一种方法是使用Box-Muller变换。这里有一个代码示例,它生成具有正态分布的随机数:\n\nRandom random = new Random();\ndouble u1 = random.nextDouble();\ndouble u2 = random.nextDouble();\ndouble z = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);\ndouble mean = 5;\ndouble stdDev = 2;\nint randomNumber = (int) (z * stdDev + mean);\n\n这段代码生成一个具有均值5和标准差2的正态分布随机数。您可以调整均值和标准差来生成具有不同分布的随机数。"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "你能解释一下算法分析中的大O记法概念吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "是的,大O表示法是一种描述算法性能随输入数据大小增加的方法。它用于估计算法在最坏情况下的时间复杂度或空间复杂度。大O中的“O”代表“数量级”,这意味着随着输入大小的增加,算法所需的操作数以一定的速率增长。\n\n例如,如果我们有一个需要n步才能解决大小为n的问题的算法,那么时间复杂度可以表示为O(n),这意味着操作数随着输入大小线性增长。如果我们有一个需要n^2步才能解决大小为n的问题的算法,那么时间复杂度可以表示为O(n^2),这意味着操作数随着输入大小二次方增长。\n\n还有其他几种大O表示法,如O(log n)、O(n log n)、O(2^n)等,它们代表了随着输入大小增加的不同操作数增长速率。通过分析算法的大O表示法,我们可以比较不同算法的性能,并选择给定问题最有效的算法。\n\n您是否有关于大O表示法或算法分析的特定问题?"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"tools": "[]"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
|
@ -41,6 +20,27 @@
|
||||||
],
|
],
|
||||||
"tools": "[{\"name\": \"generate_invoice\", \"description\": \"生成发票\", \"parameters\": {\"type\": \"object\", \"properties\": {\"customer_name\": {\"type\": \"string\", \"description\": \"客户名称\"}, \"items\": {\"type\": \"array\", \"items\": {\"type\": \"object\", \"properties\": {\"name\": {\"type\": \"string\", \"description\": \"The item name\"}, \"quantity\": {\"type\": \"integer\", \"description\": \"The quantity of the item\"}, \"price\": {\"type\": \"number\", \"description\": \"The price per unit\"}}, \"required\": [\"name\", \"quantity\", \"price\"]}}}, \"required\": [\"customer_name\", \"items\"]}}, {\"name\": \"generate_password\", \"description\": \"生成随机密码\", \"parameters\": {\"type\": \"object\", \"properties\": {\"length\": {\"type\": \"integer\", \"description\": \"密码的长度\"}}, \"required\": [\"length\"]}}]"
|
"tools": "[{\"name\": \"generate_invoice\", \"description\": \"生成发票\", \"parameters\": {\"type\": \"object\", \"properties\": {\"customer_name\": {\"type\": \"string\", \"description\": \"客户名称\"}, \"items\": {\"type\": \"array\", \"items\": {\"type\": \"object\", \"properties\": {\"name\": {\"type\": \"string\", \"description\": \"The item name\"}, \"quantity\": {\"type\": \"integer\", \"description\": \"The quantity of the item\"}, \"price\": {\"type\": \"number\", \"description\": \"The price per unit\"}}, \"required\": [\"name\", \"quantity\", \"price\"]}}}, \"required\": [\"customer_name\", \"items\"]}}, {\"name\": \"generate_password\", \"description\": \"生成随机密码\", \"parameters\": {\"type\": \"object\", \"properties\": {\"length\": {\"type\": \"integer\", \"description\": \"密码的长度\"}}, \"required\": [\"length\"]}}]"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "假设你有一个需要随机数的Java程序,范围在0到10之间。你可以使用什么代码片段来生成这样的数字?\n\n下面是一个示例代码片段,它使用Math.random()方法来生成一个在0到1之间的随机数,然后缩放并四舍五入结果,得到一个在0到10之间的随机数:\n\n```java\nint randomNumber = (int) (Math.random() * 11);\n```\n\n在Java中是否有更有效的方法来生成随机数?你会如何修改代码以生成具有不同范围或分布的随机数?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "您提供的代码片段是Java中生成随机数的一种常见且高效的方法。然而,如果您需要对随机数生成有更多的控制,您可以使用java.util.Random类,它提供了更多生成随机数的选项。\n\n要生成一个在不同范围内的随机数,您可以修改代码以调整缩放因子。例如,要生成一个在1到10之间的随机数,您可以使用以下代码:\n\nint randomNumber = (int) (Math.random() * 10) + 1;\n\n要生成具有不同分布的随机数,您可以使用各种技术,如正态分布、指数分布或自定义分布。生成正态分布的一种方法是使用Box-Muller变换。这里有一个代码示例,它生成具有正态分布的随机数:\n\nRandom random = new Random();\ndouble u1 = random.nextDouble();\ndouble u2 = random.nextDouble();\ndouble z = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);\ndouble mean = 5;\ndouble stdDev = 2;\nint randomNumber = (int) (z * stdDev + mean);\n\n这段代码生成一个具有均值5和标准差2的正态分布随机数。您可以调整均值和标准差来生成具有不同分布的随机数。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "你能解释一下算法分析中的大O记法概念吗?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "是的,大O表示法是一种描述算法性能随输入数据大小增加的方法。它用于估计算法在最坏情况下的时间复杂度或空间复杂度。大O中的“O”代表“数量级”,这意味着随着输入大小的增加,算法所需的操作数以一定的速率增长。\n\n例如,如果我们有一个需要n步才能解决大小为n的问题的算法,那么时间复杂度可以表示为O(n),这意味着操作数随着输入大小线性增长。如果我们有一个需要n^2步才能解决大小为n的问题的算法,那么时间复杂度可以表示为O(n^2),这意味着操作数随着输入大小二次方增长。\n\n还有其他几种大O表示法,如O(log n)、O(n log n)、O(2^n)等,它们代表了随着输入大小增加的不同操作数增长速率。通过分析算法的大O表示法,我们可以比较不同算法的性能,并选择给定问题最有效的算法。\n\n您是否有关于大O表示法或算法分析的特定问题?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": "[]"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
version: '3.8'
|
|
||||||
|
|
||||||
services:
|
|
||||||
llama-factory:
|
|
||||||
build:
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
context: .
|
|
||||||
container_name: llama_factory
|
|
||||||
volumes:
|
|
||||||
- ./hf_cache:/root/.cache/huggingface/
|
|
||||||
- ./data:/app/data
|
|
||||||
- ./output:/app/output
|
|
||||||
ports:
|
|
||||||
- "7860:7860"
|
|
||||||
ipc: host
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: "all"
|
|
||||||
capabilities: [gpu]
|
|
||||||
restart: unless-stopped
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Use the NVIDIA official image with PyTorch 2.3.0
|
||||||
|
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
|
||||||
|
FROM nvcr.io/nvidia/pytorch:24.02-py3
|
||||||
|
|
||||||
|
# Define environments
|
||||||
|
ENV MAX_JOBS=4
|
||||||
|
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
|
||||||
|
|
||||||
|
# Define installation arguments
|
||||||
|
ARG INSTALL_BNB=false
|
||||||
|
ARG INSTALL_VLLM=false
|
||||||
|
ARG INSTALL_DEEPSPEED=false
|
||||||
|
ARG INSTALL_FLASHATTN=false
|
||||||
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
|
|
||||||
|
# Set the working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install the requirements
|
||||||
|
COPY requirements.txt /app
|
||||||
|
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||||
|
pip config set global.extra-index-url "$PIP_INDEX" && \
|
||||||
|
python -m pip install --upgrade pip && \
|
||||||
|
python -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Rebuild flash attention
|
||||||
|
RUN pip uninstall -y transformer-engine flash-attn && \
|
||||||
|
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
|
||||||
|
pip uninstall -y ninja && pip install ninja && \
|
||||||
|
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Copy the rest of the application into the image
|
||||||
|
COPY . /app
|
||||||
|
|
||||||
|
# Install the LLaMA Factory
|
||||||
|
RUN EXTRA_PACKAGES="metrics"; \
|
||||||
|
if [ "$INSTALL_BNB" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_VLLM" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
|
fi; \
|
||||||
|
pip install -e ".[$EXTRA_PACKAGES]"
|
||||||
|
|
||||||
|
# Set up volumes
|
||||||
|
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
|
||||||
|
|
||||||
|
# Expose port 7860 for the LLaMA Board
|
||||||
|
ENV GRADIO_SERVER_PORT 7860
|
||||||
|
EXPOSE 7860
|
||||||
|
|
||||||
|
# Expose port 8000 for the API service
|
||||||
|
ENV API_PORT 8000
|
||||||
|
EXPOSE 8000
|
|
@ -0,0 +1,32 @@
|
||||||
|
services:
|
||||||
|
llamafactory:
|
||||||
|
build:
|
||||||
|
dockerfile: ./docker/docker-cuda/Dockerfile
|
||||||
|
context: ../..
|
||||||
|
args:
|
||||||
|
INSTALL_BNB: false
|
||||||
|
INSTALL_VLLM: false
|
||||||
|
INSTALL_DEEPSPEED: false
|
||||||
|
INSTALL_FLASHATTN: false
|
||||||
|
PIP_INDEX: https://pypi.org/simple
|
||||||
|
container_name: llamafactory
|
||||||
|
volumes:
|
||||||
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../data:/app/data
|
||||||
|
- ../../output:/app/output
|
||||||
|
ports:
|
||||||
|
- "7860:7860"
|
||||||
|
- "8000:8000"
|
||||||
|
ipc: host
|
||||||
|
tty: true
|
||||||
|
stdin_open: true
|
||||||
|
command: bash
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: "all"
|
||||||
|
capabilities: [gpu]
|
||||||
|
restart: unless-stopped
|
|
@ -0,0 +1,45 @@
|
||||||
|
# Use the Ubuntu 22.04 image with CANN 8.0.rc1
|
||||||
|
# More versions can be found at https://hub.docker.com/r/cosdt/cann/tags
|
||||||
|
# FROM cosdt/cann:8.0.rc1-910-ubuntu22.04
|
||||||
|
FROM cosdt/cann:8.0.rc1-910b-ubuntu22.04
|
||||||
|
# FROM cosdt/cann:8.0.rc1-910-openeuler22.03
|
||||||
|
# FROM cosdt/cann:8.0.rc1-910b-openeuler22.03
|
||||||
|
|
||||||
|
# Define environments
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# Define installation arguments
|
||||||
|
ARG INSTALL_DEEPSPEED=false
|
||||||
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
|
ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
|
# Set the working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install the requirements
|
||||||
|
COPY requirements.txt /app
|
||||||
|
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||||
|
pip config set global.extra-index-url "$TORCH_INDEX" && \
|
||||||
|
python -m pip install --upgrade pip && \
|
||||||
|
python -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Copy the rest of the application into the image
|
||||||
|
COPY . /app
|
||||||
|
|
||||||
|
# Install the LLaMA Factory
|
||||||
|
RUN EXTRA_PACKAGES="torch-npu,metrics"; \
|
||||||
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
|
fi; \
|
||||||
|
pip install -e ".[$EXTRA_PACKAGES]"
|
||||||
|
|
||||||
|
# Set up volumes
|
||||||
|
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
|
||||||
|
|
||||||
|
# Expose port 7860 for the LLaMA Board
|
||||||
|
ENV GRADIO_SERVER_PORT 7860
|
||||||
|
EXPOSE 7860
|
||||||
|
|
||||||
|
# Expose port 8000 for the API service
|
||||||
|
ENV API_PORT 8000
|
||||||
|
EXPOSE 8000
|
|
@ -0,0 +1,31 @@
|
||||||
|
services:
|
||||||
|
llamafactory:
|
||||||
|
build:
|
||||||
|
dockerfile: ./docker/docker-npu/Dockerfile
|
||||||
|
context: ../..
|
||||||
|
args:
|
||||||
|
INSTALL_DEEPSPEED: false
|
||||||
|
PIP_INDEX: https://pypi.org/simple
|
||||||
|
container_name: llamafactory
|
||||||
|
volumes:
|
||||||
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../data:/app/data
|
||||||
|
- ../../output:/app/output
|
||||||
|
- /usr/local/dcmi:/usr/local/dcmi
|
||||||
|
- /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
|
||||||
|
- /usr/local/Ascend/driver:/usr/local/Ascend/driver
|
||||||
|
- /etc/ascend_install.info:/etc/ascend_install.info
|
||||||
|
ports:
|
||||||
|
- "7860:7860"
|
||||||
|
- "8000:8000"
|
||||||
|
ipc: host
|
||||||
|
tty: true
|
||||||
|
stdin_open: true
|
||||||
|
command: bash
|
||||||
|
devices:
|
||||||
|
- /dev/davinci0
|
||||||
|
- /dev/davinci_manager
|
||||||
|
- /dev/devmm_svm
|
||||||
|
- /dev/hisi_hdc
|
||||||
|
restart: unless-stopped
|
|
@ -11,6 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
|
|
@ -4,59 +4,59 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
- [LoRA Fine-Tuning on A Single GPU](#lora-fine-tuning-on-a-single-gpu)
|
- [LoRA Fine-Tuning](#lora-fine-tuning)
|
||||||
- [QLoRA Fine-Tuning on a Single GPU](#qlora-fine-tuning-on-a-single-gpu)
|
- [QLoRA Fine-Tuning](#qlora-fine-tuning)
|
||||||
- [LoRA Fine-Tuning on Multiple GPUs](#lora-fine-tuning-on-multiple-gpus)
|
- [Full-Parameter Fine-Tuning](#full-parameter-fine-tuning)
|
||||||
- [LoRA Fine-Tuning on Multiple NPUs](#lora-fine-tuning-on-multiple-npus)
|
|
||||||
- [Full-Parameter Fine-Tuning on Multiple GPUs](#full-parameter-fine-tuning-on-multiple-gpus)
|
|
||||||
- [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization)
|
- [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization)
|
||||||
- [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models)
|
- [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models)
|
||||||
- [Extras](#extras)
|
- [Extras](#extras)
|
||||||
|
|
||||||
|
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
### LoRA Fine-Tuning on A Single GPU
|
### LoRA Fine-Tuning
|
||||||
|
|
||||||
#### (Continuous) Pre-Training
|
#### (Continuous) Pre-Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning
|
#### Supervised Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Multimodal Supervised Fine-Tuning
|
#### Multimodal Supervised Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Reward Modeling
|
#### Reward Modeling
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### PPO Training
|
#### PPO Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### DPO/ORPO/SimPO Training
|
#### DPO/ORPO/SimPO Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### KTO Training
|
#### KTO Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Preprocess Dataset
|
#### Preprocess Dataset
|
||||||
|
@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
||||||
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
|
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml
|
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
|
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml
|
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
||||||
```
|
|
||||||
|
|
||||||
### QLoRA Fine-Tuning on a Single GPU
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 4-bit AWQ Quantization
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### LoRA Fine-Tuning on Multiple GPUs
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning on Single Node
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning on Multiple Nodes
|
#### Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### LoRA Fine-Tuning on Multiple NPUs
|
### QLoRA Fine-Tuning
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-0
|
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Full-Parameter Fine-Tuning on Multiple GPUs
|
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Supervised Fine-Tuning with 4-bit AWQ Quantization
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Full-Parameter Fine-Tuning
|
||||||
|
|
||||||
#### Supervised Fine-Tuning on Single Node
|
#### Supervised Fine-Tuning on Single Node
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning on Multiple Nodes
|
#### Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
|
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Merging LoRA Adapters and Quantization
|
### Merging LoRA Adapters and Quantization
|
||||||
|
@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam
|
||||||
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
|
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Quantizing Model using AutoGPTQ
|
#### Quantizing Model using AutoGPTQ
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Inferring LoRA Fine-Tuned Models
|
### Inferring LoRA Fine-Tuned Models
|
||||||
|
|
||||||
Use `CUDA_VISIBLE_DEVICES=0,1` to infer models on multiple devices.
|
|
||||||
|
|
||||||
#### Use CLI
|
#### Use CLI
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Use Web UI
|
#### Use Web UI
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Launch OpenAI-style API
|
#### Launch OpenAI-style API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Extras
|
### Extras
|
||||||
|
@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
|
||||||
#### Full-Parameter Fine-Tuning using GaLore
|
#### Full-Parameter Fine-Tuning using GaLore
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Full-Parameter Fine-Tuning using BAdam
|
#### Full-Parameter Fine-Tuning using BAdam
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### LoRA+ Fine-Tuning
|
#### LoRA+ Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### PiSSA Fine-Tuning
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Mixture-of-Depths Fine-Tuning
|
#### Mixture-of-Depths Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### LLaMA-Pro Fine-Tuning
|
#### LLaMA-Pro Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bash examples/extras/llama_pro/expand.sh
|
bash examples/extras/llama_pro/expand.sh
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### FSDP+QLoRA Fine-Tuning
|
#### FSDP+QLoRA Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bash examples/extras/fsdp_qlora/single_node.sh
|
bash examples/extras/fsdp_qlora/train.sh
|
||||||
```
|
```
|
||||||
|
|
|
@ -4,59 +4,59 @@
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
- [单 GPU LoRA 微调](#单-gpu-lora-微调)
|
- [LoRA 微调](#lora-微调)
|
||||||
- [单 GPU QLoRA 微调](#单-gpu-qlora-微调)
|
- [QLoRA 微调](#qlora-微调)
|
||||||
- [多 GPU LoRA 微调](#多-gpu-lora-微调)
|
- [全参数微调](#全参数微调)
|
||||||
- [多 NPU LoRA 微调](#多-npu-lora-微调)
|
|
||||||
- [多 GPU 全参数微调](#多-gpu-全参数微调)
|
|
||||||
- [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化)
|
- [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化)
|
||||||
- [推理 LoRA 模型](#推理-lora-模型)
|
- [推理 LoRA 模型](#推理-lora-模型)
|
||||||
- [杂项](#杂项)
|
- [杂项](#杂项)
|
||||||
|
|
||||||
|
使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。
|
||||||
|
|
||||||
## 示例
|
## 示例
|
||||||
|
|
||||||
### 单 GPU LoRA 微调
|
### LoRA 微调
|
||||||
|
|
||||||
#### (增量)预训练
|
#### (增量)预训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 指令监督微调
|
#### 指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 多模态指令监督微调
|
#### 多模态指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 奖励模型训练
|
#### 奖励模型训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### PPO 训练
|
#### PPO 训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### DPO/ORPO/SimPO 训练
|
#### DPO/ORPO/SimPO 训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### KTO 训练
|
#### KTO 训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 预处理数据集
|
#### 预处理数据集
|
||||||
|
@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
||||||
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
|
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml
|
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 在 MMLU/CMMLU/C-Eval 上评估
|
#### 在 MMLU/CMMLU/C-Eval 上评估
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml
|
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 单 GPU QLoRA 微调
|
#### 多机指令监督微调
|
||||||
|
|
||||||
#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
|
|
||||||
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 基于 4 比特 AWQ 量化进行指令监督微调
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 基于 2 比特 AQLM 量化进行指令监督微调
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### 多 GPU LoRA 微调
|
|
||||||
|
|
||||||
#### 在单机上进行指令监督微调
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 在多机上进行指令监督微调
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 多 NPU LoRA 微调
|
### QLoRA 微调
|
||||||
|
|
||||||
#### 使用 DeepSpeed ZeRO-0 进行指令监督微调
|
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 多 GPU 全参数微调
|
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 基于 4 比特 AWQ 量化进行指令监督微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 基于 2 比特 AQLM 量化进行指令监督微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 全参数微调
|
||||||
|
|
||||||
#### 在单机上进行指令监督微调
|
#### 在单机上进行指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 在多机上进行指令监督微调
|
#### 在多机上进行指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
|
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 合并 LoRA 适配器与模型量化
|
### 合并 LoRA 适配器与模型量化
|
||||||
|
@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam
|
||||||
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
|
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 AutoGPTQ 量化模型
|
#### 使用 AutoGPTQ 量化模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 推理 LoRA 模型
|
### 推理 LoRA 模型
|
||||||
|
|
||||||
使用 `CUDA_VISIBLE_DEVICES=0,1` 进行多卡推理。
|
|
||||||
|
|
||||||
#### 使用命令行接口
|
#### 使用命令行接口
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用浏览器界面
|
#### 使用浏览器界面
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 启动 OpenAI 风格 API
|
#### 启动 OpenAI 风格 API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 杂项
|
### 杂项
|
||||||
|
@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
|
||||||
#### 使用 GaLore 进行全参数训练
|
#### 使用 GaLore 进行全参数训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 BAdam 进行全参数训练
|
#### 使用 BAdam 进行全参数训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### LoRA+ 微调
|
#### LoRA+ 微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### PiSSA 微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 深度混合微调
|
#### 深度混合微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### LLaMA-Pro 微调
|
#### LLaMA-Pro 微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bash examples/extras/llama_pro/expand.sh
|
bash examples/extras/llama_pro/expand.sh
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### FSDP+QLoRA 微调
|
#### FSDP+QLoRA 微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bash examples/extras/fsdp_qlora/single_node.sh
|
bash examples/extras/fsdp_qlora/train.sh
|
||||||
```
|
```
|
||||||
|
|
|
@ -5,10 +5,11 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: full
|
finetuning_type: full
|
||||||
|
use_badam: true
|
||||||
### ddp
|
badam_mode: layer
|
||||||
ddp_timeout: 180000000
|
badam_switch_mode: ascending
|
||||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
badam_switch_interval: 50
|
||||||
|
badam_verbose: 2
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
|
@ -27,12 +28,11 @@ overwrite_output_dir: true
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 8
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -6,9 +6,11 @@ stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: full
|
finetuning_type: full
|
||||||
use_badam: true
|
use_badam: true
|
||||||
|
badam_mode: layer
|
||||||
badam_switch_mode: ascending
|
badam_switch_mode: ascending
|
||||||
badam_switch_interval: 50
|
badam_switch_interval: 50
|
||||||
badam_verbose: 2
|
badam_verbose: 2
|
||||||
|
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
|
@ -32,7 +34,6 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
pure_bf16: true
|
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -8,9 +8,6 @@ do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### ddp
|
|
||||||
ddp_timeout: 180000000
|
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
|
@ -33,7 +30,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
|
|
|
@ -31,7 +31,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
|
|
|
@ -30,7 +30,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
|
|
|
@ -31,6 +31,7 @@ num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
pure_bf16: true
|
pure_bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
quantization_bit: 4
|
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
pissa_init: true
|
||||||
|
pissa_iter: 4
|
||||||
|
pissa_convert: true
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
|
@ -30,7 +32,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -4,11 +4,8 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: full
|
||||||
lora_target: all
|
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||||
|
|
||||||
### ddp
|
|
||||||
ddp_timeout: 180000000
|
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
|
@ -19,7 +16,7 @@ overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/llama3-8b/lora/sft
|
output_dir: saves/llama3-8b/full/sft
|
||||||
logging_steps: 10
|
logging_steps: 10
|
||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
|
@ -32,7 +29,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -7,7 +7,7 @@ do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
pref_beta: 0.1
|
pref_beta: 0.1
|
||||||
pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]
|
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: dpo_en_demo
|
dataset: dpo_en_demo
|
||||||
|
@ -31,7 +31,8 @@ learning_rate: 5.0e-6
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -6,6 +6,7 @@ stage: kto
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
pref_beta: 0.1
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: kto_en_demo
|
dataset: kto_en_demo
|
||||||
|
@ -29,7 +30,8 @@ learning_rate: 5.0e-6
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -30,7 +30,8 @@ learning_rate: 1.0e-5
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### generate
|
### generate
|
||||||
max_new_tokens: 512
|
max_new_tokens: 512
|
|
@ -22,3 +22,4 @@ overwrite_output_dir: true
|
||||||
### eval
|
### eval
|
||||||
per_device_eval_batch_size: 1
|
per_device_eval_batch_size: 1
|
||||||
predict_with_generate: true
|
predict_with_generate: true
|
||||||
|
ddp_timeout: 180000000
|
|
@ -28,7 +28,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -25,11 +25,12 @@ overwrite_output_dir: true
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
learning_rate: 1.0e-5
|
learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -29,7 +29,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -6,9 +6,6 @@ stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### ddp
|
|
||||||
ddp_timeout: 180000000
|
|
||||||
deepspeed: examples/deepspeed/ds_z0_config.json
|
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
|
@ -33,7 +30,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -6,9 +6,6 @@ stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### ddp
|
|
||||||
ddp_timeout: 180000000
|
|
||||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
|
@ -33,7 +30,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -30,7 +30,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -29,7 +29,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -29,7 +29,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -29,7 +29,8 @@ learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
|
@ -0,0 +1,41 @@
|
||||||
|
### model
|
||||||
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
quantization_bit: 4
|
||||||
|
quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)]
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_target: all
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: identity,alpaca_en_demo
|
||||||
|
template: llama3
|
||||||
|
cutoff_len: 1024
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/llama3-8b/lora/sft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
val_size: 0.1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
eval_strategy: steps
|
||||||
|
eval_steps: 500
|
|
@ -4,6 +4,7 @@ accelerate>=0.30.1
|
||||||
peft>=0.11.1
|
peft>=0.11.1
|
||||||
trl>=0.8.6
|
trl>=0.8.6
|
||||||
gradio>=4.0.0
|
gradio>=4.0.0
|
||||||
|
pandas>=2.0.0
|
||||||
scipy
|
scipy
|
||||||
einops
|
einops
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
@ -17,3 +18,4 @@ matplotlib>=3.7.0
|
||||||
fire
|
fire
|
||||||
packaging
|
packaging
|
||||||
pyyaml
|
pyyaml
|
||||||
|
numpy<2.0.0
|
||||||
|
|
|
@ -1,7 +1,20 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Calculates the flops of pre-trained models.
|
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
||||||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
#
|
||||||
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
# This code is inspired by the Microsoft's DeepSpeed library.
|
||||||
|
# https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
|
@ -17,6 +30,10 @@ def calculate_flops(
|
||||||
seq_length: int = 256,
|
seq_length: int = 256,
|
||||||
flash_attn: str = "auto",
|
flash_attn: str = "auto",
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Calculates the flops of pre-trained models.
|
||||||
|
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||||
|
"""
|
||||||
with get_accelerator().device(0):
|
with get_accelerator().device(0):
|
||||||
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
||||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
||||||
|
|
|
@ -1,7 +1,20 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
# Copyright 2024 imoneoi and the LlamaFactory team.
|
||||||
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
#
|
||||||
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
# This code is inspired by the imoneoi's OpenChat library.
|
||||||
|
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
@ -32,6 +45,10 @@ def calculate_lr(
|
||||||
cutoff_len: int = 1024, # i.e. maximum input length during training
|
cutoff_len: int = 1024, # i.e. maximum input length during training
|
||||||
is_mistral: bool = False, # mistral model uses a smaller learning rate,
|
is_mistral: bool = False, # mistral model uses a smaller learning rate,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||||
|
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
||||||
|
"""
|
||||||
model_args, data_args, training_args, _, _ = get_train_args(
|
model_args, data_args, training_args, _, _ = get_train_args(
|
||||||
dict(
|
dict(
|
||||||
stage=stage,
|
stage=stage,
|
||||||
|
|
|
@ -1,6 +1,17 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Calculates the ppl on the dataset of the pre-trained models.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -56,6 +67,10 @@ def cal_ppl(
|
||||||
max_samples: Optional[int] = None,
|
max_samples: Optional[int] = None,
|
||||||
train_on_prompt: bool = False,
|
train_on_prompt: bool = False,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Calculates the ppl on the dataset of the pre-trained models.
|
||||||
|
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
||||||
|
"""
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||||
dict(
|
dict(
|
||||||
stage=stage,
|
stage=stage,
|
||||||
|
|
|
@ -1,6 +1,17 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Calculates the distribution of the input lengths in the dataset.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
@ -19,6 +30,10 @@ def length_cdf(
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
interval: int = 1000,
|
interval: int = 1000,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Calculates the distribution of the input lengths in the dataset.
|
||||||
|
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
||||||
|
"""
|
||||||
model_args, data_args, training_args, _, _ = get_train_args(
|
model_args, data_args, training_args, _, _ = get_train_args(
|
||||||
dict(
|
dict(
|
||||||
stage="sft",
|
stage="sft",
|
||||||
|
|
|
@ -1,7 +1,20 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
|
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
||||||
# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
#
|
||||||
# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||||
|
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -37,6 +50,10 @@ def block_expansion(
|
||||||
shard_size: Optional[str] = "2GB",
|
shard_size: Optional[str] = "2GB",
|
||||||
save_safetensors: Optional[bool] = False,
|
save_safetensors: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
|
||||||
|
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
||||||
|
"""
|
||||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
|
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
num_layers = getattr(config, "num_hidden_layers")
|
num_layers = getattr(config, "num_hidden_layers")
|
||||||
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
||||||
|
@ -103,7 +120,7 @@ def block_expansion(
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print("Model weights saved in {}".format(output_dir))
|
||||||
|
|
||||||
print("Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print("model_name_or_path: {}".format(output_dir))
|
||||||
print("finetuning_type: freeze")
|
print("finetuning_type: freeze")
|
||||||
print("freeze_trainable_layers: {}".format(num_expand))
|
print("freeze_trainable_layers: {}".format(num_expand))
|
||||||
|
|
|
@ -1,8 +1,17 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
|
#
|
||||||
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -79,6 +88,11 @@ def save_config(input_dir: str, output_dir: str):
|
||||||
def llamafy_baichuan2(
|
def llamafy_baichuan2(
|
||||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||||
|
Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
|
||||||
|
Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
os.makedirs(output_dir, exist_ok=False)
|
os.makedirs(output_dir, exist_ok=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1,7 +1,17 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Converts the Qwen models in the same format as LLaMA2.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
# Usage: python llamafy_qwen.py --input_dir input --output_dir output
|
#
|
||||||
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -131,6 +141,11 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||||
def llamafy_qwen(
|
def llamafy_qwen(
|
||||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Converts the Qwen models in the same format as LLaMA2.
|
||||||
|
Usage: python llamafy_qwen.py --input_dir input --output_dir output
|
||||||
|
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
os.makedirs(output_dir, exist_ok=False)
|
os.makedirs(output_dir, exist_ok=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1,14 +1,25 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir
|
#
|
||||||
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
|
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
@ -17,65 +28,61 @@ if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
class Shell(nn.Module):
|
|
||||||
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
|
||||||
if bias is not None:
|
|
||||||
self.bias = nn.Parameter(bias, requires_grad=False)
|
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
|
||||||
for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
|
|
||||||
parent_name = ".".join(name.split(".")[:-1])
|
|
||||||
child_name = name.split(".")[-1]
|
|
||||||
parent_module = model.get_submodule(parent_name)
|
|
||||||
child_module = getattr(parent_module, child_name)
|
|
||||||
base_layer = getattr(child_module, "base_layer")
|
|
||||||
weight = getattr(base_layer, "weight", None)
|
|
||||||
bias = getattr(base_layer, "bias", None)
|
|
||||||
setattr(parent_module, child_name, Shell(weight, bias))
|
|
||||||
|
|
||||||
print("Model unwrapped.")
|
|
||||||
|
|
||||||
|
|
||||||
def quantize_loftq(
|
def quantize_loftq(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
save_dir: str,
|
output_dir: str,
|
||||||
loftq_bits: Optional[int] = 4,
|
loftq_bits: int = 4,
|
||||||
loftq_iter: Optional[int] = 1,
|
loftq_iter: int = 4,
|
||||||
lora_alpha: Optional[int] = None,
|
lora_alpha: int = None,
|
||||||
lora_rank: Optional[int] = 16,
|
lora_rank: int = 16,
|
||||||
lora_target: Optional[str] = "q_proj,v_proj",
|
lora_dropout: float = 0,
|
||||||
save_safetensors: Optional[bool] = False,
|
lora_target: tuple = ("q_proj", "v_proj"),
|
||||||
|
save_safetensors: bool = True,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
||||||
|
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||||
|
"""
|
||||||
|
if isinstance(lora_target, str):
|
||||||
|
lora_target = [name.strip() for name in lora_target.split(",")]
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||||
|
|
||||||
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
|
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
inference_mode=True,
|
inference_mode=True,
|
||||||
r=lora_rank,
|
r=lora_rank,
|
||||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||||
lora_dropout=0.1,
|
lora_dropout=lora_dropout,
|
||||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
target_modules=lora_target,
|
||||||
init_lora_weights="loftq",
|
init_lora_weights="loftq",
|
||||||
loftq_config=loftq_config,
|
loftq_config=loftq_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init LoftQ model
|
# Init LoftQ model
|
||||||
lora_model = get_peft_model(model, lora_config)
|
print("Initializing LoftQ weights, it may be take several minutes, wait patiently.")
|
||||||
base_model: "PreTrainedModel" = lora_model.get_base_model()
|
peft_model = get_peft_model(model, lora_config)
|
||||||
|
loftq_dir = os.path.join(output_dir, "loftq_init")
|
||||||
|
|
||||||
# Save LoftQ model
|
# Save LoftQ model
|
||||||
setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir)
|
setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir)
|
||||||
setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True)
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
||||||
lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors)
|
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
||||||
|
print("Adapter weights saved in {}".format(loftq_dir))
|
||||||
|
|
||||||
# Save base model
|
# Save base model
|
||||||
unwrap_model(base_model)
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
base_model.save_pretrained(save_dir, safe_serialization=save_safetensors)
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
tokenizer.save_pretrained(save_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
print("Model weights saved in {}".format(output_dir))
|
||||||
|
|
||||||
|
print("- Fine-tune this model with:")
|
||||||
|
print("model_name_or_path: {}".format(output_dir))
|
||||||
|
print("adapter_name_or_path: {}".format(loftq_dir))
|
||||||
|
print("finetuning_type: lora")
|
||||||
|
print("quantization_bit: {}".format(loftq_bits))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
|
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import fire
|
||||||
|
from peft import LoraConfig, TaskType, get_peft_model
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_pissa(
|
||||||
|
model_name_or_path: str,
|
||||||
|
output_dir: str,
|
||||||
|
pissa_iter: int = 4,
|
||||||
|
lora_alpha: int = None,
|
||||||
|
lora_rank: int = 16,
|
||||||
|
lora_dropout: float = 0,
|
||||||
|
lora_target: tuple = ("q_proj", "v_proj"),
|
||||||
|
save_safetensors: bool = True,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
|
||||||
|
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||||
|
"""
|
||||||
|
if isinstance(lora_target, str):
|
||||||
|
lora_target = [name.strip() for name in lora_target.split(",")]
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||||
|
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
r=lora_rank,
|
||||||
|
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
target_modules=lora_target,
|
||||||
|
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init PiSSA model
|
||||||
|
peft_model = get_peft_model(model, lora_config)
|
||||||
|
pissa_dir = os.path.join(output_dir, "pissa_init")
|
||||||
|
|
||||||
|
# Save PiSSA model
|
||||||
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
||||||
|
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
||||||
|
print("Adapter weights saved in {}".format(pissa_dir))
|
||||||
|
|
||||||
|
# Save base model
|
||||||
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
print("Model weights saved in {}".format(output_dir))
|
||||||
|
|
||||||
|
print("- Fine-tune this model with:")
|
||||||
|
print("model_name_or_path: {}".format(output_dir))
|
||||||
|
print("adapter_name_or_path: {}".format(pissa_dir))
|
||||||
|
print("finetuning_type: lora")
|
||||||
|
print("pissa_init: false")
|
||||||
|
print("pissa_convert: true")
|
||||||
|
print("- and optionally with:")
|
||||||
|
print("quantization_bit: 4")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(quantize_pissa)
|
|
@ -1,3 +1,18 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
26
setup.py
26
setup.py
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
@ -23,14 +37,16 @@ extra_require = {
|
||||||
"torch": ["torch>=1.13.1"],
|
"torch": ["torch>=1.13.1"],
|
||||||
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
||||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||||
"deepspeed": ["deepspeed>=0.10.0,<=0.14.0"],
|
"deepspeed": ["deepspeed>=0.10.0"],
|
||||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||||
"vllm": ["vllm>=0.4.3"],
|
"hqq": ["hqq"],
|
||||||
"galore": ["galore-torch"],
|
"eetq": ["eetq"],
|
||||||
"badam": ["badam"],
|
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
||||||
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
|
||||||
"awq": ["autoawq"],
|
"awq": ["autoawq"],
|
||||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||||
|
"vllm": ["vllm>=0.4.3"],
|
||||||
|
"galore": ["galore-torch"],
|
||||||
|
"badam": ["badam>=1.2.1"],
|
||||||
"qwen": ["transformers_stream_generator"],
|
"qwen": ["transformers_stream_generator"],
|
||||||
"modelscope": ["modelscope"],
|
"modelscope": ["modelscope"],
|
||||||
"dev": ["ruff", "pytest"],
|
"dev": ["ruff", "pytest"],
|
||||||
|
|
14
src/api.py
14
src/api.py
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
||||||
|
|
||||||
from .cli import VERSION
|
from .cli import VERSION
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
@ -78,9 +92,11 @@ def _process_request(
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
|
||||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||||
name = message.tool_calls[0].function.name
|
tool_calls = [
|
||||||
arguments = message.tool_calls[0].function.arguments
|
{"name": tool_call.function.name, "arguments": tool_call.function.arguments}
|
||||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
for tool_call in message.tool_calls
|
||||||
|
]
|
||||||
|
content = json.dumps(tool_calls, ensure_ascii=False)
|
||||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||||
elif isinstance(message.content, list):
|
elif isinstance(message.content, list):
|
||||||
for input_item in message.content:
|
for input_item in message.content:
|
||||||
|
@ -104,7 +120,7 @@ def _process_request(
|
||||||
if isinstance(tool_list, list) and len(tool_list):
|
if isinstance(tool_list, list) and len(tool_list):
|
||||||
try:
|
try:
|
||||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||||
except Exception:
|
except json.JSONDecodeError:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||||
else:
|
else:
|
||||||
tools = None
|
tools = None
|
||||||
|
@ -146,15 +162,17 @@ async def create_chat_completion_response(
|
||||||
choices = []
|
choices = []
|
||||||
for i, response in enumerate(responses):
|
for i, response in enumerate(responses):
|
||||||
if tools:
|
if tools:
|
||||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
result = chat_model.engine.template.extract_tool(response.response_text)
|
||||||
else:
|
else:
|
||||||
result = response.response_text
|
result = response.response_text
|
||||||
|
|
||||||
if isinstance(result, tuple):
|
if isinstance(result, list):
|
||||||
name, arguments = result
|
tool_calls = []
|
||||||
function = Function(name=name, arguments=arguments)
|
for tool in result:
|
||||||
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
|
function = Function(name=tool[0], arguments=tool[1])
|
||||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
|
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
||||||
|
|
||||||
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||||
finish_reason = Finish.TOOL
|
finish_reason = Finish.TOOL
|
||||||
else:
|
else:
|
||||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any, Dict
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .base_engine import BaseEngine
|
from .base_engine import BaseEngine
|
||||||
from .chat_model import ChatModel
|
from .chat_model import ChatModel
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
||||||
|
@ -36,11 +50,6 @@ class BaseEngine(ABC):
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def start(
|
|
||||||
self,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,3 +1,20 @@
|
||||||
|
# Copyright 2024 THUDM and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the THUDM's ChatGLM implementation.
|
||||||
|
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||||
|
@ -14,7 +31,7 @@ if TYPE_CHECKING:
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
|
|
||||||
|
@ -32,7 +49,6 @@ class ChatModel:
|
||||||
self._loop = asyncio.new_event_loop()
|
self._loop = asyncio.new_event_loop()
|
||||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
|
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import os
|
import os
|
||||||
|
@ -40,11 +54,19 @@ class HuggingfaceEngine(BaseEngine):
|
||||||
self.tokenizer = tokenizer_module["tokenizer"]
|
self.tokenizer = tokenizer_module["tokenizer"]
|
||||||
self.processor = tokenizer_module["processor"]
|
self.processor = tokenizer_module["processor"]
|
||||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
|
||||||
self.model = load_model(
|
self.model = load_model(
|
||||||
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||||
) # must after fixing tokenizer to resize vocab
|
) # must after fixing tokenizer to resize vocab
|
||||||
self.generating_args = generating_args.to_dict()
|
self.generating_args = generating_args.to_dict()
|
||||||
|
try:
|
||||||
|
asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning("There is no current event loop, creating a new one.")
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_args(
|
def _process_args(
|
||||||
|
@ -245,9 +267,6 @@ class HuggingfaceEngine(BaseEngine):
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
|
@ -272,7 +291,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||||
image,
|
image,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self._semaphore:
|
async with self.semaphore:
|
||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
return await loop.run_in_executor(pool, self._chat, *input_args)
|
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||||
|
|
||||||
|
@ -300,7 +319,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||||
image,
|
image,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self._semaphore:
|
async with self.semaphore:
|
||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
stream = self._stream_chat(*input_args)
|
stream = self._stream_chat(*input_args)
|
||||||
while True:
|
while True:
|
||||||
|
@ -319,6 +338,6 @@ class HuggingfaceEngine(BaseEngine):
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||||
async with self._semaphore:
|
async with self.semaphore:
|
||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
||||||
|
|
|
@ -1,10 +1,24 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import get_device_count
|
from ..extras.misc import get_device_count
|
||||||
from ..extras.packages import is_vllm_available
|
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
|
||||||
from ..model import load_config, load_tokenizer
|
from ..model import load_config, load_tokenizer
|
||||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
@ -13,7 +27,11 @@ from .base_engine import BaseEngine, Response
|
||||||
if is_vllm_available():
|
if is_vllm_available():
|
||||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import MultiModalData
|
|
||||||
|
if is_vllm_version_greater_than_0_5():
|
||||||
|
from vllm.multimodal.image import ImagePixelData
|
||||||
|
else:
|
||||||
|
from vllm.sequence import MultiModalData
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -41,14 +59,14 @@ class VllmEngine(BaseEngine):
|
||||||
self.tokenizer = tokenizer_module["tokenizer"]
|
self.tokenizer = tokenizer_module["tokenizer"]
|
||||||
self.processor = tokenizer_module["processor"]
|
self.processor = tokenizer_module["processor"]
|
||||||
self.tokenizer.padding_side = "left"
|
self.tokenizer.padding_side = "left"
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
|
||||||
self.generating_args = generating_args.to_dict()
|
self.generating_args = generating_args.to_dict()
|
||||||
|
|
||||||
engine_args = {
|
engine_args = {
|
||||||
"model": model_args.model_name_or_path,
|
"model": model_args.model_name_or_path,
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"download_dir": model_args.cache_dir,
|
"download_dir": model_args.cache_dir,
|
||||||
"dtype": model_args.vllm_dtype,
|
"dtype": model_args.infer_dtype,
|
||||||
"max_model_len": model_args.vllm_maxlen,
|
"max_model_len": model_args.vllm_maxlen,
|
||||||
"tensor_parallel_size": get_device_count() or 1,
|
"tensor_parallel_size": get_device_count() or 1,
|
||||||
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
||||||
|
@ -106,7 +124,10 @@ class VllmEngine(BaseEngine):
|
||||||
if self.processor is not None and image is not None: # add image features
|
if self.processor is not None and image is not None: # add image features
|
||||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
if is_vllm_version_greater_than_0_5():
|
||||||
|
multi_modal_data = ImagePixelData(image=pixel_values)
|
||||||
|
else: # TODO: remove vllm 0.4.3 support
|
||||||
|
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||||
else:
|
else:
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
|
||||||
|
@ -162,9 +183,6 @@ class VllmEngine(BaseEngine):
|
||||||
)
|
)
|
||||||
return result_generator
|
return result_generator
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -60,7 +74,7 @@ class Command(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
command = sys.argv.pop(1)
|
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
|
||||||
if command == Command.API:
|
if command == Command.API:
|
||||||
run_api()
|
run_api()
|
||||||
elif command == Command.CHAT:
|
elif command == Command.CHAT:
|
||||||
|
@ -77,7 +91,7 @@ def main():
|
||||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
||||||
subprocess.run(
|
process = subprocess.run(
|
||||||
(
|
(
|
||||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||||
|
@ -92,6 +106,7 @@ def main():
|
||||||
),
|
),
|
||||||
shell=True,
|
shell=True,
|
||||||
)
|
)
|
||||||
|
sys.exit(process.returncode)
|
||||||
else:
|
else:
|
||||||
run_exp()
|
run_exp()
|
||||||
elif command == Command.WEBDEMO:
|
elif command == Command.WEBDEMO:
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
||||||
from .data_utils import Role, split_dataset
|
from .data_utils import Role, split_dataset
|
||||||
from .loader import get_dataset
|
from .loader import get_dataset
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||||
|
@ -10,6 +24,7 @@ from .data_utils import Role
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
@ -175,7 +190,10 @@ def convert_sharegpt(
|
||||||
|
|
||||||
|
|
||||||
def align_dataset(
|
def align_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
|
dataset_attr: "DatasetAttr",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
r"""
|
r"""
|
||||||
Aligned dataset:
|
Aligned dataset:
|
||||||
|
@ -208,7 +226,7 @@ def align_dataset(
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||||
desc="Converting format of dataset",
|
desc="Converting format of dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Sequence
|
from typing import Any, Dict, Sequence
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,19 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets
|
from datasets import concatenate_datasets, interleave_datasets
|
||||||
|
|
||||||
|
@ -16,6 +30,9 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
class Role(str, Enum):
|
class Role(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
@ -25,13 +42,6 @@ class Role(str, Enum):
|
||||||
OBSERVATION = "observation"
|
OBSERVATION = "observation"
|
||||||
|
|
||||||
|
|
||||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
|
||||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
|
||||||
max_target_len = max(max_target_len, reserved_label_len)
|
|
||||||
max_source_len = max_len - min(max_target_len, target_len)
|
|
||||||
return max_source_len, max_target_len
|
|
||||||
|
|
||||||
|
|
||||||
def merge_dataset(
|
def merge_dataset(
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
|
|
@ -1,83 +1,36 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
|
from typing import List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from .data_utils import SLOTS
|
||||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
|
||||||
|
|
||||||
|
|
||||||
JSON_FORMAT_PROMPT = (
|
|
||||||
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
TOOL_SYSTEM_PROMPT = (
|
|
||||||
"You have access to the following tools:\n{tool_text}"
|
|
||||||
"Use the following format if using a tool:\n"
|
|
||||||
"```\n"
|
|
||||||
"Action: tool name (one of [{tool_names}]).\n"
|
|
||||||
"Action Input: the input to the tool{format_prompt}.\n"
|
|
||||||
"```\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
|
||||||
tool_text = ""
|
|
||||||
tool_names = []
|
|
||||||
for tool in tools:
|
|
||||||
param_text = ""
|
|
||||||
for name, param in tool["parameters"]["properties"].items():
|
|
||||||
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
|
||||||
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
|
||||||
items = (
|
|
||||||
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
|
|
||||||
)
|
|
||||||
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
|
||||||
name=name,
|
|
||||||
type=param.get("type", ""),
|
|
||||||
required=required,
|
|
||||||
desc=param.get("description", ""),
|
|
||||||
enum=enum,
|
|
||||||
items=items,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
|
||||||
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
|
||||||
)
|
|
||||||
tool_names.append(tool["name"])
|
|
||||||
|
|
||||||
return TOOL_SYSTEM_PROMPT.format(
|
|
||||||
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
|
||||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
|
|
||||||
action_match = re.search(regex, content)
|
|
||||||
if not action_match:
|
|
||||||
return content
|
|
||||||
|
|
||||||
tool_name = action_match.group(1).strip()
|
|
||||||
tool_input = action_match.group(2).strip().strip('"').strip("```")
|
|
||||||
try:
|
|
||||||
arguments = json.loads(tool_input)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return content
|
|
||||||
|
|
||||||
return tool_name, json.dumps(arguments, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Formatter(ABC):
|
class Formatter(ABC):
|
||||||
slots: SLOTS = field(default_factory=list)
|
slots: SLOTS = field(default_factory=list)
|
||||||
tool_format: Optional[Literal["default"]] = None
|
tool_format: Optional[Literal["default", "glm4"]] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, **kwargs) -> SLOTS: ...
|
def apply(self, **kwargs) -> SLOTS: ...
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,34 +81,37 @@ class StringFormatter(Formatter):
|
||||||
@dataclass
|
@dataclass
|
||||||
class FunctionFormatter(Formatter):
|
class FunctionFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
has_name, has_args = False, False
|
if self.tool_format == "default":
|
||||||
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
self.slots = DefaultToolUtils.get_function_slots() + self.slots
|
||||||
if "{{name}}" in slot:
|
elif self.tool_format == "glm4":
|
||||||
has_name = True
|
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
|
||||||
if "{{arguments}}" in slot:
|
else:
|
||||||
has_args = True
|
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||||
|
|
||||||
if not has_name or not has_args:
|
|
||||||
raise ValueError("Name and arguments placeholders are required in the function formatter.")
|
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
|
functions: List[Tuple[str, str]] = []
|
||||||
try:
|
try:
|
||||||
function = json.loads(content)
|
tool_calls = json.loads(content)
|
||||||
name = function["name"]
|
if not isinstance(tool_calls, list): # parallel function call
|
||||||
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
tool_calls = [tool_calls]
|
||||||
except Exception:
|
|
||||||
name, arguments = "", ""
|
for tool_call in tool_calls:
|
||||||
|
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
functions = []
|
||||||
|
|
||||||
elements = []
|
elements = []
|
||||||
for slot in self.slots:
|
for name, arguments in functions:
|
||||||
if isinstance(slot, str):
|
for slot in self.slots:
|
||||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
if isinstance(slot, str):
|
||||||
elements.append(slot)
|
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||||
elif isinstance(slot, (dict, set)):
|
elements.append(slot)
|
||||||
elements.append(slot)
|
elif isinstance(slot, (dict, set)):
|
||||||
else:
|
elements.append(slot)
|
||||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
else:
|
||||||
|
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
|
@ -163,25 +119,22 @@ class FunctionFormatter(Formatter):
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolFormatter(Formatter):
|
class ToolFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tool_format is None:
|
if self.tool_format == "default":
|
||||||
raise ValueError("Tool format was not found.")
|
self._tool_formatter = DefaultToolUtils.tool_formatter
|
||||||
|
self._tool_extractor = DefaultToolUtils.tool_extractor
|
||||||
|
elif self.tool_format == "glm4":
|
||||||
|
self._tool_formatter = GLM4ToolUtils.tool_formatter
|
||||||
|
self._tool_extractor = GLM4ToolUtils.tool_extractor
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
try:
|
try:
|
||||||
tools = json.loads(content)
|
tools = json.loads(content)
|
||||||
if not len(tools):
|
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
||||||
return [""]
|
except json.JSONDecodeError:
|
||||||
|
|
||||||
if self.tool_format == "default":
|
|
||||||
return [default_tool_formatter(tools)]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
except Exception:
|
|
||||||
return [""]
|
return [""]
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
if self.tool_format == "default":
|
return self._tool_extractor(content)
|
||||||
return default_tool_extractor(content)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
@ -18,8 +32,7 @@ from .template import get_template_and_fix_tokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ..hparams import DataArguments, ModelArguments
|
from ..hparams import DataArguments, ModelArguments
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
@ -32,6 +45,7 @@ def load_single_dataset(
|
||||||
dataset_attr: "DatasetAttr",
|
dataset_attr: "DatasetAttr",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
|
@ -123,7 +137,7 @@ def load_single_dataset(
|
||||||
max_samples = min(data_args.max_samples, len(dataset))
|
max_samples = min(data_args.max_samples, len(dataset))
|
||||||
dataset = dataset.select(range(max_samples))
|
dataset = dataset.select(range(max_samples))
|
||||||
|
|
||||||
return align_dataset(dataset, dataset_attr, data_args)
|
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
|
@ -134,7 +148,7 @@ def get_dataset(
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"] = None,
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
|
||||||
|
@ -157,7 +171,8 @@ def get_dataset(
|
||||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||||
|
|
||||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||||
|
|
||||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset"):
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
|
@ -169,7 +184,7 @@ def get_dataset(
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||||
desc="Running tokenizer on dataset",
|
desc="Running tokenizer on dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
||||||
|
|
||||||
|
@ -13,8 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
from .template import Template
|
from .template import Template
|
||||||
|
|
|
@ -1,13 +1,26 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
@ -42,12 +55,8 @@ def _encode_feedback_example(
|
||||||
else:
|
else:
|
||||||
kl_messages = prompt + [kl_response[1]]
|
kl_messages = prompt + [kl_response[1]]
|
||||||
|
|
||||||
prompt_ids, response_ids = template.encode_oneturn(
|
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
_, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
||||||
)
|
|
||||||
_, kl_response_ids = template.encode_oneturn(
|
|
||||||
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
|
||||||
)
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
response_ids += [tokenizer.eos_token_id]
|
response_ids += [tokenizer.eos_token_id]
|
||||||
|
@ -57,6 +66,12 @@ def _encode_feedback_example(
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||||
|
|
||||||
|
# do not consider the kl_response
|
||||||
|
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
|
||||||
|
prompt_ids = prompt_ids[:source_len]
|
||||||
|
response_ids = response_ids[:target_len]
|
||||||
|
kl_response_ids = kl_response_ids[:target_len]
|
||||||
|
|
||||||
input_ids = prompt_ids + response_ids
|
input_ids = prompt_ids + response_ids
|
||||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||||
kl_input_ids = prompt_ids + kl_response_ids
|
kl_input_ids = prompt_ids + kl_response_ids
|
||||||
|
|
|
@ -1,13 +1,26 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
@ -31,12 +44,8 @@ def _encode_pairwise_example(
|
||||||
|
|
||||||
chosen_messages = prompt + [response[0]]
|
chosen_messages = prompt + [response[0]]
|
||||||
rejected_messages = prompt + [response[1]]
|
rejected_messages = prompt + [response[1]]
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||||
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
||||||
)
|
|
||||||
_, rejected_ids = template.encode_oneturn(
|
|
||||||
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
|
||||||
)
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
chosen_ids += [tokenizer.eos_token_id]
|
chosen_ids += [tokenizer.eos_token_id]
|
||||||
|
@ -46,6 +55,13 @@ def _encode_pairwise_example(
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||||
|
|
||||||
|
source_len, target_len = infer_seqlen(
|
||||||
|
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
|
||||||
|
) # consider the response is more important
|
||||||
|
prompt_ids = prompt_ids[:source_len]
|
||||||
|
chosen_ids = chosen_ids[:target_len]
|
||||||
|
rejected_ids = rejected_ids[:target_len]
|
||||||
|
|
||||||
chosen_input_ids = prompt_ids + chosen_ids
|
chosen_input_ids = prompt_ids + chosen_ids
|
||||||
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
||||||
rejected_input_ids = prompt_ids + rejected_ids
|
rejected_input_ids = prompt_ids + rejected_ids
|
||||||
|
|
|
@ -1,9 +1,26 @@
|
||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List
|
from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
|
|
||||||
|
@ -12,7 +29,8 @@ def preprocess_pretrain_dataset(
|
||||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
|
||||||
|
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
|
||||||
|
|
||||||
if not data_args.packing:
|
if not data_args.packing:
|
||||||
if data_args.template == "gemma":
|
if data_args.template == "gemma":
|
||||||
|
|
|
@ -1,5 +1,19 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from typing import TYPE_CHECKING, List, Sequence
|
from typing import TYPE_CHECKING, List, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.packages import is_pillow_available
|
from ...extras.packages import is_pillow_available
|
||||||
|
|
||||||
|
@ -62,3 +76,16 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
|
||||||
"""
|
"""
|
||||||
image_seq_length = getattr(processor, "image_seq_length")
|
image_seq_length = getattr(processor, "image_seq_length")
|
||||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
||||||
|
if target_len * 2 < cutoff_len: # truncate source
|
||||||
|
max_target_len = cutoff_len
|
||||||
|
elif source_len * 2 < cutoff_len: # truncate target
|
||||||
|
max_target_len = cutoff_len - source_len
|
||||||
|
else: # truncate both
|
||||||
|
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
||||||
|
|
||||||
|
new_target_len = min(max_target_len, target_len)
|
||||||
|
new_source_len = max(cutoff_len - new_target_len, 0)
|
||||||
|
return new_source_len, new_target_len
|
||||||
|
|
|
@ -1,14 +1,27 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
|
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
@ -38,10 +51,17 @@ def _encode_supervised_example(
|
||||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||||
|
|
||||||
encoded_pairs = template.encode_multiturn(
|
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
total_length = 1 if template.efficient_eos else 0
|
||||||
)
|
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||||
|
if total_length >= data_args.cutoff_len:
|
||||||
|
break
|
||||||
|
|
||||||
|
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
|
||||||
|
source_ids = source_ids[:source_len]
|
||||||
|
target_ids = target_ids[:target_len]
|
||||||
|
total_length += source_len + target_len
|
||||||
|
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
|
|
|
@ -1,13 +1,26 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ..data_utils import Role
|
from ..data_utils import Role
|
||||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
@ -34,9 +47,7 @@ def _encode_unsupervised_example(
|
||||||
else:
|
else:
|
||||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(
|
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
|
||||||
)
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
labels += [tokenizer.eos_token_id]
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
@ -44,6 +55,9 @@ def _encode_unsupervised_example(
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||||
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
||||||
|
|
||||||
|
source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
|
||||||
|
input_ids = input_ids[:source_len]
|
||||||
|
labels = labels[:target_len]
|
||||||
return input_ids, labels
|
return input_ids, labels
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,22 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .data_utils import Role, infer_max_len
|
from .data_utils import Role
|
||||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,69 +38,74 @@ class Template:
|
||||||
format_observation: "Formatter"
|
format_observation: "Formatter"
|
||||||
format_tools: "Formatter"
|
format_tools: "Formatter"
|
||||||
format_separator: "Formatter"
|
format_separator: "Formatter"
|
||||||
|
format_prefix: "Formatter"
|
||||||
default_system: str
|
default_system: str
|
||||||
stop_words: List[str]
|
stop_words: List[str]
|
||||||
image_token: str
|
image_token: str
|
||||||
efficient_eos: bool
|
efficient_eos: bool
|
||||||
replace_eos: bool
|
replace_eos: bool
|
||||||
force_system: bool
|
|
||||||
|
|
||||||
def encode_oneturn(
|
def encode_oneturn(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: int = 1_000_000,
|
|
||||||
reserved_label_len: int = 1,
|
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
r"""
|
r"""
|
||||||
Returns a single pair of token ids representing prompt and response respectively.
|
Returns a single pair of token ids representing prompt and response respectively.
|
||||||
"""
|
"""
|
||||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||||
prompt_ids = []
|
prompt_ids = []
|
||||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
for encoded_ids in encoded_messages[:-1]:
|
||||||
prompt_ids += query_ids + resp_ids
|
prompt_ids += encoded_ids
|
||||||
prompt_ids = prompt_ids + encoded_pairs[-1][0]
|
|
||||||
answer_ids = encoded_pairs[-1][1]
|
answer_ids = encoded_messages[-1]
|
||||||
return prompt_ids, answer_ids
|
return prompt_ids, answer_ids
|
||||||
|
|
||||||
def encode_multiturn(
|
def encode_multiturn(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: int = 1_000_000,
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
reserved_label_len: int = 1,
|
|
||||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
|
||||||
r"""
|
r"""
|
||||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
"""
|
"""
|
||||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||||
|
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||||
|
|
||||||
|
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
|
r"""
|
||||||
|
Extracts tool message.
|
||||||
|
"""
|
||||||
|
return self.format_tools.extract(content)
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str],
|
system: Optional[str],
|
||||||
tools: Optional[str],
|
tools: Optional[str],
|
||||||
cutoff_len: int,
|
) -> List[List[int]]:
|
||||||
reserved_label_len: int,
|
|
||||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
|
||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
Turn 0: system + query resp
|
Turn 0: prefix + system + query resp
|
||||||
Turn t: sep + query resp
|
Turn t: sep + query resp
|
||||||
"""
|
"""
|
||||||
system = system or self.default_system
|
system = system or self.default_system
|
||||||
encoded_messages = []
|
encoded_messages = []
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
elements = []
|
elements = []
|
||||||
if i == 0 and (system or tools or self.force_system):
|
|
||||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
if i == 0:
|
||||||
elements += self.format_system.apply(content=(system + tool_text))
|
elements += self.format_prefix.apply()
|
||||||
elif i > 0 and i % 2 == 0:
|
if system or tools:
|
||||||
|
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||||
|
elements += self.format_system.apply(content=(system + tool_text))
|
||||||
|
|
||||||
|
if i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER.value:
|
if message["role"] == Role.USER.value:
|
||||||
|
@ -102,11 +121,9 @@ class Template:
|
||||||
|
|
||||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
return encoded_messages
|
||||||
|
|
||||||
def _convert_elements_to_ids(
|
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
|
||||||
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
|
|
||||||
) -> List[int]:
|
|
||||||
r"""
|
r"""
|
||||||
Converts elements to token ids.
|
Converts elements to token ids.
|
||||||
"""
|
"""
|
||||||
|
@ -127,57 +144,34 @@ class Template:
|
||||||
|
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
def _make_pairs(
|
|
||||||
self,
|
|
||||||
encoded_messages: Sequence[List[int]],
|
|
||||||
cutoff_len: int,
|
|
||||||
reserved_label_len: int,
|
|
||||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
|
||||||
encoded_pairs = []
|
|
||||||
total_length = 0
|
|
||||||
for i in range(0, len(encoded_messages), 2):
|
|
||||||
if total_length >= cutoff_len:
|
|
||||||
break
|
|
||||||
|
|
||||||
max_source_len, max_target_len = infer_max_len(
|
|
||||||
source_len=len(encoded_messages[i]),
|
|
||||||
target_len=len(encoded_messages[i + 1]),
|
|
||||||
max_len=(cutoff_len - total_length),
|
|
||||||
reserved_label_len=reserved_label_len,
|
|
||||||
)
|
|
||||||
source_ids = encoded_messages[i][:max_source_len]
|
|
||||||
target_ids = encoded_messages[i + 1][:max_target_len]
|
|
||||||
total_length += len(source_ids) + len(target_ids)
|
|
||||||
encoded_pairs.append((source_ids, target_ids))
|
|
||||||
|
|
||||||
return encoded_pairs
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Llama2Template(Template):
|
class Llama2Template(Template):
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tools: str,
|
tools: str,
|
||||||
cutoff_len: int,
|
) -> List[List[int]]:
|
||||||
reserved_label_len: int,
|
|
||||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
|
||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
Turn 0: system + query resp
|
Turn 0: prefix + system + query resp
|
||||||
Turn t: sep + query resp
|
Turn t: sep + query resp
|
||||||
"""
|
"""
|
||||||
system = system or self.default_system
|
system = system or self.default_system
|
||||||
encoded_messages = []
|
encoded_messages = []
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
elements = []
|
elements = []
|
||||||
|
|
||||||
system_text = ""
|
system_text = ""
|
||||||
if i == 0 and (system or tools or self.force_system):
|
if i == 0:
|
||||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
elements += self.format_prefix.apply()
|
||||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
if system or tools:
|
||||||
elif i > 0 and i % 2 == 0:
|
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||||
|
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||||
|
|
||||||
|
if i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER.value:
|
if message["role"] == Role.USER.value:
|
||||||
|
@ -193,7 +187,7 @@ class Llama2Template(Template):
|
||||||
|
|
||||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
return encoded_messages
|
||||||
|
|
||||||
|
|
||||||
TEMPLATES: Dict[str, Template] = {}
|
TEMPLATES: Dict[str, Template] = {}
|
||||||
|
@ -208,12 +202,12 @@ def _register_template(
|
||||||
format_observation: Optional["Formatter"] = None,
|
format_observation: Optional["Formatter"] = None,
|
||||||
format_tools: Optional["Formatter"] = None,
|
format_tools: Optional["Formatter"] = None,
|
||||||
format_separator: Optional["Formatter"] = None,
|
format_separator: Optional["Formatter"] = None,
|
||||||
|
format_prefix: Optional["Formatter"] = None,
|
||||||
default_system: str = "",
|
default_system: str = "",
|
||||||
stop_words: List[str] = [],
|
stop_words: Sequence[str] = [],
|
||||||
image_token: str = "<image>",
|
image_token: str = "<image>",
|
||||||
efficient_eos: bool = False,
|
efficient_eos: bool = False,
|
||||||
replace_eos: bool = False,
|
replace_eos: bool = False,
|
||||||
force_system: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Registers a chat template.
|
Registers a chat template.
|
||||||
|
@ -245,9 +239,10 @@ def _register_template(
|
||||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
|
||||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||||
default_separator_formatter = EmptyFormatter()
|
default_separator_formatter = EmptyFormatter()
|
||||||
|
default_prefix_formatter = EmptyFormatter()
|
||||||
TEMPLATES[name] = template_class(
|
TEMPLATES[name] = template_class(
|
||||||
format_user=format_user or default_user_formatter,
|
format_user=format_user or default_user_formatter,
|
||||||
format_assistant=format_assistant or default_assistant_formatter,
|
format_assistant=format_assistant or default_assistant_formatter,
|
||||||
|
@ -256,12 +251,12 @@ def _register_template(
|
||||||
format_observation=format_observation or format_user or default_user_formatter,
|
format_observation=format_observation or format_user or default_user_formatter,
|
||||||
format_tools=format_tools or default_tool_formatter,
|
format_tools=format_tools or default_tool_formatter,
|
||||||
format_separator=format_separator or default_separator_formatter,
|
format_separator=format_separator or default_separator_formatter,
|
||||||
|
format_prefix=format_prefix or default_prefix_formatter,
|
||||||
default_system=default_system,
|
default_system=default_system,
|
||||||
stop_words=stop_words,
|
stop_words=stop_words,
|
||||||
image_token=image_token,
|
image_token=image_token,
|
||||||
efficient_eos=efficient_eos,
|
efficient_eos=efficient_eos,
|
||||||
replace_eos=replace_eos,
|
replace_eos=replace_eos,
|
||||||
force_system=force_system,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -307,6 +302,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
|
||||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
||||||
jinja_template = ""
|
jinja_template = ""
|
||||||
|
|
||||||
|
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
||||||
|
if prefix:
|
||||||
|
jinja_template += "{{ " + prefix + " }}"
|
||||||
|
|
||||||
if template.default_system:
|
if template.default_system:
|
||||||
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
||||||
|
|
||||||
|
@ -315,11 +314,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||||
)
|
)
|
||||||
|
|
||||||
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
||||||
if isinstance(template, Llama2Template):
|
if not isinstance(template, Llama2Template):
|
||||||
pass
|
|
||||||
elif template.force_system:
|
|
||||||
jinja_template += "{{ " + system_message + " }}"
|
|
||||||
else:
|
|
||||||
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||||
|
|
||||||
jinja_template += "{% for message in messages %}"
|
jinja_template += "{% for message in messages %}"
|
||||||
|
@ -346,6 +341,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||||
def get_template_and_fix_tokenizer(
|
def get_template_and_fix_tokenizer(
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
tool_format: Optional[str] = None,
|
||||||
) -> Template:
|
) -> Template:
|
||||||
if name is None:
|
if name is None:
|
||||||
template = TEMPLATES["empty"] # placeholder
|
template = TEMPLATES["empty"] # placeholder
|
||||||
|
@ -354,6 +350,12 @@ def get_template_and_fix_tokenizer(
|
||||||
if template is None:
|
if template is None:
|
||||||
raise ValueError("Template {} does not exist.".format(name))
|
raise ValueError("Template {} does not exist.".format(name))
|
||||||
|
|
||||||
|
if tool_format is not None:
|
||||||
|
logger.info("Using tool format: {}.".format(tool_format))
|
||||||
|
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||||
|
template.format_tools = ToolFormatter(tool_format=tool_format)
|
||||||
|
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
|
||||||
|
|
||||||
stop_words = template.stop_words
|
stop_words = template.stop_words
|
||||||
if template.replace_eos:
|
if template.replace_eos:
|
||||||
if not stop_words:
|
if not stop_words:
|
||||||
|
@ -435,9 +437,8 @@ _register_template(
|
||||||
_register_template(
|
_register_template(
|
||||||
name="belle",
|
name="belle",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||||
force_system=True,
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -450,11 +451,7 @@ _register_template(
|
||||||
_register_template(
|
_register_template(
|
||||||
name="breeze",
|
name="breeze",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
default_system=(
|
|
||||||
"You are a helpful AI assistant built by MediaTek Research. "
|
|
||||||
"The user you are helping speaks Traditional Chinese and comes from Taiwan."
|
|
||||||
),
|
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -462,10 +459,9 @@ _register_template(
|
||||||
_register_template(
|
_register_template(
|
||||||
name="chatglm2",
|
name="chatglm2",
|
||||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -473,32 +469,13 @@ _register_template(
|
||||||
name="chatglm3",
|
name="chatglm3",
|
||||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
||||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||||
),
|
),
|
||||||
stop_words=["<|user|>", "<|observation|>"],
|
format_tools=ToolFormatter(tool_format="glm4"),
|
||||||
efficient_eos=True,
|
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
|
||||||
name="chatglm3_system",
|
|
||||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
|
||||||
format_system=StringFormatter(
|
|
||||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
|
||||||
),
|
|
||||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
|
||||||
format_observation=StringFormatter(
|
|
||||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
|
||||||
),
|
|
||||||
default_system=(
|
|
||||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
|
||||||
"Follow the user's instructions carefully. Respond using markdown."
|
|
||||||
),
|
|
||||||
stop_words=["<|user|>", "<|observation|>"],
|
stop_words=["<|user|>", "<|observation|>"],
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
)
|
)
|
||||||
|
@ -529,8 +506,7 @@ _register_template(
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="codegeex2",
|
name="codegeex2",
|
||||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -544,21 +520,15 @@ _register_template(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(
|
format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
|
||||||
slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
),
|
|
||||||
default_system=(
|
|
||||||
"You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users "
|
|
||||||
"by providing thorough responses. You are trained by Cohere."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="cpm",
|
name="cpm",
|
||||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -591,30 +561,28 @@ _register_template(
|
||||||
_register_template(
|
_register_template(
|
||||||
name="deepseek",
|
name="deepseek",
|
||||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="deepseekcoder",
|
name="deepseekcoder",
|
||||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
|
||||||
format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
default_system=(
|
default_system=(
|
||||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||||
"For politically sensitive questions, security and privacy issues, "
|
"For politically sensitive questions, security and privacy issues, "
|
||||||
"and other non-computer science questions, you will refuse to answer\n"
|
"and other non-computer science questions, you will refuse to answer\n"
|
||||||
),
|
),
|
||||||
stop_words=["<|EOT|>"],
|
|
||||||
efficient_eos=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="default",
|
name="default",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
|
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
|
||||||
format_system=StringFormatter(slots=["{{content}}\n"]),
|
format_system=StringFormatter(slots=["{{content}}\n"]),
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
)
|
)
|
||||||
|
@ -622,11 +590,7 @@ _register_template(
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="empty",
|
name="empty",
|
||||||
format_user=StringFormatter(slots=["{{content}}"]),
|
|
||||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -648,13 +612,12 @@ _register_template(
|
||||||
_register_template(
|
_register_template(
|
||||||
name="gemma",
|
name="gemma",
|
||||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||||
),
|
),
|
||||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -662,36 +625,33 @@ _register_template(
|
||||||
name="glm4",
|
name="glm4",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||||
format_system=StringFormatter(slots=["[gMASK]<sop>{{content}}"]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||||
|
format_tools=ToolFormatter(tool_format="glm4"),
|
||||||
|
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||||
stop_words=["<|user|>", "<|observation|>"],
|
stop_words=["<|user|>", "<|observation|>"],
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||||
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
|
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["<eoa>\n"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<eoa>"],
|
stop_words=["<eoa>"],
|
||||||
efficient_eos=True,
|
efficient_eos=True, # internlm tokenizer cannot set eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="intern2",
|
name="intern2",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
|
||||||
default_system=(
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
|
||||||
"- InternLM (书生·浦语) is a conversational language model that is developed "
|
|
||||||
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
|
|
||||||
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
|
|
||||||
"by the user such as English and 中文."
|
|
||||||
),
|
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
||||||
)
|
)
|
||||||
|
@ -700,7 +660,6 @@ _register_template(
|
||||||
_register_template(
|
_register_template(
|
||||||
name="llama2",
|
name="llama2",
|
||||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||||
format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
|
|
||||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -723,9 +682,7 @@ _register_template(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
|
|
||||||
),
|
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
|
@ -734,7 +691,7 @@ _register_template(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
default_system="You are a helpful assistant.",
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|eot_id|>"],
|
stop_words=["<|eot_id|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
)
|
)
|
||||||
|
@ -743,24 +700,21 @@ _register_template(
|
||||||
_register_template(
|
_register_template(
|
||||||
name="mistral",
|
name="mistral",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="olmo",
|
name="olmo",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||||
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="openchat",
|
name="openchat",
|
||||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -774,27 +728,25 @@ _register_template(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|eot_id|>"],
|
stop_words=["<|eot_id|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="orion",
|
name="orion",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="phi",
|
name="phi",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
default_system="You are a helpful AI assistant.",
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|end|>"],
|
stop_words=["<|end|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
)
|
)
|
||||||
|
@ -827,7 +779,6 @@ _register_template(
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
stop_words=["<|end|>"],
|
stop_words=["<|end|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,140 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
from .data_utils import SLOTS
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_TOOL_PROMPT = (
|
||||||
|
"You have access to the following tools:\n{tool_text}"
|
||||||
|
"Use the following format if using a tool:\n"
|
||||||
|
"```\n"
|
||||||
|
"Action: tool name (one of [{tool_names}]).\n"
|
||||||
|
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||||
|
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
|
||||||
|
"```\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
GLM4_TOOL_PROMPT = (
|
||||||
|
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||||
|
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolUtils(ABC):
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_function_slots() -> SLOTS: ...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultToolUtils(ToolUtils):
|
||||||
|
@staticmethod
|
||||||
|
def get_function_slots() -> SLOTS:
|
||||||
|
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
tool_names = []
|
||||||
|
for tool in tools:
|
||||||
|
param_text = ""
|
||||||
|
for name, param in tool["parameters"]["properties"].items():
|
||||||
|
required, enum, items = "", "", ""
|
||||||
|
if name in tool["parameters"].get("required", []):
|
||||||
|
required = ", required"
|
||||||
|
|
||||||
|
if param.get("enum", None):
|
||||||
|
enum = ", should be one of [{}]".format(", ".join(param["enum"]))
|
||||||
|
|
||||||
|
if param.get("items", None):
|
||||||
|
items = ", where each item should be {}".format(param["items"].get("type", ""))
|
||||||
|
|
||||||
|
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
||||||
|
name=name,
|
||||||
|
type=param.get("type", ""),
|
||||||
|
required=required,
|
||||||
|
desc=param.get("description", ""),
|
||||||
|
enum=enum,
|
||||||
|
items=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||||
|
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
||||||
|
)
|
||||||
|
tool_names.append(tool["name"])
|
||||||
|
|
||||||
|
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
|
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||||
|
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||||
|
if not action_match:
|
||||||
|
return content
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for match in action_match:
|
||||||
|
tool_name = match[0].strip()
|
||||||
|
tool_input = match[1].strip().strip('"').strip("```")
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tool_input)
|
||||||
|
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class GLM4ToolUtils(ToolUtils):
|
||||||
|
@staticmethod
|
||||||
|
def get_function_slots() -> SLOTS:
|
||||||
|
return ["{{name}}\n{{arguments}}"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
for tool in tools:
|
||||||
|
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||||
|
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
|
if "\n" not in content:
|
||||||
|
return content
|
||||||
|
|
||||||
|
tool_name, tool_input = content.split("\n", maxsplit=1)
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tool_input)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
|
@ -1,4 +1,41 @@
|
||||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the Dan's test library.
|
||||||
|
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2020 Dan Hendrycks
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
@ -26,9 +63,7 @@ class Evaluator:
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||||
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
|
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
|
||||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||||
self.choice_inputs = [
|
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
||||||
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
|
||||||
]
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Sequence, Tuple
|
from typing import Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
|
@ -10,7 +24,6 @@ class EvalTemplate:
|
||||||
system: str
|
system: str
|
||||||
choice: str
|
choice: str
|
||||||
answer: str
|
answer: str
|
||||||
prefix: str
|
|
||||||
|
|
||||||
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -42,8 +55,8 @@ class EvalTemplate:
|
||||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||||
|
|
||||||
|
|
||||||
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
|
||||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
|
||||||
|
|
||||||
|
|
||||||
def get_eval_template(name: str) -> "EvalTemplate":
|
def get_eval_template(name: str) -> "EvalTemplate":
|
||||||
|
@ -56,8 +69,7 @@ _register_eval_template(
|
||||||
name="en",
|
name="en",
|
||||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
answer="\nAnswer: ",
|
answer="\nAnswer:",
|
||||||
prefix=" ",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,5 +78,4 @@ _register_eval_template(
|
||||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
answer="\n答案:",
|
answer="\n答案:",
|
||||||
prefix=" ",
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
@ -404,6 +418,18 @@ register_model_group(
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
|
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
|
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
|
||||||
},
|
},
|
||||||
|
"DeepSeek-MoE-Coder-16B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
|
||||||
|
},
|
||||||
|
"DeepSeek-MoE-Coder-236B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
|
||||||
|
},
|
||||||
|
"DeepSeek-MoE-Coder-16B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
||||||
|
},
|
||||||
|
"DeepSeek-MoE-Coder-236B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
template="deepseek",
|
template="deepseek",
|
||||||
)
|
)
|
||||||
|
@ -496,6 +522,18 @@ register_model_group(
|
||||||
"Gemma-1.1-7B-Chat": {
|
"Gemma-1.1-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
|
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
|
||||||
},
|
},
|
||||||
|
"Gemma-2-9B": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-2-9b",
|
||||||
|
},
|
||||||
|
"Gemma-2-27B": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-2-27b",
|
||||||
|
},
|
||||||
|
"Gemma-2-9B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
||||||
|
},
|
||||||
|
"Gemma-2-27B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
template="gemma",
|
template="gemma",
|
||||||
)
|
)
|
||||||
|
@ -568,7 +606,7 @@ register_model_group(
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Jambda-v0.1": {
|
"Jamba-v0.1": {
|
||||||
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
|
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
|
||||||
}
|
}
|
||||||
|
@ -683,6 +721,21 @@ register_model_group(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"MiniCPM-2B-SFT-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
|
||||||
|
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
|
||||||
|
},
|
||||||
|
"MiniCPM-2B-DPO-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
|
||||||
|
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="cpm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Mistral-7B-v0.1": {
|
"Mistral-7B-v0.1": {
|
||||||
|
|
|
@ -1,3 +1,20 @@
|
||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
|
@ -9,7 +26,7 @@ import trl
|
||||||
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
||||||
|
|
||||||
|
|
||||||
VERSION = "0.8.1.dev0"
|
VERSION = "0.8.3.dev0"
|
||||||
|
|
||||||
|
|
||||||
def print_env() -> None:
|
def print_env() -> None:
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -1,13 +1,29 @@
|
||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's PEFT library.
|
||||||
|
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Dict, Tuple
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft import PeftModel
|
import transformers.dynamic_module_utils
|
||||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||||
|
from transformers.dynamic_module_utils import get_relative_imports
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
is_torch_bf16_gpu_available,
|
is_torch_bf16_gpu_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
|
@ -16,7 +32,6 @@ from transformers.utils import (
|
||||||
)
|
)
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,8 +43,6 @@ except Exception:
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
|
||||||
|
|
||||||
from ..hparams import ModelArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,6 +71,9 @@ class AverageMeter:
|
||||||
|
|
||||||
|
|
||||||
def check_dependencies() -> None:
|
def check_dependencies() -> None:
|
||||||
|
r"""
|
||||||
|
Checks the version of the required packages.
|
||||||
|
"""
|
||||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
else:
|
else:
|
||||||
|
@ -68,7 +84,7 @@ def check_dependencies() -> None:
|
||||||
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
|
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||||
r"""
|
r"""
|
||||||
Returns the number of trainable parameters and number of all parameters in the model.
|
Returns the number of trainable parameters and number of all parameters in the model.
|
||||||
"""
|
"""
|
||||||
|
@ -79,7 +95,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||||
num_params = param.ds_numel
|
num_params = param.ds_numel
|
||||||
|
|
||||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
|
||||||
if param.__class__.__name__ == "Params4bit":
|
if param.__class__.__name__ == "Params4bit":
|
||||||
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||||
num_bytes = param.quant_storage.itemsize
|
num_bytes = param.quant_storage.itemsize
|
||||||
|
@ -97,55 +113,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
def fix_valuehead_checkpoint(
|
def get_current_device() -> "torch.device":
|
||||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
|
||||||
) -> None:
|
|
||||||
r"""
|
|
||||||
The model is already unwrapped.
|
|
||||||
|
|
||||||
There are three cases:
|
|
||||||
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
|
||||||
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
|
||||||
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
|
||||||
|
|
||||||
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
|
||||||
"""
|
|
||||||
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
|
||||||
return
|
|
||||||
|
|
||||||
if safe_serialization:
|
|
||||||
from safetensors import safe_open
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
|
||||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
|
||||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
|
||||||
else:
|
|
||||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
|
||||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
|
||||||
|
|
||||||
decoder_state_dict = {}
|
|
||||||
v_head_state_dict = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name.startswith("v_head."):
|
|
||||||
v_head_state_dict[name] = param
|
|
||||||
else:
|
|
||||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
|
||||||
|
|
||||||
os.remove(path_to_checkpoint)
|
|
||||||
model.pretrained_model.save_pretrained(
|
|
||||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
|
||||||
)
|
|
||||||
|
|
||||||
if safe_serialization:
|
|
||||||
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
|
||||||
else:
|
|
||||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
|
||||||
|
|
||||||
logger.info("Value head model saved at: {}".format(output_dir))
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_device() -> torch.device:
|
|
||||||
r"""
|
r"""
|
||||||
Gets the current available device.
|
Gets the current available device.
|
||||||
"""
|
"""
|
||||||
|
@ -184,7 +152,14 @@ def get_logits_processor() -> "LogitsProcessorList":
|
||||||
return logits_processor
|
return logits_processor
|
||||||
|
|
||||||
|
|
||||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||||
|
r"""
|
||||||
|
Checks if the path has a tokenized dataset.
|
||||||
|
"""
|
||||||
|
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||||
r"""
|
r"""
|
||||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||||
"""
|
"""
|
||||||
|
@ -203,11 +178,9 @@ def is_gpu_or_npu_available() -> bool:
|
||||||
return is_torch_npu_available() or is_torch_cuda_available()
|
return is_torch_npu_available() or is_torch_cuda_available()
|
||||||
|
|
||||||
|
|
||||||
def has_tokenized_data(path: os.PathLike) -> bool:
|
def skip_check_imports() -> None:
|
||||||
r"""
|
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||||
Checks if the path has a tokenized dataset.
|
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||||
"""
|
|
||||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def torch_gc() -> None:
|
def torch_gc() -> None:
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue