support LongLoRA

This commit is contained in:
hiyouga 2023-09-27 21:55:50 +08:00
parent 4dd9b4d982
commit 90375f600d
8 changed files with 313 additions and 329 deletions

View File

@ -14,6 +14,8 @@
## Changelog
[23/09/27] We supported **S^2-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA). Try `--shift_attn` argument to enable shift short attention.
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
[23/09/10] We supported using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
@ -50,7 +52,7 @@
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |

View File

@ -14,6 +14,8 @@
## 更新日志
[23/09/27] 我们支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **S^2-Attn**。请使用 `--shift_attn` 参数以启用该功能。
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
[23/09/10] 我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2实验性功能
@ -50,7 +52,7 @@
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |

View File

@ -22,6 +22,9 @@ def preprocess_dataset(
column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
if template.efficient_eos and data_args.sft_packing:
raise ValueError("Current template is incompatible with packing.")
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
@ -96,6 +99,28 @@ def preprocess_dataset(
return model_inputs
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# we do not mask the inputs in packed training.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for query, response, history, system in construct_example(examples):
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
input_ids += source_ids + target_ids
labels += source_ids + target_ids # TODO: try masking source_ids here
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
model_inputs["attention_mask"].append([1] * len(block_size))
model_inputs["labels"].append(labels[i: i + block_size])
return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
@ -166,19 +191,19 @@ def preprocess_dataset(
if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_pretrain_dataset
preprocess_func = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_function = preprocess_pairwise_dataset
preprocess_func = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_unsupervised_dataset
preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
with training_args.main_process_first(desc="dataset map pre-processing"):
@ -191,7 +216,7 @@ def preprocess_dataset(
)
dataset = dataset.map(
preprocess_function,
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs

View File

@ -1,301 +0,0 @@
# coding=utf-8
# Modified from:
# [1] https://huggingface.co/Birchlabs/flash_llama/blob/main/modeling_flash_llama.py
# [2] https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py
# [3] https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py
# [4] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17
import torch
from typing import TYPE_CHECKING, Optional, Tuple
from transformers.utils import logging
if TYPE_CHECKING:
from transformers.models.llama.configuration_llama import LlamaConfig
try:
from flash_attn.flash_attn_interface import (
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func
)
from flash_attn.bert_padding import pad_input, unpad_input
print(">>>> FlashAttention installed")
except ImportError:
raise ImportError("Please install FlashAttention from https://github.com/Dao-AILab/flash-attention")
try:
from flash_attn.layers.rotary import apply_rotary_emb_func
print(">>>> Flash RoPE installed")
except ImportError:
raise ImportError("Please install RoPE kernels from https://github.com/Dao-AILab/flash-attention")
logger = logging.get_logger(__name__)
class LlamaRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype) # for fp32 weight
class FlashRotaryEmbedding(torch.nn.Module):
def __init__(
self,
dim: int,
base=10000.0,
interleaved=False,
scale_base=None,
scaling_factor=1.0,
pos_idx_in_fp32=True,
device=None
):
super().__init__()
self.dim = dim
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
self.scaling_factor = scaling_factor
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None else None
)
self.register_buffer("scale", scale)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _compute_inv_freq(self, device=None):
return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
if (
seqlen > self._seq_len_cached or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
t /= self.scaling_factor
if self.inv_freq.dtype != torch.float32:
inv_freq = self.inv_freq.to(torch.float32)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
t /= self.scaling_factor
inv_freq = self.inv_freq
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
(torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2) / self.scale_base
)
scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
q: (batch, seqlen, nheads, headdim)
k: (batch, seqlen, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
"""
self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype)
if self.scale is None:
return apply_rotary_emb_func(
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved, True # inplace=True
), apply_rotary_emb_func(
k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved, True # inplace=True
)
else:
assert False
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
r"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, slen, _, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, :, :, None, :].expand(batch, slen, 2, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(batch, slen, 2, num_key_value_heads * n_rep, head_dim)
class LlamaAttention(torch.nn.Module):
def __init__(self, config: "LlamaConfig"):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = torch.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.register_buffer(
"norm_factor",
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
if self.config.rope_scaling is None:
scaling_factor = 1
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
assert scaling_type == "linear"
self.rotary_emb = FlashRotaryEmbedding(
self.head_dim, base=10000, interleaved=False, scaling_factor=scaling_factor
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, h_size = hidden_states.size()
has_layer_past = past_key_value is not None
if has_layer_past:
past_kv = past_key_value[0]
past_len = past_key_value[1]
else:
past_len = 0
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = q.view(bsz, q_len, self.num_heads, self.head_dim)
k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
q, k = self.rotary_emb(q, k, past_len)
kv = torch.stack([k, v], 2)
kv = repeat_kv(kv, self.num_key_value_groups)
# Cache QKV values
if has_layer_past:
new_len = past_len+q.size(1)
if new_len > past_kv.size(1):
past_kv = torch.cat(
[past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)],
dim=1
)
past_kv[:, past_len:new_len] = kv
kv = past_kv[:, :new_len]
else:
past_kv = kv
past_key_value = (past_kv, past_len + q.size(1)) if use_cache else None
if attention_mask is not None:
# varlen, ignore padding tokens, efficient for large batch with many paddings
logger.warning_once("padded sequences is less efficient")
unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask)
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):])
attn_outputs = flash_attn_varlen_kvpacked_func(
unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k,
dropout_p=0.0, softmax_scale=1.0 / self.norm_factor,
causal=(not has_layer_past), return_attn_probs=output_attentions
)
attn_output = attn_outputs[0] if output_attentions else attn_outputs
attn_output = pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size)
attn_weights = attn_outputs[2] if output_attentions else None
else:
# no padding tokens, more efficient
attn_outputs = flash_attn_kvpacked_func(
q, kv, dropout_p=0.0, softmax_scale=1.0 / self.norm_factor,
causal=(not has_layer_past), return_attn_probs=output_attentions
)
attn_output = attn_outputs[0] if output_attentions else attn_outputs
attn_output = attn_output.reshape(bsz, q_len, h_size)
attn_weights = attn_outputs[2] if output_attentions else None
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Disable the transformation of the attention mask in LlamaModel as flash attention
# takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
if past_key_values_length > 0 and attention_mask is not None:
attention_mask = torch.cat(
(
torch.full(
(input_shape[0], past_key_values_length),
True,
dtype=attention_mask.dtype,
device=attention_mask.device
),
attention_mask
),
dim=-1
)
if attention_mask is not None and torch.all(attention_mask):
return None # This uses the faster call when training with full samples
return attention_mask

View File

@ -0,0 +1,232 @@
# coding=utf-8
# Modified from:
# [1] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
import math
import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
except ImportError:
raise ImportError("Please install FlashAttention from https://github.com/Dao-AILab/flash-attention")
logger = logging.get_logger(__name__)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
class LlamaShiftShortAttention(LlamaAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
if getattr(self, "num_key_value_groups"):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self, "shift_ratio", None) and self.training: # shift
group_size = int(q_len * getattr(self, "shift_ratio"))
if q_len % group_size > 0:
raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size))
num_group = q_len // group_size
for state in (query_states, key_states, value_states):
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1)
state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim).transpose(1, 2)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self, "shift_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaFlashAttention2(LlamaAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
if getattr(self, "num_key_value_groups"):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
if getattr(self, "shift_ratio", None) and self.training: # shift
group_size = int(q_len * getattr(self, "shift_ratio"))
if q_len % group_size > 0:
raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size))
num_group = q_len // group_size
for state in (query_states, key_states, value_states):
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1)
state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim)
if attention_mask is not None:
logger.warning_once("Padded sequences are less efficient.")
batch_size = query_states.shape[0]
# -q_len: assumes left padding
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
attn_output_unpad = flash_attn_varlen_func(
unpadded_q,
unpadded_k,
unpadded_v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=True,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
)
if getattr(self, "shift_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Disable the transformation of the attention mask in LlamaModel as flash attention
# takes a boolean padding_mask. Fills in the past kv length for use in forward.
def _prepare_decoder_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: torch.Tensor,
inputs_embeds: torch.Tensor,
past_key_values_length: int
) -> torch.Tensor:
if attention_mask is not None and torch.all(attention_mask):
return None # This uses the faster call when training with full samples
return attention_mask

View File

@ -90,6 +90,10 @@ class DataArguments:
default=0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
)
sft_packing: Optional[bool] = field(
default=False,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
)
def init_for_training(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")]

View File

@ -45,7 +45,11 @@ class ModelArguments:
)
flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable flash attention for faster training."}
metadata={"help": "Enable FlashAttention-2 for faster training."}
)
shift_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
)
checkpoint_dir: Optional[str] = field(
default=None,

View File

@ -13,17 +13,19 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase
)
from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
try:
from transformers.integrations import is_deepspeed_zero3_enabled
except ImportError:
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter
@ -73,10 +75,6 @@ def load_model_and_tokenizer(
**config_kwargs
)
# Fix tokenizer (for ChatGLM2)
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
model_to_load = model_args.checkpoint_dir[0]
else:
@ -84,10 +82,15 @@ def load_model_and_tokenizer(
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
# Fix tokenizer (for ChatGLM2)
if getattr(config, "model_type", None) == "chatglm":
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
# Fix config (for Qwen)
if hasattr(config, "fp16") and hasattr(config, "bf16"):
if getattr(config, "model_type", None) == "qwen":
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
setattr(config, "fp32", model_args.compute_dtype == torch.float32)
# Set RoPE scaling
if model_args.rope_scaling is not None:
@ -103,7 +106,6 @@ def load_model_and_tokenizer(
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
if is_trainable:
if model_args.rope_scaling == "dynamic":
assert not model_args.flash_attn, "Flash attention does not support dynamic rope scaling."
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
@ -126,17 +128,23 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support RoPE scaling.")
# Set flash attention
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
import transformers.models.llama.modeling_llama as LlamaModule
import llmtuner.extras.patches.flash_llama as FlashLlama
LlamaModule.LlamaRMSNorm = FlashLlama.LlamaRMSNorm
LlamaModule.LlamaAttention = FlashLlama.LlamaAttention
LlamaModule.LlamaModel._prepare_decoder_attention_mask = FlashLlama._prepare_decoder_attention_mask
if not hasattr(config, "num_key_value_heads"): # for LLaMA-1 models
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
if getattr(config, "pretraining_tp", 1) != 1:
setattr(config, "pretraining_tp", 1)
# Fix RMSNorm in fp32 weight (https://github.com/huggingface/transformers/pull/23535)
if getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaRMSNorm = LlamaPatches.LlamaRMSNorm
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = (
LlamaPatches._prepare_decoder_attention_mask
)
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("Current model does not support FlashAttention-2.")
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
logger.warning("Using `--flash_attn` for faster training in large context length.")
# Quantization configurations (using bitsandbytes library).
is_mergeable = True
@ -172,12 +180,20 @@ def load_model_and_tokenizer(
**config_kwargs
)
# Disable custom generate method (for Qwen)
# Set shift short attention (S^2-Attn)
if is_trainable and model_args.shift_attn:
if getattr(config, "model_type", None) == "llama":
setattr(model, "shift_ratio", 0.25)
logger.info("Using shift short attention proposed by LongLoRA.")
else:
logger.warning("Current model does not support shift short attention.")
# Disable custom generate method (for Qwen and Baichuan2)
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
if getattr(config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
# Register auto class to save the custom code files.