support FlashAttention2

This commit is contained in:
hiyouga 2023-09-10 20:43:56 +08:00
parent 815b92e698
commit d8aa1404be
9 changed files with 875 additions and 115 deletions

View File

@ -12,11 +12,13 @@
## Changelog
[23/09/10] Now we support using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs (experimental feature).
[23/08/18] Now we support **resuming training**, upgrade `transformers` to `4.31.0` to enjoy this feature.
[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models.
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
@ -62,8 +64,11 @@
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
- **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
> **Note**
>
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
## Supported Training Approaches
@ -75,7 +80,9 @@
| PPO Training | | | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
- Use `--quantization_bit 4/8` argument to enable QLoRA.
> **Note**
>
> Use `--quantization_bit 4/8` argument to enable QLoRA.
## Provided Datasets
@ -138,7 +145,9 @@ And **powerful GPUs**!
Please refer to `data/example_dataset` for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
Note: please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
> **Note**
>
> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
### Dependence Installation (optional)
@ -164,10 +173,16 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts **automatically**.
Currently the web UI only supports training on **a single GPU**.
> **Warning**
>
> Currently the web UI only supports training on **a single GPU**.
### Train on a single GPU
> **Warning**
>
> If you want to train models on multiple GPUs, please refer to [#distributed-training](Distributed Training).
#### Pre-Training
```bash
@ -300,19 +315,13 @@ accelerate config # configure the environment
accelerate launch src/train_bash.py # arguments (same as above)
```
<details><summary>Example config.yaml for training with DeepSpeed ZeRO-2</summary>
<details><summary>Example config for LoRA training</summary>
```yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 4
gradient_clipping: 0.5
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
@ -336,7 +345,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
... # arguments (same as above)
```
<details><summary>Example ds_config.json for training with DeepSpeed ZeRO-2</summary>
<details><summary>Example config for full-parameter training with DeepSpeed ZeRO-2</summary>
```json
{
@ -387,7 +396,9 @@ python src/api_demo.py \
--checkpoint_dir path_to_checkpoint
```
Visit `http://localhost:8000/docs` for API documentation.
> **Note**
>
> Visit `http://localhost:8000/docs` for API documentation.
### CLI Demo
@ -426,7 +437,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate
```
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
> **Note**
>
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
### Predict
@ -445,12 +458,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate
```
## TODO
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
- [ ] Implementing multi-query attention for faster inference.
- [ ] Supporting full-parameter RLHF training.
## License
This repository is licensed under the [Apache-2.0 License](LICENSE).

View File

@ -12,11 +12,13 @@
## 更新日志
[23/09/10] 现在我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2实验性功能
[23/08/18] 现在我们支持了**训练状态恢复**,请将 `transformers` 升级至 `4.31.0` 以启用此功能。
[23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请尝试使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
[23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)(实验性功能)
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)。
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型时请添加 `--template chatml` 参数。
@ -62,8 +64,11 @@
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
- **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
- 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用对应的模板。
> **Note**
>
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
>
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用对应的模板。
## 训练方法
@ -75,7 +80,9 @@
| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
- 使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
> **Note**
>
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
## 数据集
@ -138,7 +145,9 @@ huggingface-cli login
关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
注意:使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`
> **Note**
>
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`
### 环境搭建(可跳过)
@ -164,10 +173,16 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
我们极力推荐新手使用浏览器一体化界面,因为它还可以**自动**生成运行所需的命令行脚本。
目前网页 UI 仅支持**单卡训练**。
> **Warning**
>
> 目前网页 UI 仅支持**单卡训练**。
### 单 GPU 训练
> **Warning**
>
> 如果您使用多张 GPU 训练模型,请移步[多 GPU 分布式训练](#多-gpu-分布式训练)部分。
#### 预训练
```bash
@ -299,19 +314,13 @@ accelerate config # 首先配置分布式环境
accelerate launch src/train_bash.py # 参数同上
```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 Accelerate 配置示例</summary>
<details><summary>LoRA 训练的 Accelerate 配置示例</summary>
```yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 4
gradient_clipping: 0.5
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
@ -335,7 +344,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
... # 参数同上
```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 DeepSpeed 配置示例</summary>
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 DeepSpeed 配置示例</summary>
```json
{
@ -386,7 +395,9 @@ python src/api_demo.py \
--checkpoint_dir path_to_checkpoint
```
关于 API 文档请见 `http://localhost:8000/docs`
> **Note**
>
> 关于 API 文档请见 `http://localhost:8000/docs`
### 命令行测试
@ -425,7 +436,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate
```
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128`
> **Note**
>
> 我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128`
### 模型预测
@ -444,12 +457,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate
```
## TODO
- [ ] 实现 flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention))。
- [ ] 在推理阶段使用 Multi-query attention 进行加速。
- [ ] 支持 RLHF 的全参数微调。
## 协议
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。

View File

View File

@ -0,0 +1,771 @@
# coding=utf-8
# Modified from:
# [1] https://huggingface.co/Birchlabs/flash_llama/blob/main/modeling_flash_llama.py
# [2] https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py
# [3] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from transformers.models.llama.configuration_llama import LlamaConfig
try:
from flash_attn.flash_attn_interface import (
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
)
from flash_attn.bert_padding import unpad_input, pad_input
flash_attn_v2_installed = True
print('>>>> Flash Attention installed')
except ImportError:
flash_attn_v2_installed = False
raise ImportError('Please install Flash Attention: `pip install flash-attn --no-build-isolation`')
try:
from flash_attn.layers.rotary import apply_rotary_emb_func
flash_rope_installed = True
print('>>>> Flash RoPE installed')
except ImportError:
flash_rope_installed = False
raise ImportError('Please install RoPE kernels: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`')
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
def rmsnorm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return (weight * hidden_states).to(input_dtype)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.register_buffer(
"variance_epsilon",
torch.tensor(eps),
persistent=False,
)
def forward(self, hidden_states):
return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
class FlashRotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
scaling_factor=1.0, pos_idx_in_fp32=True, device=None):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
scaling_factor: RotaryEmbedding extended with linear scaling.
"""
super().__init__()
self.dim = dim
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
self.scaling_factor = scaling_factor
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim) if scale_base is not None else None)
self.register_buffer("scale", scale)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _compute_inv_freq(self, device=None):
return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
dtype=torch.float32) / self.dim))
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
t /= self.scaling_factor
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self.inv_freq.to(torch.float32)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
t /= self.scaling_factor
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2) / self.scale_base)
scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
"""
q: (batch, seqlen, nheads, headdim)
k: (batch, seqlen, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
"""
self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype)
if self.scale is None:
return apply_rotary_emb_func(
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved, True # inplace=True
), apply_rotary_emb_func(
k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved, True # inplace=True
)
else:
assert False
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, slen, _, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, :, :, None, :].expand(batch, slen, 2, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(batch, slen, 2, num_key_value_heads * n_rep, head_dim)
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.register_buffer(
"norm_factor",
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
if self.config.rope_scaling is None:
scaling_factor = 1
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
assert scaling_type == 'linear'
self.rotary_emb = FlashRotaryEmbedding(
self.head_dim, base=10000, interleaved=False, scaling_factor=scaling_factor,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
is_padded_inputs: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, h_size = hidden_states.size()
has_layer_past = past_key_value is not None
if has_layer_past:
past_kv = past_key_value[0]
past_len = past_key_value[1]
else:
past_len = 0
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
q = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
q = torch.cat(q, dim=-1)
k = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
k = torch.cat(k, dim=-1)
v = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
v = torch.cat(v, dim=-1)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = q.view(bsz, q_len, self.num_heads, self.head_dim)
k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
q, k = self.rotary_emb(q, k, past_len)
kv = torch.stack([k, v], 2)
kv = repeat_kv(kv, self.num_key_value_groups)
# Cache QKV values
if has_layer_past:
new_len = past_len+q.size(1)
if new_len > past_kv.size(1):
past_kv = torch.cat([past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1)
past_kv[:, past_len:new_len] = kv
kv = past_kv[:, :new_len]
else:
past_kv = kv
past_key_value = (past_kv, past_len+q.size(1)) if use_cache else None
if is_padded_inputs:
# varlen, ignore padding tokens, efficient for large batch with many paddings
logger.warning_once("padded")
assert attention_mask is not None
unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask)
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):])
attn_outputs = flash_attn_varlen_kvpacked_func(
unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k,
dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
causal=(not has_layer_past), return_attn_probs=output_attentions
)
attn_output = attn_outputs[0] if output_attentions else attn_outputs
attn_output = pad_input(
attn_output, indices_q, bsz, q_len
).reshape(bsz, q_len, h_size)
attn_weights = attn_outputs[2] if output_attentions else None
else:
# no padding tokens, more efficient
attn_outputs = flash_attn_kvpacked_func(
q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions)
attn_output = attn_outputs[0] if output_attentions else attn_outputs
attn_output = attn_output.reshape(bsz, q_len, h_size)
attn_weights = attn_outputs[2] if output_attentions else None
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
is_padded_inputs: Optional[bool] = False,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
is_padded_inputs=is_padded_inputs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
LLAMA_START_DOCSTRING, LLAMA_INPUTS_DOCSTRING = "", ""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
is_padded_inputs: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
position_ids = None
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
is_padded_inputs
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
is_padded_inputs=is_padded_inputs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
is_padded_inputs: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_padded_inputs = ((attention_mask is not None) and (not attention_mask.all().item()))
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: "CausalLMOutputWithPast" = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_padded_inputs=is_padded_inputs,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"is_padded_inputs": ((attention_mask is not None) and (not attention_mask.all().item()))
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past

View File

@ -206,9 +206,6 @@ def get_template_and_fix_tokenizer(
name: str,
tokenizer: "PreTrainedTokenizer"
) -> Template:
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token))
@ -217,6 +214,11 @@ def get_template_and_fix_tokenizer(
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if name is None:
return None
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
tokenizer.add_special_tokens(
dict(additional_special_tokens=template.stop_words),
replace_additional_special_tokens=False

View File

@ -43,6 +43,10 @@ class ModelArguments:
default=None,
metadata={"help": "Adopt scaled rotary positional embeddings."}
)
flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable flash attention for faster training."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}

View File

@ -4,6 +4,7 @@ import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
import transformers
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@ -84,7 +85,8 @@ def load_model_and_tokenizer(
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
# Fix config (for Qwen)
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"):
if model_args.compute_dtype == torch.bfloat16:
setattr(config, "bf16", True)
else:
@ -105,6 +107,7 @@ def load_model_and_tokenizer(
if is_trainable:
if model_args.rope_scaling == "dynamic":
assert not model_args.flash_attn, "Flash attention does not support dynamic rope scaling."
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
@ -127,6 +130,15 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support RoPE scaling.")
# Set flash attention
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
from llmtuner.extras.models.flash_llama import LlamaForCausalLM
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
if not hasattr(config, "num_key_value_heads"):
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
if getattr(config, "pretraining_tp", 1) != 1:
setattr(config, "pretraining_tp", 1)
# Quantization configurations (using bitsandbytes library).
is_mergeable = True
if model_args.quantization_bit is not None:

View File

@ -33,27 +33,28 @@ class Seq2SeqPeftTrainer(PeftTrainer):
Subclass and override to inject custom behavior.
"""
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len:
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
if "attention_mask" in inputs:
inputs["attention_mask"] = self._pad_tensors_to_target_len(
inputs["attention_mask"], inputs["labels"], pad_token_id=0
)
if "position_ids" in inputs:
inputs["position_ids"] = self._pad_tensors_to_target_len(
inputs["position_ids"], inputs["labels"], pad_token_id=0
)
if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len:
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
if "attention_mask" in inputs:
inputs["attention_mask"] = self._pad_tensors_to_target_len(
inputs["attention_mask"], inputs["labels"], pad_token_id=0
)
if "position_ids" in inputs:
inputs["position_ids"] = self._pad_tensors_to_target_len(
inputs["position_ids"], inputs["labels"], pad_token_id=0
)
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
if generated_tokens is not None:
generated_tokens[:, :max(prompt_len, label_len)] = (
self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
)
if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :max(prompt_len, label_len)] = self.tokenizer.pad_token_id
generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels
@ -66,16 +67,8 @@ class Seq2SeqPeftTrainer(PeftTrainer):
) -> torch.Tensor:
r"""
Pads the tensor to the same length as the target tensor.
Should only be called when predict_with_generate=True.
"""
if pad_token_id is None:
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
pad_token_id = self.tokenizer.pad_token_id
else:
raise ValueError("PAD token is required.")
pad_token_id = pad_token_id if pad_token_id is not None else self.tokenizer.pad_token_id
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory

View File

@ -1,36 +0,0 @@
# Test Template Encode
# Usage: python .\tests\template_encode.py --model_name_and_path D:\llm\chinese-alpaca-2-7b
# --template llama2_zh --query 'how are you?'
# --history '[[\"Hello!\",\"HiI am llama2.\"]]'
import sys
import fire
from typing import List, Optional, Tuple
from transformers import AutoTokenizer
sys.path.append("./src")
from llmtuner.extras.template import get_template_and_fix_tokenizer
def encode(
model_name_and_path: str,
template: str,
query: str,
resp: Optional[str] = "",
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None):
tokenizer = AutoTokenizer.from_pretrained(
model_name_and_path,
trust_remote_code=True
)
template = get_template_and_fix_tokenizer(template, tokenizer)
encoded_pairs = template.encode_multiturn(tokenizer, query, resp, history, system)
for prompt_ids, answer_ids in encoded_pairs:
print("="*50)
print("prompt_ids: {}, answer_ids: {}".format(prompt_ids, answer_ids))
print("prompt decode: {}".format(tokenizer.decode(prompt_ids)))
if __name__ == '__main__':
fire.Fire(encode)