Merge remote-tracking branch 'upstream/main'
This commit is contained in:
commit
ea1f3ba5e0
|
@ -4,6 +4,8 @@
|
|||
.venv
|
||||
cache
|
||||
data
|
||||
hf_cache
|
||||
output
|
||||
examples
|
||||
.dockerignore
|
||||
.gitattributes
|
||||
|
|
|
@ -13,6 +13,18 @@ body:
|
|||
- label: I have read the README and searched the existing issues.
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
|
||||
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
|
||||
|
||||
placeholder: llamafactory version, platform, python version, ...
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
|
@ -26,7 +38,9 @@ body:
|
|||
请合理使用 Markdown 标签来格式化您的文本。
|
||||
|
||||
placeholder: |
|
||||
python src/train_bash.py ...
|
||||
```bash
|
||||
llamafactory-cli train ...
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
|
@ -38,18 +52,6 @@ body:
|
|||
Please provide a clear and concise description of what you would expect to happen.
|
||||
请提供您原本的目的,即这段代码的期望行为。
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. You can run the command **transformers-cli env** and copy-paste its output below.
|
||||
请提供您的系统信息。您可以在命令行运行 **transformers-cli env** 并将其输出复制到该文本框中。
|
||||
|
||||
placeholder: transformers version, platform, python version, ...
|
||||
|
||||
- type: textarea
|
||||
id: others
|
||||
validations:
|
||||
|
|
|
@ -5,3 +5,4 @@ Fixes # (issue)
|
|||
## Before submitting
|
||||
|
||||
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
|
||||
- [ ] Did you write any new necessary tests?
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
name: label_issue
|
||||
|
||||
on:
|
||||
issues:
|
||||
types:
|
||||
- opened
|
||||
|
||||
jobs:
|
||||
label_issue:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
ISSUE_URL: ${{ github.event.issue.html_url }}
|
||||
run: |
|
||||
gh issue edit $ISSUE_URL --add-label "pending"
|
|
@ -2,28 +2,44 @@ name: tests
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**.py"
|
||||
- "requirements.txt"
|
||||
- ".github/workflows/*.yml"
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**.py"
|
||||
- "requirements.txt"
|
||||
- ".github/workflows/*.yml"
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
|
||||
tests:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.8"
|
||||
cache: "pip"
|
||||
cache-dependency-path: "setup.py"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ruff
|
||||
python -m pip install .[torch,dev]
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
make style && make quality
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
|
|
38
Dockerfile
38
Dockerfile
|
@ -1,14 +1,44 @@
|
|||
FROM nvcr.io/nvidia/pytorch:24.01-py3
|
||||
# Use the NVIDIA official image with PyTorch 2.3.0
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
|
||||
FROM nvcr.io/nvidia/pytorch:24.02-py3
|
||||
|
||||
# Define installation arguments
|
||||
ARG INSTALL_BNB=false
|
||||
ARG INSTALL_VLLM=false
|
||||
ARG INSTALL_DEEPSPEED=false
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app/
|
||||
RUN pip install -r requirements.txt
|
||||
RUN pip config set global.index-url $PIP_INDEX
|
||||
RUN python -m pip install --upgrade pip
|
||||
RUN python -m pip install -r requirements.txt
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
COPY . /app/
|
||||
RUN pip install -e .[metrics,bitsandbytes,qwen]
|
||||
|
||||
# Install the LLaMA Factory
|
||||
RUN EXTRA_PACKAGES="metrics"; \
|
||||
if [ "$INSTALL_BNB" = "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_VLLM" = "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_DEEPSPEED" = "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||
fi; \
|
||||
pip install -e .[$EXTRA_PACKAGES] && \
|
||||
pip uninstall -y transformer-engine flash-attn
|
||||
|
||||
# Set up volumes
|
||||
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
|
||||
|
||||
# Expose port 7860 for the LLaMA Board
|
||||
EXPOSE 7860
|
||||
|
||||
CMD [ "llamafactory-cli", "webui" ]
|
||||
# Expose port 8000 for the API service
|
||||
EXPOSE 8000
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
include LICENSE requirements.txt
|
5
Makefile
5
Makefile
|
@ -1,4 +1,4 @@
|
|||
.PHONY: quality style
|
||||
.PHONY: quality style test
|
||||
|
||||
check_dirs := scripts src tests
|
||||
|
||||
|
@ -9,3 +9,6 @@ quality:
|
|||
style:
|
||||
ruff check $(check_dirs) --fix
|
||||
ruff format $(check_dirs)
|
||||
|
||||
test:
|
||||
CUDA_VISIBLE_DEVICES= pytest tests/
|
||||
|
|
184
README.md
184
README.md
|
@ -8,9 +8,10 @@
|
|||
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
|
||||
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
|
||||
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||
|
||||
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
|
||||
|
||||
|
@ -25,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89
|
|||
Choose your path:
|
||||
|
||||
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **Local machine**: Please refer to [usage](#getting-started)
|
||||
|
||||
## Table of Contents
|
||||
|
@ -45,9 +47,9 @@ Choose your path:
|
|||
## Features
|
||||
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO and ORPO.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
|
||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||
|
@ -69,14 +71,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||
|
||||
## Changelog
|
||||
|
||||
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
||||
|
||||
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
||||
|
||||
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
|
||||
|
||||
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
|
||||
|
||||
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
|
||||
|
@ -145,38 +151,38 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||
|
||||
## Supported Models
|
||||
|
||||
| Model | Model size | Default module | Template |
|
||||
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | qkv_proj | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
| Model | Model size | Template |
|
||||
| -------------------------------------------------------- | -------------------------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules for better convergence.
|
||||
>
|
||||
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||
>
|
||||
> Remember to use the **SAME** template in training and inference.
|
||||
|
||||
|
@ -208,6 +214,8 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
|
||||
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
|
||||
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
|
||||
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
|
||||
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
|
||||
|
@ -251,6 +259,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
|
@ -267,6 +276,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||
<details><summary>Preference datasets</summary>
|
||||
|
||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
|
@ -286,21 +296,21 @@ huggingface-cli login
|
|||
|
||||
| Mandatory | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.8 | 3.10 |
|
||||
| torch | 1.13.1 | 2.2.0 |
|
||||
| transformers | 4.37.2 | 4.41.0 |
|
||||
| datasets | 2.14.3 | 2.19.1 |
|
||||
| accelerate | 0.27.2 | 0.30.1 |
|
||||
| peft | 0.9.0 | 0.11.1 |
|
||||
| trl | 0.8.2 | 0.8.6 |
|
||||
| python | 3.8 | 3.11 |
|
||||
| torch | 1.13.1 | 2.3.0 |
|
||||
| transformers | 4.41.2 | 4.41.2 |
|
||||
| datasets | 2.16.0 | 2.19.2 |
|
||||
| accelerate | 0.30.1 | 0.30.1 |
|
||||
| peft | 0.11.1 | 0.11.1 |
|
||||
| trl | 0.8.6 | 0.9.4 |
|
||||
|
||||
| Optional | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| CUDA | 11.6 | 12.2 |
|
||||
| deepspeed | 0.10.0 | 0.14.0 |
|
||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||
| vllm | 0.4.0 | 0.4.2 |
|
||||
| flash-attn | 2.3.0 | 2.5.8 |
|
||||
| vllm | 0.4.3 | 0.4.3 |
|
||||
| flash-attn | 2.3.0 | 2.5.9 |
|
||||
|
||||
### Hardware Requirement
|
||||
|
||||
|
@ -326,10 +336,10 @@ huggingface-cli login
|
|||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e .[torch,metrics]
|
||||
pip install -e ".[torch,metrics]"
|
||||
```
|
||||
|
||||
Extra dependencies available: torch, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
|
||||
Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
|
||||
|
||||
> [!TIP]
|
||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||
|
@ -350,14 +360,28 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
|
|||
|
||||
Join [NPU user group](assets/wechat_npu.jpg).
|
||||
|
||||
To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** library and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**.
|
||||
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
|
||||
| Requirement | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| CANN | 8.0.RC1 | 8.0.RC1 |
|
||||
| torch | 2.2.0 | 2.2.0 |
|
||||
| torch-npu | 2.2.0 | 2.2.0 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
```bash
|
||||
# replace the url according to your CANN version and devices
|
||||
# install CANN Toolkit
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
|
||||
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
|
||||
|
||||
# install CANN Kernels
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
|
||||
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
||||
|
||||
# set env variables
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
```
|
||||
|
||||
| Requirement | Minimum | Recommend |
|
||||
| ------------ | ------- | ----------- |
|
||||
| CANN | 8.0.RC1 | 8.0.RC1 |
|
||||
| torch | 2.1.0 | 2.1.0 |
|
||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
|
||||
Docker image:
|
||||
|
||||
|
@ -382,9 +406,9 @@ Please refer to [data/README.md](data/README.md) for checking the details about
|
|||
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
|
||||
|
@ -394,36 +418,38 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
|
|||
|
||||
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
||||
|
||||
> [!IMPORTANT]
|
||||
> LLaMA Board GUI only supports training on a single GPU.
|
||||
|
||||
#### Use local environment
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
</details>
|
||||
### Build Docker
|
||||
|
||||
#### Use Docker
|
||||
|
||||
```bash
|
||||
docker build -f ./Dockerfile -t llama-factory:latest .
|
||||
docker run --gpus=all \
|
||||
docker build -f ./Dockerfile \
|
||||
--build-arg INSTALL_BNB=false \
|
||||
--build-arg INSTALL_VLLM=false \
|
||||
--build-arg INSTALL_DEEPSPEED=false \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
-t llamafactory:latest .
|
||||
|
||||
docker run -it --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface/ \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-e CUDA_VISIBLE_DEVICES=0 \
|
||||
-p 7860:7860 \
|
||||
-p 8000:8000 \
|
||||
--shm-size 16G \
|
||||
--name llama_factory \
|
||||
-d llama-factory:latest
|
||||
--name llamafactory \
|
||||
llamafactory:latest
|
||||
```
|
||||
|
||||
#### Use Docker Compose
|
||||
|
||||
```bash
|
||||
docker compose -f ./docker-compose.yml up -d
|
||||
docker-compose up -d
|
||||
docker-compose exec llamafactory bash
|
||||
```
|
||||
|
||||
<details><summary>Details about volume</summary>
|
||||
|
@ -437,9 +463,12 @@ docker compose -f ./docker-compose.yml up -d
|
|||
### Deploy with OpenAI-style API and vLLM
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Visit https://platform.openai.com/docs/api-reference/chat/create for API document.
|
||||
|
||||
### Download from ModelScope Hub
|
||||
|
||||
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
|
||||
|
@ -448,7 +477,18 @@ If you have trouble with downloading models and datasets from Hugging Face, you
|
|||
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||
```
|
||||
|
||||
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
||||
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
||||
|
||||
### Use W&B Logger
|
||||
|
||||
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments.
|
||||
|
||||
```yaml
|
||||
report_to: wandb
|
||||
run_name: test_run # optional
|
||||
```
|
||||
|
||||
Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
|
||||
|
||||
## Projects using LLaMA Factory
|
||||
|
||||
|
@ -507,7 +547,7 @@ If you have a project that should be incorporated, please contact via email or c
|
|||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
188
README_zh.md
188
README_zh.md
|
@ -8,9 +8,10 @@
|
|||
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
|
||||
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
|
||||
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
||||
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
||||
|
||||
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
|
||||
|
||||
|
@ -25,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
选择你的打开方式:
|
||||
|
||||
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **本地机器**:请见[如何使用](#如何使用)
|
||||
|
||||
## 目录
|
||||
|
@ -45,9 +47,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
## 项目特色
|
||||
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练和 ORPO 训练。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
|
||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
|
||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||
|
@ -69,14 +71,18 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
|
||||
## 更新日志
|
||||
|
||||
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
|
||||
|
||||
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
|
||||
|
||||
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
|
||||
|
||||
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
@ -145,40 +151,40 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
|
||||
## 模型
|
||||
|
||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | qkv_proj | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
| 模型名 | 模型大小 | Template |
|
||||
| -------------------------------------------------------- | -------------------------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块以取得更好的效果。
|
||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||
>
|
||||
> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||
>
|
||||
> 请务必在训练和推理时使用**完全一致**的模板。
|
||||
> 请务必在训练和推理时采用**完全一致**的模板。
|
||||
|
||||
项目所支持模型的完整列表请参阅 [constants.py](src/llamafactory/extras/constants.py)。
|
||||
|
||||
|
@ -208,6 +214,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
|
||||
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
|
||||
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
|
||||
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
|
||||
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
|
||||
|
@ -251,6 +259,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
|
@ -267,6 +276,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
<details><summary>偏好数据集</summary>
|
||||
|
||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
|
@ -286,21 +296,21 @@ huggingface-cli login
|
|||
|
||||
| 必需项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.8 | 3.10 |
|
||||
| torch | 1.13.1 | 2.2.0 |
|
||||
| transformers | 4.37.2 | 4.41.0 |
|
||||
| datasets | 2.14.3 | 2.19.1 |
|
||||
| accelerate | 0.27.2 | 0.30.1 |
|
||||
| peft | 0.9.0 | 0.11.1 |
|
||||
| trl | 0.8.2 | 0.8.6 |
|
||||
| python | 3.8 | 3.11 |
|
||||
| torch | 1.13.1 | 2.3.0 |
|
||||
| transformers | 4.41.2 | 4.41.2 |
|
||||
| datasets | 2.16.0 | 2.19.2 |
|
||||
| accelerate | 0.30.1 | 0.30.1 |
|
||||
| peft | 0.11.1 | 0.11.1 |
|
||||
| trl | 0.8.6 | 0.9.4 |
|
||||
|
||||
| 可选项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| CUDA | 11.6 | 12.2 |
|
||||
| deepspeed | 0.10.0 | 0.14.0 |
|
||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||
| vllm | 0.4.0 | 0.4.2 |
|
||||
| flash-attn | 2.3.0 | 2.5.8 |
|
||||
| vllm | 0.4.3 | 0.4.3 |
|
||||
| flash-attn | 2.3.0 | 2.5.9 |
|
||||
|
||||
### 硬件依赖
|
||||
|
||||
|
@ -326,10 +336,10 @@ huggingface-cli login
|
|||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e .[torch,metrics]
|
||||
pip install -e ".[torch,metrics]"
|
||||
```
|
||||
|
||||
可选的额外依赖项:torch、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
|
||||
可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
|
||||
|
||||
> [!TIP]
|
||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||
|
@ -350,21 +360,35 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
|||
|
||||
加入 [NPU 用户群](assets/wechat_npu.jpg)。
|
||||
|
||||
如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**。
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
|
||||
| 依赖项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| CANN | 8.0.RC1 | 8.0.RC1 |
|
||||
| torch | 2.2.0 | 2.2.0 |
|
||||
| torch-npu | 2.2.0 | 2.2.0 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
```bash
|
||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||
# 安装 CANN Toolkit
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
|
||||
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
|
||||
|
||||
# 安装 CANN Kernels
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
|
||||
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
||||
|
||||
# 设置环境变量
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
```
|
||||
|
||||
| 依赖项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | ----------- |
|
||||
| CANN | 8.0.RC1 | 8.0.RC1 |
|
||||
| torch | 2.1.0 | 2.1.0 |
|
||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
|
||||
Docker 镜像:
|
||||
|
||||
- 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
|
||||
- 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。
|
||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
||||
|
||||
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
|
||||
|
||||
|
@ -382,9 +406,9 @@ Docker 镜像:
|
|||
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。
|
||||
|
@ -394,34 +418,38 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s
|
|||
|
||||
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
|
||||
|
||||
> [!IMPORTANT]
|
||||
> LLaMA Board 可视化界面目前仅支持单 GPU 训练。
|
||||
|
||||
#### 使用本地环境
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
### 构建 Docker
|
||||
|
||||
#### 使用 Docker
|
||||
|
||||
```bash
|
||||
docker build -f ./Dockerfile -t llama-factory:latest .
|
||||
docker run --gpus=all \
|
||||
docker build -f ./Dockerfile \
|
||||
--build-arg INSTALL_BNB=false \
|
||||
--build-arg INSTALL_VLLM=false \
|
||||
--build-arg INSTALL_DEEPSPEED=false \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
-t llamafactory:latest .
|
||||
|
||||
docker run -it --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface/ \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-e CUDA_VISIBLE_DEVICES=0 \
|
||||
-p 7860:7860 \
|
||||
-p 8000:8000 \
|
||||
--shm-size 16G \
|
||||
--name llama_factory \
|
||||
-d llama-factory:latest
|
||||
--name llamafactory \
|
||||
llamafactory:latest
|
||||
```
|
||||
|
||||
#### 使用 Docker Compose
|
||||
|
||||
```bash
|
||||
docker compose -f ./docker-compose.yml up -d
|
||||
docker-compose up -d
|
||||
docker-compose exec llamafactory bash
|
||||
```
|
||||
|
||||
<details><summary>数据卷详情</summary>
|
||||
|
@ -435,9 +463,12 @@ docker compose -f ./docker-compose.yml up -d
|
|||
### 利用 vLLM 部署 OpenAI API
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。
|
||||
|
||||
### 从魔搭社区下载
|
||||
|
||||
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||
|
@ -446,7 +477,18 @@ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/l
|
|||
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
```
|
||||
|
||||
将 `--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
||||
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
||||
|
||||
### 使用 W&B 面板
|
||||
|
||||
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请添加下面的参数。
|
||||
|
||||
```yaml
|
||||
report_to: wandb
|
||||
run_name: test_run # 可选
|
||||
```
|
||||
|
||||
在启动训练任务时,将 `WANDB_API_KEY` 设置为[密钥](https://wandb.ai/authorize)来登录 W&B 账户。
|
||||
|
||||
## 使用了 LLaMA Factory 的项目
|
||||
|
||||
|
@ -505,7 +547,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
|||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## 引用
|
||||
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 192 KiB After Width: | Height: | Size: 140 KiB |
Binary file not shown.
Before Width: | Height: | Size: 146 KiB After Width: | Height: | Size: 148 KiB |
|
@ -12,6 +12,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
|
|||
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
|
||||
"subset": "the name of the subset. (optional, default: None)",
|
||||
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||
"num_samples": "the number of samples in the dataset used for training. (optional, default: None)",
|
||||
"columns (optional)": {
|
||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
||||
"query": "the column name in the dataset containing the queries. (default: input)",
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||
"subset": "数据集子集的名称(可选,默认:None)",
|
||||
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||
"num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)",
|
||||
"columns(可选)": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
|
|
|
@ -248,6 +248,10 @@
|
|||
"ruozhiba_gpt4": {
|
||||
"hf_hub_url": "hfl/ruozhiba_gpt4_turbo"
|
||||
},
|
||||
"neo_sft": {
|
||||
"hf_hub_url": "m-a-p/neo_sft_phase2",
|
||||
"formatting": "sharegpt"
|
||||
},
|
||||
"llava_1k_en": {
|
||||
"hf_hub_url": "BUAADreamer/llava-en-zh-2k",
|
||||
"subset": "en",
|
||||
|
@ -308,6 +312,20 @@
|
|||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"mllm_pt_demo": {
|
||||
"hf_hub_url": "BUAADreamer/mllm_pt_demo",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages",
|
||||
"images": "images"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "role",
|
||||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"oasst_de": {
|
||||
"hf_hub_url": "mayflowergmbh/oasst_de"
|
||||
},
|
||||
|
@ -377,6 +395,16 @@
|
|||
"rejected": "rejected"
|
||||
}
|
||||
},
|
||||
"ultrafeedback": {
|
||||
"hf_hub_url": "llamafactory/ultrafeedback_binarized",
|
||||
"ms_hub_url": "llamafactory/ultrafeedback_binarized",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
},
|
||||
"orca_pairs": {
|
||||
"hf_hub_url": "Intel/orca_dpo_pairs",
|
||||
"ranking": true,
|
||||
|
@ -434,6 +462,15 @@
|
|||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"ultrafeedback_kto": {
|
||||
"hf_hub_url": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
|
||||
"ms_hub_url": "AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto",
|
||||
"columns": {
|
||||
"prompt": "prompt",
|
||||
"response": "completion",
|
||||
"kto_tag": "label"
|
||||
}
|
||||
},
|
||||
"wiki_demo": {
|
||||
"file_name": "wiki_demo.txt",
|
||||
"columns": {
|
||||
|
@ -487,6 +524,18 @@
|
|||
"prompt": "text"
|
||||
}
|
||||
},
|
||||
"fileweb": {
|
||||
"hf_hub_url": "HuggingFaceFW/fineweb",
|
||||
"columns": {
|
||||
"prompt": "text"
|
||||
}
|
||||
},
|
||||
"fileweb_edu": {
|
||||
"hf_hub_url": "HuggingFaceFW/fineweb-edu",
|
||||
"columns": {
|
||||
"prompt": "text"
|
||||
}
|
||||
},
|
||||
"the_stack": {
|
||||
"hf_hub_url": "bigcode/the-stack",
|
||||
"ms_hub_url": "AI-ModelScope/the-stack",
|
||||
|
|
|
@ -1,20 +1,25 @@
|
|||
version: '3.8'
|
||||
|
||||
services:
|
||||
llama-factory:
|
||||
llamafactory:
|
||||
build:
|
||||
dockerfile: Dockerfile
|
||||
context: .
|
||||
container_name: llama_factory
|
||||
args:
|
||||
INSTALL_BNB: false
|
||||
INSTALL_VLLM: false
|
||||
INSTALL_DEEPSPEED: false
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
container_name: llamafactory
|
||||
volumes:
|
||||
- ./hf_cache:/root/.cache/huggingface/
|
||||
- ./data:/app/data
|
||||
- ./output:/app/output
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=0
|
||||
ports:
|
||||
- "7860:7860"
|
||||
- "8000:8000"
|
||||
ipc: host
|
||||
tty: true
|
||||
stdin_open: true
|
||||
command: bash
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import datasets
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import datasets
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import datasets
|
||||
|
@ -154,7 +155,7 @@ class MMLU(datasets.GeneratorBasedBuilder):
|
|||
]
|
||||
|
||||
def _generate_examples(self, filepath):
|
||||
df = pd.read_csv(filepath)
|
||||
df = pd.read_csv(filepath, header=None)
|
||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
|
|
|
@ -4,59 +4,59 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
|
|||
|
||||
## Table of Contents
|
||||
|
||||
- [LoRA Fine-Tuning on A Single GPU](#lora-fine-tuning-on-a-single-gpu)
|
||||
- [QLoRA Fine-Tuning on a Single GPU](#qlora-fine-tuning-on-a-single-gpu)
|
||||
- [LoRA Fine-Tuning on Multiple GPUs](#lora-fine-tuning-on-multiple-gpus)
|
||||
- [LoRA Fine-Tuning on Multiple NPUs](#lora-fine-tuning-on-multiple-npus)
|
||||
- [Full-Parameter Fine-Tuning on Multiple GPUs](#full-parameter-fine-tuning-on-multiple-gpus)
|
||||
- [LoRA Fine-Tuning](#lora-fine-tuning)
|
||||
- [QLoRA Fine-Tuning](#qlora-fine-tuning)
|
||||
- [Full-Parameter Fine-Tuning](#full-parameter-fine-tuning)
|
||||
- [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization)
|
||||
- [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models)
|
||||
- [Extras](#extras)
|
||||
|
||||
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
|
||||
|
||||
## Examples
|
||||
|
||||
### LoRA Fine-Tuning on A Single GPU
|
||||
### LoRA Fine-Tuning
|
||||
|
||||
#### (Continuous) Pre-Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Multimodal Supervised Fine-Tuning
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Reward Modeling
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
||||
```
|
||||
|
||||
#### PPO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||
```
|
||||
|
||||
#### DPO/ORPO/SimPO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### KTO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### Preprocess Dataset
|
||||
|
@ -64,93 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
|||
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
||||
```
|
||||
|
||||
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml
|
||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||
```
|
||||
|
||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
||||
```
|
||||
|
||||
### QLoRA Fine-Tuning on a Single GPU
|
||||
|
||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
|
||||
#### Supervised Fine-Tuning on Multiple Nodes
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with 4-bit AWQ Quantization
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
|
||||
```
|
||||
|
||||
### LoRA Fine-Tuning on Multiple GPUs
|
||||
|
||||
#### Supervised Fine-Tuning with Accelerate on Single Node
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/single_node.sh
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/multi_node.sh
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/ds_zero3.sh
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||
```
|
||||
|
||||
### LoRA Fine-Tuning on Multiple NPUs
|
||||
### QLoRA Fine-Tuning
|
||||
|
||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-0
|
||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_npu/ds_zero0.sh
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
|
||||
```
|
||||
|
||||
### Full-Parameter Fine-Tuning on Multiple GPUs
|
||||
|
||||
#### Supervised Fine-Tuning with Accelerate on Single Node
|
||||
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/single_node.sh
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes
|
||||
#### Supervised Fine-Tuning with 4-bit AWQ Quantization
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/multi_node.sh
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
||||
```
|
||||
|
||||
### Full-Parameter Fine-Tuning
|
||||
|
||||
#### Supervised Fine-Tuning on Single Node
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning on Multiple Nodes
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||
```
|
||||
|
||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/predict.sh
|
||||
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
|
||||
```
|
||||
|
||||
### Merging LoRA Adapters and Quantization
|
||||
|
@ -160,35 +146,33 @@ bash examples/full_multi_gpu/predict.sh
|
|||
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Quantizing Model using AutoGPTQ
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||
```
|
||||
|
||||
### Inferring LoRA Fine-Tuned Models
|
||||
|
||||
Use `CUDA_VISIBLE_DEVICES=0,1` to infer models on multiple devices.
|
||||
|
||||
#### Use CLI
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Use Web UI
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Launch OpenAI-style API
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
### Extras
|
||||
|
@ -196,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
|
|||
#### Full-Parameter Fine-Tuning using GaLore
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### Full-Parameter Fine-Tuning using BAdam
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### LoRA+ Fine-Tuning
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### PiSSA Fine-Tuning
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Mixture-of-Depths Fine-Tuning
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
||||
llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### LLaMA-Pro Fine-Tuning
|
||||
|
||||
```bash
|
||||
bash examples/extras/llama_pro/expand.sh
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
||||
llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
||||
```
|
||||
|
||||
#### FSDP+QLoRA Fine-Tuning
|
||||
|
||||
```bash
|
||||
bash examples/extras/fsdp_qlora/single_node.sh
|
||||
bash examples/extras/fsdp_qlora/train.sh
|
||||
```
|
||||
|
|
|
@ -4,59 +4,59 @@
|
|||
|
||||
## 目录
|
||||
|
||||
- [单 GPU LoRA 微调](#单-gpu-lora-微调)
|
||||
- [单 GPU QLoRA 微调](#单-gpu-qlora-微调)
|
||||
- [多 GPU LoRA 微调](#多-gpu-lora-微调)
|
||||
- [多 NPU LoRA 微调](#多-npu-lora-微调)
|
||||
- [多 GPU 全参数微调](#多-gpu-全参数微调)
|
||||
- [LoRA 微调](#lora-微调)
|
||||
- [QLoRA 微调](#qlora-微调)
|
||||
- [全参数微调](#全参数微调)
|
||||
- [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化)
|
||||
- [推理 LoRA 模型](#推理-lora-模型)
|
||||
- [杂项](#杂项)
|
||||
|
||||
使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。
|
||||
|
||||
## 示例
|
||||
|
||||
### 单 GPU LoRA 微调
|
||||
### LoRA 微调
|
||||
|
||||
#### (增量)预训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
```
|
||||
|
||||
#### 指令监督微调
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 多模态指令监督微调
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 奖励模型训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
||||
```
|
||||
|
||||
#### PPO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||
```
|
||||
|
||||
#### DPO/ORPO/SimPO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### KTO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### 预处理数据集
|
||||
|
@ -64,93 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
|||
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
||||
```
|
||||
|
||||
#### 在 MMLU/CMMLU/C-Eval 上评估
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml
|
||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||
```
|
||||
|
||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
||||
```
|
||||
|
||||
### 单 GPU QLoRA 微调
|
||||
|
||||
#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
|
||||
#### 多机指令监督微调
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
|
||||
```
|
||||
|
||||
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
|
||||
```
|
||||
|
||||
#### 基于 4 比特 AWQ 量化进行指令监督微调
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
|
||||
```
|
||||
|
||||
#### 基于 2 比特 AQLM 量化进行指令监督微调
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
|
||||
```
|
||||
|
||||
### 多 GPU LoRA 微调
|
||||
|
||||
#### 使用 Accelerate 进行单节点训练
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/single_node.sh
|
||||
```
|
||||
|
||||
#### 使用 Accelerate 进行多节点训练
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/multi_node.sh
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/ds_zero3.sh
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||
```
|
||||
|
||||
### 多 NPU LoRA 微调
|
||||
### QLoRA 微调
|
||||
|
||||
#### 使用 DeepSpeed ZeRO-0 训练
|
||||
#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_npu/ds_zero0.sh
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
|
||||
```
|
||||
|
||||
### 多 GPU 全参数微调
|
||||
|
||||
#### 使用 DeepSpeed 进行单节点训练
|
||||
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/single_node.sh
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
|
||||
```
|
||||
|
||||
#### 使用 DeepSpeed 进行多节点训练
|
||||
#### 基于 4 比特 AWQ 量化进行指令监督微调
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/multi_node.sh
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
|
||||
```
|
||||
|
||||
#### 基于 2 比特 AQLM 量化进行指令监督微调
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
||||
```
|
||||
|
||||
### 全参数微调
|
||||
|
||||
#### 在单机上进行指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||
```
|
||||
|
||||
#### 在多机上进行指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||
```
|
||||
|
||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/predict.sh
|
||||
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
|
||||
```
|
||||
|
||||
### 合并 LoRA 适配器与模型量化
|
||||
|
@ -160,35 +146,33 @@ bash examples/full_multi_gpu/predict.sh
|
|||
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 AutoGPTQ 量化模型
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||
```
|
||||
|
||||
### 推理 LoRA 模型
|
||||
|
||||
使用 `CUDA_VISIBLE_DEVICES=0,1` 进行多卡推理。
|
||||
|
||||
#### 使用命令行接口
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用浏览器界面
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 启动 OpenAI 风格 API
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
### 杂项
|
||||
|
@ -196,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
|
|||
#### 使用 GaLore 进行全参数训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 BAdam 进行全参数训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### LoRA+ 微调
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### PiSSA 微调
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 深度混合微调
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
||||
llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### LLaMA-Pro 微调
|
||||
|
||||
```bash
|
||||
bash examples/extras/llama_pro/expand.sh
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
||||
llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
||||
```
|
||||
|
||||
#### FSDP+QLoRA 微调
|
||||
|
||||
```bash
|
||||
bash examples/extras/fsdp_qlora/single_node.sh
|
||||
bash examples/extras/fsdp_qlora/train.sh
|
||||
```
|
||||
|
|
|
@ -5,16 +5,16 @@ downcast_bf16: 'no'
|
|||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: true
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: true # offload may affect training speed
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_use_orig_params: true
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
mixed_precision: fp16 # or bf16
|
||||
num_machines: 1 # the number of nodes
|
||||
num_processes: 2 # the number of GPUs in all nodes
|
||||
rdzv_backend: static
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_process_ip: 192.168.0.1
|
||||
main_process_port: 29555
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 2 # the number of nodes
|
||||
num_processes: 8 # the number of GPUs in all nodes
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
|
@ -1,16 +0,0 @@
|
|||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 1 # the number of nodes
|
||||
num_processes: 4 # the number of GPUs in all nodes
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
|
@ -1,18 +0,0 @@
|
|||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 1
|
||||
main_process_ip: 192.168.0.1
|
||||
main_process_port: 29555
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 2 # the number of nodes
|
||||
num_processes: 8 # the number of GPUs in all nodes
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
|
@ -28,14 +28,14 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
pure_bf16: true
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
|
|
@ -6,10 +6,7 @@ quantization_bit: 4
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### ddp
|
||||
ddp_timeout: 180000000
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
|
@ -29,14 +26,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
#!/bin/bash
|
||||
# DO NOT use GPTQ/AWQ model in FSDP+QLoRA
|
||||
|
||||
pip install "transformers>=4.39.1"
|
||||
pip install "accelerate>=0.28.0"
|
||||
pip install "bitsandbytes>=0.43.0"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
|
||||
--config_file examples/accelerate/fsdp_config.yaml \
|
||||
src/train.py examples/extras/fsdp_qlora/llama3_lora_sft.yaml
|
|
@ -29,14 +29,14 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
pure_bf16: true
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
|
|
@ -27,14 +27,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
|
|
@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
loraplus_lr_ratio: 16.0
|
||||
|
||||
### dataset
|
||||
|
@ -26,14 +26,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
|
|
@ -26,14 +26,15 @@ overwrite_output_dir: true
|
|||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
optim: paged_adamw_8bit
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
pure_bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
|
|
@ -5,10 +5,10 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### ddp
|
||||
ddp_timeout: 180000000
|
||||
lora_target: all
|
||||
pissa_init: true
|
||||
pissa_iter: 4
|
||||
pissa_convert: true
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
|
@ -27,15 +27,16 @@ overwrite_output_dir: true
|
|||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 0.0001
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -1,15 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
NPROC_PER_NODE=4
|
||||
NNODES=2
|
||||
RANK=0
|
||||
MASTER_ADDR=192.168.0.1
|
||||
MASTER_PORT=29500
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
|
||||
--nproc_per_node $NPROC_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml
|
|
@ -1,5 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file examples/accelerate/single_config.yaml \
|
||||
src/train.py examples/full_multi_gpu/llama3_full_predict.yaml
|
|
@ -1,15 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
NPROC_PER_NODE=4
|
||||
NNODES=1
|
||||
RANK=0
|
||||
MASTER_ADDR=127.0.0.1
|
||||
MASTER_PORT=29500
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
|
||||
--nproc_per_node $NPROC_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml
|
|
@ -1,15 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
NPROC_PER_NODE=4
|
||||
NNODES=1
|
||||
RANK=0
|
||||
MASTER_ADDR=127.0.0.1
|
||||
MASTER_PORT=29500
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
|
||||
--nproc_per_node $NPROC_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
src/train.py examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
|
|
@ -1,6 +0,0 @@
|
|||
#!/bin/bash
|
||||
# also launch it on slave machine using slave_config.yaml
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file examples/accelerate/master_config.yaml \
|
||||
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml
|
|
@ -1,5 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file examples/accelerate/single_config.yaml \
|
||||
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml
|
|
@ -1,15 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
NPROC_PER_NODE=4
|
||||
NNODES=1
|
||||
RANK=0
|
||||
MASTER_ADDR=127.0.0.1
|
||||
MASTER_PORT=29500
|
||||
|
||||
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 torchrun \
|
||||
--nproc_per_node $NPROC_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
src/train.py examples/lora_multi_npu/llama3_lora_sft_ds.yaml
|
|
@ -5,9 +5,6 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
|
||||
### ddp
|
||||
ddp_timeout: 180000000
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### dataset
|
||||
|
@ -28,14 +25,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||
stage: dpo
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
pref_beta: 0.1
|
||||
pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]
|
||||
|
||||
|
@ -27,14 +27,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.000005
|
||||
learning_rate: 5.0e-6
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -5,7 +5,8 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||
stage: kto
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
pref_beta: 0.1
|
||||
|
||||
### dataset
|
||||
dataset: kto_en_demo
|
||||
|
@ -25,14 +26,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.000005
|
||||
learning_rate: 5.0e-6
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -6,7 +6,7 @@ reward_model: saves/llama3-8b/lora/reward
|
|||
stage: ppo
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
|
@ -26,11 +26,12 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.00001
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### generate
|
||||
max_new_tokens: 512
|
|
@ -22,3 +22,4 @@ overwrite_output_dir: true
|
|||
### eval
|
||||
per_device_eval_batch_size: 1
|
||||
predict_with_generate: true
|
||||
ddp_timeout: 180000000
|
|
@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||
stage: pt
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: c4_demo
|
||||
|
@ -24,14 +24,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||
stage: rm
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: dpo_en_demo
|
||||
|
@ -25,14 +25,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.00001
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
|
@ -25,14 +25,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -5,10 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### ddp
|
||||
ddp_timeout: 180000000
|
||||
lora_target: all
|
||||
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||
|
||||
### dataset
|
||||
|
@ -29,14 +26,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -5,10 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### ddp
|
||||
ddp_timeout: 180000000
|
||||
lora_target: all
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### dataset
|
||||
|
@ -29,14 +26,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
|
@ -6,7 +6,7 @@ visual_inputs: true
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo
|
||||
|
@ -26,14 +26,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -5,7 +5,7 @@ model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
|
@ -25,14 +25,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -5,7 +5,7 @@ model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
|
@ -25,14 +25,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -6,7 +6,7 @@ quantization_bit: 4
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
|
@ -26,14 +26,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -5,7 +5,7 @@ model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ
|
|||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
|
@ -25,14 +25,15 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
|
@ -1,12 +1,13 @@
|
|||
transformers>=4.37.2
|
||||
datasets>=2.14.3
|
||||
accelerate>=0.27.2
|
||||
peft>=0.10.0
|
||||
trl>=0.8.1
|
||||
transformers>=4.41.2
|
||||
datasets>=2.16.0
|
||||
accelerate>=0.30.1
|
||||
peft>=0.11.1
|
||||
trl>=0.8.6
|
||||
gradio>=4.0.0
|
||||
scipy
|
||||
einops
|
||||
sentencepiece
|
||||
tiktoken
|
||||
protobuf
|
||||
uvicorn
|
||||
pydantic
|
||||
|
|
|
@ -1,7 +1,20 @@
|
|||
# coding=utf-8
|
||||
# Calculates the flops of pre-trained models.
|
||||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Microsoft's DeepSpeed library.
|
||||
# https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fire
|
||||
import torch
|
||||
|
@ -17,6 +30,10 @@ def calculate_flops(
|
|||
seq_length: int = 256,
|
||||
flash_attn: str = "auto",
|
||||
):
|
||||
r"""
|
||||
Calculates the flops of pre-trained models.
|
||||
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||
"""
|
||||
with get_accelerator().device(0):
|
||||
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
||||
|
|
|
@ -1,7 +1,20 @@
|
|||
# coding=utf-8
|
||||
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
||||
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
||||
# Copyright 2024 imoneoi and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the imoneoi's OpenChat library.
|
||||
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
@ -32,6 +45,10 @@ def calculate_lr(
|
|||
cutoff_len: int = 1024, # i.e. maximum input length during training
|
||||
is_mistral: bool = False, # mistral model uses a smaller learning rate,
|
||||
):
|
||||
r"""
|
||||
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
||||
"""
|
||||
model_args, data_args, training_args, _, _ = get_train_args(
|
||||
dict(
|
||||
stage=stage,
|
||||
|
|
|
@ -1,6 +1,17 @@
|
|||
# coding=utf-8
|
||||
# Calculates the ppl on the dataset of the pre-trained models.
|
||||
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
@ -56,6 +67,10 @@ def cal_ppl(
|
|||
max_samples: Optional[int] = None,
|
||||
train_on_prompt: bool = False,
|
||||
):
|
||||
r"""
|
||||
Calculates the ppl on the dataset of the pre-trained models.
|
||||
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
||||
"""
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||
dict(
|
||||
stage=stage,
|
||||
|
|
|
@ -1,6 +1,17 @@
|
|||
# coding=utf-8
|
||||
# Calculates the distribution of the input lengths in the dataset.
|
||||
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
@ -19,6 +30,10 @@ def length_cdf(
|
|||
template: str = "default",
|
||||
interval: int = 1000,
|
||||
):
|
||||
r"""
|
||||
Calculates the distribution of the input lengths in the dataset.
|
||||
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
||||
"""
|
||||
model_args, data_args, training_args, _, _ = get_train_args(
|
||||
dict(
|
||||
stage="sft",
|
||||
|
|
|
@ -1,7 +1,20 @@
|
|||
# coding=utf-8
|
||||
# Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
|
||||
# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
||||
# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
||||
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
@ -37,6 +50,10 @@ def block_expansion(
|
|||
shard_size: Optional[str] = "2GB",
|
||||
save_safetensors: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
|
||||
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
||||
"""
|
||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
|
||||
num_layers = getattr(config, "num_hidden_layers")
|
||||
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
||||
|
@ -103,11 +120,11 @@ def block_expansion(
|
|||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
print("Fine-tune this model with:")
|
||||
print(" --model_name_or_path {} \\".format(output_dir))
|
||||
print(" --finetuning_type freeze \\")
|
||||
print(" --freeze_trainable_layers {} \\".format(num_expand))
|
||||
print(" --use_llama_pro")
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print("finetuning_type: freeze")
|
||||
print("freeze_trainable_layers: {}".format(num_expand))
|
||||
print("use_llama_pro: true")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,8 +1,17 @@
|
|||
# coding=utf-8
|
||||
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
|
||||
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
|
||||
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
@ -79,6 +88,11 @@ def save_config(input_dir: str, output_dir: str):
|
|||
def llamafy_baichuan2(
|
||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||
):
|
||||
r"""
|
||||
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||
Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
|
||||
Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
||||
"""
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,7 +1,17 @@
|
|||
# coding=utf-8
|
||||
# Converts the Qwen models in the same format as LLaMA2.
|
||||
# Usage: python llamafy_qwen.py --input_dir input --output_dir output
|
||||
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
@ -131,6 +141,11 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
|||
def llamafy_qwen(
|
||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||
):
|
||||
r"""
|
||||
Converts the Qwen models in the same format as LLaMA2.
|
||||
Usage: python llamafy_qwen.py --input_dir input --output_dir output
|
||||
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
||||
"""
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,14 +1,25 @@
|
|||
# coding=utf-8
|
||||
# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
||||
# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir
|
||||
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the HuggingFace's PEFT library.
|
||||
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
@ -17,38 +28,21 @@ if TYPE_CHECKING:
|
|||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
class Shell(nn.Module):
|
||||
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||
if bias is not None:
|
||||
self.bias = nn.Parameter(bias, requires_grad=False)
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
||||
for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
|
||||
parent_name = ".".join(name.split(".")[:-1])
|
||||
child_name = name.split(".")[-1]
|
||||
parent_module = model.get_submodule(parent_name)
|
||||
child_module = getattr(parent_module, child_name)
|
||||
base_layer = getattr(child_module, "base_layer")
|
||||
weight = getattr(base_layer, "weight", None)
|
||||
bias = getattr(base_layer, "bias", None)
|
||||
setattr(parent_module, child_name, Shell(weight, bias))
|
||||
|
||||
print("Model unwrapped.")
|
||||
|
||||
|
||||
def quantize_loftq(
|
||||
model_name_or_path: str,
|
||||
save_dir: str,
|
||||
loftq_bits: Optional[int] = 4,
|
||||
loftq_iter: Optional[int] = 1,
|
||||
lora_alpha: Optional[int] = None,
|
||||
lora_rank: Optional[int] = 16,
|
||||
lora_target: Optional[str] = "q_proj,v_proj",
|
||||
save_safetensors: Optional[bool] = False,
|
||||
output_dir: str,
|
||||
loftq_bits: int = 4,
|
||||
loftq_iter: int = 4,
|
||||
lora_alpha: int = None,
|
||||
lora_rank: int = 16,
|
||||
lora_dropout: float = 0,
|
||||
lora_target: str = "q_proj,v_proj",
|
||||
save_safetensors: bool = True,
|
||||
):
|
||||
r"""
|
||||
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
||||
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
|
||||
|
@ -57,25 +51,34 @@ def quantize_loftq(
|
|||
inference_mode=True,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||
lora_dropout=0.1,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||
init_lora_weights="loftq",
|
||||
loftq_config=loftq_config,
|
||||
)
|
||||
|
||||
# Init LoftQ model
|
||||
lora_model = get_peft_model(model, lora_config)
|
||||
base_model: "PreTrainedModel" = lora_model.get_base_model()
|
||||
print("Initializing LoftQ weights, it may be take several minutes, wait patiently.")
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
loftq_dir = os.path.join(output_dir, "loftq_init")
|
||||
|
||||
# Save LoftQ model
|
||||
setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir)
|
||||
setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True)
|
||||
lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors)
|
||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir)
|
||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
||||
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
||||
print("Adapter weights saved in {}".format(loftq_dir))
|
||||
|
||||
# Save base model
|
||||
unwrap_model(base_model)
|
||||
base_model.save_pretrained(save_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
base_model: "PreTrainedModel" = peft_model.unload()
|
||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print("adapter_name_or_path: {}".format(loftq_dir))
|
||||
print("finetuning_type: lora")
|
||||
print("quantization_bit: {}".format(loftq_bits))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the HuggingFace's PEFT library.
|
||||
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import fire
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
def quantize_pissa(
|
||||
model_name_or_path: str,
|
||||
output_dir: str,
|
||||
pissa_iter: int = 4,
|
||||
lora_alpha: int = None,
|
||||
lora_rank: int = 16,
|
||||
lora_dropout: float = 0,
|
||||
lora_target: str = "q_proj,v_proj",
|
||||
save_safetensors: bool = True,
|
||||
):
|
||||
r"""
|
||||
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
|
||||
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter)
|
||||
)
|
||||
|
||||
# Init PiSSA model
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
pissa_dir = os.path.join(output_dir, "pissa_init")
|
||||
|
||||
# Save PiSSA model
|
||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
||||
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
||||
print("Adapter weights saved in {}".format(pissa_dir))
|
||||
|
||||
# Save base model
|
||||
base_model: "PreTrainedModel" = peft_model.unload()
|
||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print("adapter_name_or_path: {}".format(pissa_dir))
|
||||
print("finetuning_type: lora")
|
||||
print("pissa_init: false")
|
||||
print("pissa_convert: true")
|
||||
print("- and optionally with:")
|
||||
print("quantization_bit: 4")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(quantize_pissa)
|
|
@ -1,3 +1,18 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
@ -20,7 +35,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
|
|||
|
||||
def main():
|
||||
client = OpenAI(
|
||||
api_key="0",
|
||||
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
||||
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
||||
)
|
||||
tools = [
|
25
setup.py
25
setup.py
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
|
@ -5,7 +19,7 @@ from setuptools import find_packages, setup
|
|||
|
||||
|
||||
def get_version():
|
||||
with open(os.path.join("src", "llamafactory", "cli.py"), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
||||
(version,) = re.findall(pattern, file_content)
|
||||
|
@ -21,18 +35,19 @@ def get_requires():
|
|||
|
||||
extra_require = {
|
||||
"torch": ["torch>=1.13.1"],
|
||||
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||
"deepspeed": ["deepspeed>=0.10.0,<=0.14.0"],
|
||||
"deepspeed": ["deepspeed>=0.10.0"],
|
||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||
"vllm": ["vllm>=0.4.0"],
|
||||
"vllm": ["vllm>=0.4.3"],
|
||||
"galore": ["galore-torch"],
|
||||
"badam": ["badam"],
|
||||
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
||||
"awq": ["autoawq"],
|
||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"qwen": ["tiktoken", "transformers_stream_generator"],
|
||||
"qwen": ["transformers_stream_generator"],
|
||||
"modelscope": ["modelscope"],
|
||||
"quality": ["ruff"],
|
||||
"dev": ["ruff", "pytest"],
|
||||
}
|
||||
|
||||
|
||||
|
|
14
src/api.py
14
src/api.py
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
|
|
|
@ -1,4 +1,18 @@
|
|||
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
||||
|
||||
from .cli import VERSION
|
||||
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
|
|
@ -1,10 +1,27 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_fastapi_available
|
||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||
from .common import dictify, jsonify
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
|
@ -25,7 +42,17 @@ if is_fastapi_available():
|
|||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if is_requests_available():
|
||||
import requests
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..chat import ChatModel
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
|
@ -40,7 +67,9 @@ ROLE_MAPPING = {
|
|||
}
|
||||
|
||||
|
||||
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
|
||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
||||
|
||||
if len(request.messages) == 0:
|
||||
|
@ -49,12 +78,13 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
|
|||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = ""
|
||||
system = None
|
||||
|
||||
if len(request.messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
image = None
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
@ -66,6 +96,21 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
|
|||
arguments = message.tool_calls[0].function.arguments
|
||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||
elif isinstance(message.content, list):
|
||||
for input_item in message.content:
|
||||
if input_item.type == "text":
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
|
||||
else:
|
||||
image_url = input_item.image_url.url
|
||||
if image_url.startswith("data:image"): # base64 image
|
||||
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
|
||||
image_path = io.BytesIO(image_data)
|
||||
elif os.path.isfile(image_url): # local file
|
||||
image_path = open(image_url, "rb")
|
||||
else: # web uri
|
||||
image_path = requests.get(image_url, stream=True).raw
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
else:
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||
|
||||
|
@ -76,9 +121,9 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
|
|||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = ""
|
||||
tools = None
|
||||
|
||||
return input_messages, system, tools
|
||||
return input_messages, system, tools, image
|
||||
|
||||
|
||||
def _create_stream_chat_completion_chunk(
|
||||
|
@ -97,11 +142,12 @@ async def create_chat_completion_response(
|
|||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> "ChatCompletionResponse":
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools = _process_request(request)
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
|
@ -145,7 +191,7 @@ async def create_stream_chat_completion_response(
|
|||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools = _process_request(request)
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
|
@ -159,6 +205,7 @@ async def create_stream_chat_completion_response(
|
|||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
@ -56,9 +70,19 @@ class FunctionCall(BaseModel):
|
|||
function: Function
|
||||
|
||||
|
||||
class ImageURL(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class MultimodalInputItem(BaseModel):
|
||||
type: Literal["text", "image_url"]
|
||||
text: Optional[str] = None
|
||||
image_url: Optional[ImageURL] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role
|
||||
content: Optional[str] = None
|
||||
content: Optional[Union[str, List[MultimodalInputItem]]] = None
|
||||
tool_calls: Optional[List[FunctionCall]] = None
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_engine import BaseEngine
|
||||
from .chat_model import ChatModel
|
||||
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
|
|
@ -1,3 +1,20 @@
|
|||
# Copyright 2024 THUDM and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the THUDM's ChatGLM implementation.
|
||||
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
|
@ -8,6 +22,7 @@ import torch
|
|||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
@ -23,6 +38,9 @@ if TYPE_CHECKING:
|
|||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HuggingfaceEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -79,6 +97,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
|
||||
|
||||
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
|
||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||
|
@ -92,7 +111,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if stop is not None:
|
||||
raise ValueError("Stop parameter is not supported in Huggingface engine yet.")
|
||||
logger.warning("Stop parameter is not supported in Huggingface engine yet.")
|
||||
|
||||
generating_args = generating_args.copy()
|
||||
generating_args.update(
|
||||
|
@ -132,6 +151,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||
|
||||
gen_kwargs = dict(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=GenerationConfig(**generating_args),
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
|
|
@ -1,19 +1,37 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_device_count, infer_optim_dtype
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
|
||||
from ..model import load_config, load_tokenizer
|
||||
from ..model.utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import MultiModalData
|
||||
|
||||
if is_vllm_version_greater_than_0_5():
|
||||
from vllm.multimodal.image import ImagePixelData
|
||||
else:
|
||||
from vllm.sequence import MultiModalData
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -35,8 +53,6 @@ class VllmEngine(BaseEngine):
|
|||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
infer_dtype = str(infer_dtype).split(".")[-1]
|
||||
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
|
@ -50,7 +66,7 @@ class VllmEngine(BaseEngine):
|
|||
"model": model_args.model_name_or_path,
|
||||
"trust_remote_code": True,
|
||||
"download_dir": model_args.cache_dir,
|
||||
"dtype": infer_dtype,
|
||||
"dtype": model_args.infer_dtype,
|
||||
"max_model_len": model_args.vllm_maxlen,
|
||||
"tensor_parallel_size": get_device_count() or 1,
|
||||
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
||||
|
@ -70,7 +86,6 @@ class VllmEngine(BaseEngine):
|
|||
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
|
||||
engine_args["image_feature_size"] = self.image_feature_size
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
# bug in vllm 0.4.2, see: https://github.com/vllm-project/vllm/pull/4828
|
||||
import vllm.model_executor.models.llava
|
||||
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
|
@ -109,7 +124,10 @@ class VllmEngine(BaseEngine):
|
|||
if self.processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
if is_vllm_version_greater_than_0_5():
|
||||
multi_modal_data = ImagePixelData(image=pixel_values)
|
||||
else: # TODO: remove vllm 0.4.3 support
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
|
@ -158,12 +176,10 @@ class VllmEngine(BaseEngine):
|
|||
)
|
||||
|
||||
result_generator = self.model.generate(
|
||||
prompt=None,
|
||||
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_ids,
|
||||
lora_request=self.lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
return result_generator
|
||||
|
||||
|
|
|
@ -1,9 +1,30 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum, unique
|
||||
|
||||
from . import launcher
|
||||
from .api.app import run_api
|
||||
from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .extras.env import VERSION, print_env
|
||||
from .extras.logging import get_logger
|
||||
from .extras.misc import get_device_count
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
|
||||
|
@ -23,8 +44,6 @@ USAGE = (
|
|||
+ "-" * 70
|
||||
)
|
||||
|
||||
VERSION = "0.7.2.dev0"
|
||||
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
|
@ -37,11 +56,14 @@ WELCOME = (
|
|||
+ "-" * 58
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
class Command(str, Enum):
|
||||
API = "api"
|
||||
CHAT = "chat"
|
||||
ENV = "env"
|
||||
EVAL = "eval"
|
||||
EXPORT = "export"
|
||||
TRAIN = "train"
|
||||
|
@ -57,12 +79,35 @@ def main():
|
|||
run_api()
|
||||
elif command == Command.CHAT:
|
||||
run_chat()
|
||||
elif command == Command.ENV:
|
||||
print_env()
|
||||
elif command == Command.EVAL:
|
||||
run_eval()
|
||||
elif command == Command.EXPORT:
|
||||
export_model()
|
||||
elif command == Command.TRAIN:
|
||||
run_exp()
|
||||
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||
if force_torchrun or get_device_count() > 1:
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
||||
subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||
).format(
|
||||
nnodes=os.environ.get("NNODES", "1"),
|
||||
node_rank=os.environ.get("RANK", "0"),
|
||||
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
file_name=launcher.__file__,
|
||||
args=" ".join(sys.argv[1:]),
|
||||
),
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
run_exp()
|
||||
elif command == Command.WEBDEMO:
|
||||
run_web_demo()
|
||||
elif command == Command.WEBUI:
|
||||
|
|
|
@ -1,16 +1,30 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
||||
from .data_utils import Role, split_dataset
|
||||
from .loader import get_dataset
|
||||
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||
from .utils import Role, split_dataset
|
||||
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KTODataCollatorWithPadding",
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"get_dataset",
|
||||
"Template",
|
||||
"get_template_and_fix_tokenizer",
|
||||
"templates",
|
||||
"Role",
|
||||
"split_dataset",
|
||||
"get_dataset",
|
||||
"TEMPLATES",
|
||||
"Template",
|
||||
"get_template_and_fix_tokenizer",
|
||||
]
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
@ -5,11 +19,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
|||
from datasets import Features
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import Role
|
||||
from .data_utils import Role
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .parser import DatasetAttr
|
||||
|
@ -175,7 +190,10 @@ def convert_sharegpt(
|
|||
|
||||
|
||||
def align_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
|
@ -208,7 +226,7 @@ def align_dataset(
|
|||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||
desc="Converting format of dataset",
|
||||
)
|
||||
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
|
@ -1,24 +1,38 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import has_tokenized_data
|
||||
from .aligner import align_dataset
|
||||
from .data_utils import merge_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .utils import merge_dataset
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
from .parser import DatasetAttr
|
||||
|
@ -31,6 +45,7 @@ def load_single_dataset(
|
|||
dataset_attr: "DatasetAttr",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
|
@ -61,9 +76,9 @@ def load_single_dataset(
|
|||
raise ValueError("File {} not found.".format(local_path))
|
||||
|
||||
if data_path is None:
|
||||
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
try:
|
||||
|
@ -106,18 +121,30 @@ def load_single_dataset(
|
|||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
num_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(num_samples))
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
indexes = np.random.permutation(len(dataset))[:target_num]
|
||||
target_num -= len(indexes)
|
||||
if target_num > 0:
|
||||
expand_indexes = np.random.choice(len(dataset), target_num)
|
||||
indexes = np.concatenate((indexes, expand_indexes), axis=0)
|
||||
|
||||
return align_dataset(dataset, dataset_attr, data_args)
|
||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||
dataset = dataset.select(indexes)
|
||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
max_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(max_samples))
|
||||
|
||||
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "kto"],
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
|
@ -144,7 +171,8 @@ def get_dataset(
|
|||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
|
@ -156,7 +184,7 @@ def get_dataset(
|
|||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
|
@ -166,7 +194,7 @@ def get_dataset(
|
|||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.tokenized_path)
|
||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
@ -20,11 +34,12 @@ class DatasetAttr:
|
|||
""" basic configs """
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: str
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
ranking: bool = False
|
||||
""" extra configs """
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: bool = False
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
num_samples: Optional[int] = None
|
||||
""" common columns """
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
|
@ -102,10 +117,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||
else:
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
||||
|
||||
|
@ -13,8 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .template import Template
|
||||
|
@ -23,7 +36,7 @@ if TYPE_CHECKING:
|
|||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "kto"],
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
|
|
|
@ -1,13 +1,26 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
@ -16,6 +29,55 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_feedback_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
kl_response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = prompt + [response[0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = prompt + [response[1]]
|
||||
|
||||
if kl_response[0]["content"]:
|
||||
kl_messages = prompt + [kl_response[0]]
|
||||
else:
|
||||
kl_messages = prompt + [kl_response[1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
||||
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||
|
||||
|
||||
def preprocess_feedback_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
|
@ -45,50 +107,17 @@ def preprocess_feedback_dataset(
|
|||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
|
||||
|
||||
if examples["response"][i][0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
|
||||
if kl_response[i][0]["content"]:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
|
||||
else:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
kl_response=kl_response[i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
from typing import TYPE_CHECKING, List, Sequence
|
||||
|
||||
from ...extras.packages import is_pillow_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
||||
# process visual inputs (currently only supports a single image)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
|
||||
|
||||
|
||||
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
|
||||
# get paligemma token type ids for computing loss
|
||||
image_seq_length = getattr(processor, "image_seq_length")
|
||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
|
@ -1,13 +1,26 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
@ -16,6 +29,44 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_pairwise_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
chosen_messages = prompt + [response[0]]
|
||||
rejected_messages = prompt + [response[1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
chosen_input_ids = prompt_ids + chosen_ids
|
||||
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
|
||||
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
|
@ -43,40 +94,16 @@ def preprocess_pairwise_dataset(
|
|||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
|
||||
|
||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
chosen_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
rejected_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
chosen_input_ids = prompt_ids + chosen_ids
|
||||
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
|
||||
model_inputs["chosen_input_ids"].append(chosen_input_ids)
|
||||
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
|
||||
model_inputs["chosen_labels"].append(chosen_labels)
|
||||
|
|
|
@ -1,9 +1,26 @@
|
|||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
|
||||
|
@ -12,13 +29,14 @@ def preprocess_pretrain_dataset(
|
|||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
|
||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
|
||||
|
||||
if not data_args.packing:
|
||||
if data_args.template == "gemma":
|
||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||
|
||||
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
|
||||
else:
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import bisect
|
||||
from typing import TYPE_CHECKING, List, Sequence
|
||||
|
||||
from ...extras.packages import is_pillow_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
|
||||
r"""
|
||||
Finds the index of largest number that fits into the knapsack with the given capacity.
|
||||
"""
|
||||
index = bisect.bisect(numbers, capacity)
|
||||
return -1 if index == 0 else (index - 1)
|
||||
|
||||
|
||||
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
|
||||
r"""
|
||||
An efficient greedy algorithm with binary search for the knapsack problem.
|
||||
"""
|
||||
numbers.sort() # sort numbers in ascending order for binary search
|
||||
knapsacks = []
|
||||
|
||||
while numbers:
|
||||
current_knapsack = []
|
||||
remaining_capacity = capacity
|
||||
|
||||
while True:
|
||||
index = search_for_fit(numbers, remaining_capacity)
|
||||
if index == -1:
|
||||
break # no more numbers fit in this knapsack
|
||||
|
||||
remaining_capacity -= numbers[index] # update the remaining capacity
|
||||
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
|
||||
|
||||
knapsacks.append(current_knapsack)
|
||||
|
||||
return knapsacks
|
||||
|
||||
|
||||
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
||||
r"""
|
||||
Processes visual inputs. (currently only supports a single image)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
|
||||
|
||||
|
||||
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
|
||||
r"""
|
||||
Gets paligemma token type ids for computing loss.
|
||||
"""
|
||||
image_seq_length = getattr(processor, "image_seq_length")
|
||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
|
@ -1,13 +1,27 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
@ -16,6 +30,48 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_supervised_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
messages = prompt + response
|
||||
input_ids, labels = [], []
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||
|
||||
encoded_pairs = template.encode_multiturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
|
@ -36,41 +92,16 @@ def preprocess_supervised_dataset(
|
|||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
input_ids, labels = [], []
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
@ -90,41 +121,55 @@ def preprocess_packed_supervised_dataset(
|
|||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
input_ids, labels = [], []
|
||||
valid_num = 0
|
||||
batch_input_ids, batch_labels = [], []
|
||||
lengths = []
|
||||
length2indexes = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
for source_ids, target_ids in template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif len(input_ids) != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=None,
|
||||
data_args=data_args,
|
||||
)
|
||||
length = len(input_ids)
|
||||
if length > data_args.cutoff_len:
|
||||
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
|
||||
else:
|
||||
lengths.append(length)
|
||||
length2indexes[length].append(valid_num)
|
||||
batch_input_ids.append(input_ids)
|
||||
batch_labels.append(labels)
|
||||
valid_num += 1
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_labels = [], []
|
||||
for length in knapsack:
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
packed_labels += batch_labels[index]
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
if len(packed_input_ids) < data_args.cutoff_len:
|
||||
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
||||
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
||||
packed_labels += [IGNORE_INDEX] * pad_length
|
||||
|
||||
total_length = len(input_ids)
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
if len(packed_input_ids) != data_args.cutoff_len:
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
|
|
@ -1,13 +1,26 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..utils import Role
|
||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from ..data_utils import Role
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
@ -16,6 +29,37 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_unsupervised_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
if len(response) == 1:
|
||||
messages = prompt + response
|
||||
else:
|
||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_unsupervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
|
@ -35,30 +79,16 @@ def preprocess_unsupervised_dataset(
|
|||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
|
||||
|
||||
if len(examples["response"][i]) == 1:
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
else:
|
||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
input_ids, labels = _encode_unsupervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
|
|
@ -1,9 +1,23 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role, infer_max_len
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .utils import Role, infer_max_len
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -196,7 +210,7 @@ class Llama2Template(Template):
|
|||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
|
||||
|
||||
templates: Dict[str, Template] = {}
|
||||
TEMPLATES: Dict[str, Template] = {}
|
||||
|
||||
|
||||
def _register_template(
|
||||
|
@ -248,7 +262,7 @@ def _register_template(
|
|||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
templates[name] = template_class(
|
||||
TEMPLATES[name] = template_class(
|
||||
format_user=format_user or default_user_formatter,
|
||||
format_assistant=format_assistant or default_assistant_formatter,
|
||||
format_system=format_system or default_user_formatter,
|
||||
|
@ -348,9 +362,9 @@ def get_template_and_fix_tokenizer(
|
|||
name: Optional[str] = None,
|
||||
) -> Template:
|
||||
if name is None:
|
||||
template = templates["empty"] # placeholder
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
else:
|
||||
template = templates.get(name, None)
|
||||
template = TEMPLATES.get(name, None)
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(name))
|
||||
|
||||
|
@ -544,8 +558,13 @@ _register_template(
|
|||
)
|
||||
]
|
||||
),
|
||||
format_system=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
force_system=True,
|
||||
format_system=StringFormatter(
|
||||
slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]
|
||||
),
|
||||
default_system=(
|
||||
"You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users "
|
||||
"by providing thorough responses. You are trained by Cohere."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -653,6 +672,19 @@ _register_template(
|
|||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="glm4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["[gMASK]<sop>{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||
|
@ -682,17 +714,8 @@ _register_template(
|
|||
_register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -742,7 +765,6 @@ _register_template(
|
|||
_register_template(
|
||||
name="olmo",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
@ -751,12 +773,28 @@ _register_template(
|
|||
_register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="openchat-3.6",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
|
@ -807,6 +845,15 @@ _register_template(
|
|||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="telechat",
|
||||
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
||||
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
|
||||
stop_words=["<_end>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="vicuna",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
|
@ -857,6 +904,7 @@ _register_template(
|
|||
_register_template(
|
||||
name="yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
|
|
|
@ -1,4 +1,41 @@
|
|||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Dan's test library.
|
||||
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Dan Hendrycks
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
|
@ -26,9 +63,7 @@ class Evaluator:
|
|||
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
|
||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||
self.choice_inputs = [
|
||||
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
||||
]
|
||||
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
|
@ -10,7 +24,6 @@ class EvalTemplate:
|
|||
system: str
|
||||
choice: str
|
||||
answer: str
|
||||
prefix: str
|
||||
|
||||
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||
r"""
|
||||
|
@ -42,8 +55,8 @@ class EvalTemplate:
|
|||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||
|
||||
|
||||
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
||||
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
|
||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
|
||||
|
||||
|
||||
def get_eval_template(name: str) -> "EvalTemplate":
|
||||
|
@ -56,8 +69,7 @@ _register_eval_template(
|
|||
name="en",
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\nAnswer: ",
|
||||
prefix=" ",
|
||||
answer="\nAnswer:",
|
||||
)
|
||||
|
||||
|
||||
|
@ -66,5 +78,4 @@ _register_eval_template(
|
|||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:",
|
||||
prefix=" ",
|
||||
)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue