support LongLoRA
This commit is contained in:
parent
4dd9b4d982
commit
90375f600d
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[23/09/27] We supported **S^2-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA). Try `--shift_attn` argument to enable shift short attention.
|
||||||
|
|
||||||
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
||||||
|
|
||||||
[23/09/10] We supported using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
[23/09/10] We supported using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
||||||
|
@ -50,7 +52,7 @@
|
||||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml |
|
||||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
|
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
|
||||||
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[23/09/27] 我们支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **S^2-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
||||||
|
|
||||||
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
||||||
|
|
||||||
[23/09/10] 我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2(实验性功能)。
|
[23/09/10] 我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2(实验性功能)。
|
||||||
|
@ -50,7 +52,7 @@
|
||||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml |
|
||||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
|
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
|
||||||
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||||
|
|
|
@ -22,6 +22,9 @@ def preprocess_dataset(
|
||||||
column_names = list(next(iter(dataset)).keys())
|
column_names = list(next(iter(dataset)).keys())
|
||||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||||
|
|
||||||
|
if template.efficient_eos and data_args.sft_packing:
|
||||||
|
raise ValueError("Current template is incompatible with packing.")
|
||||||
|
|
||||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||||
for i in range(len(examples["prompt"])):
|
for i in range(len(examples["prompt"])):
|
||||||
query, response = examples["prompt"][i], examples["response"][i]
|
query, response = examples["prompt"][i], examples["response"][i]
|
||||||
|
@ -96,6 +99,28 @@ def preprocess_dataset(
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||||
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
|
# we do not mask the inputs in packed training.
|
||||||
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
input_ids, labels = [], []
|
||||||
|
for query, response, history, system in construct_example(examples):
|
||||||
|
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
|
||||||
|
input_ids += source_ids + target_ids
|
||||||
|
labels += source_ids + target_ids # TODO: try masking source_ids here
|
||||||
|
|
||||||
|
total_length = len(input_ids)
|
||||||
|
block_size = data_args.cutoff_len
|
||||||
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||||
|
total_length = (total_length // block_size) * block_size
|
||||||
|
# split by chunks of cutoff_len
|
||||||
|
for i in range(0, total_length, block_size):
|
||||||
|
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
||||||
|
model_inputs["attention_mask"].append([1] * len(block_size))
|
||||||
|
model_inputs["labels"].append(labels[i: i + block_size])
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
@ -166,19 +191,19 @@ def preprocess_dataset(
|
||||||
|
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
dataset = dataset.filter(lambda example: example["prompt"])
|
dataset = dataset.filter(lambda example: example["prompt"])
|
||||||
preprocess_function = preprocess_pretrain_dataset
|
preprocess_func = preprocess_pretrain_dataset
|
||||||
print_function = print_unsupervised_dataset_example
|
print_function = print_unsupervised_dataset_example
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
|
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
|
||||||
preprocess_function = preprocess_supervised_dataset
|
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
|
||||||
print_function = print_supervised_dataset_example
|
print_function = print_supervised_dataset_example
|
||||||
elif stage == "rm":
|
elif stage == "rm":
|
||||||
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
|
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
|
||||||
preprocess_function = preprocess_pairwise_dataset
|
preprocess_func = preprocess_pairwise_dataset
|
||||||
print_function = print_pairwise_dataset_example
|
print_function = print_pairwise_dataset_example
|
||||||
else:
|
else:
|
||||||
dataset = dataset.filter(lambda example: example["prompt"])
|
dataset = dataset.filter(lambda example: example["prompt"])
|
||||||
preprocess_function = preprocess_unsupervised_dataset
|
preprocess_func = preprocess_unsupervised_dataset
|
||||||
print_function = print_unsupervised_dataset_example
|
print_function = print_unsupervised_dataset_example
|
||||||
|
|
||||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||||
|
@ -191,7 +216,7 @@ def preprocess_dataset(
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
preprocess_function,
|
preprocess_func,
|
||||||
batched=True,
|
batched=True,
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
**kwargs
|
**kwargs
|
||||||
|
|
|
@ -1,301 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Modified from:
|
|
||||||
# [1] https://huggingface.co/Birchlabs/flash_llama/blob/main/modeling_flash_llama.py
|
|
||||||
# [2] https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py
|
|
||||||
# [3] https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py
|
|
||||||
# [4] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
|
||||||
# With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attn_interface import (
|
|
||||||
flash_attn_kvpacked_func,
|
|
||||||
flash_attn_varlen_kvpacked_func
|
|
||||||
)
|
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
|
||||||
print(">>>> FlashAttention installed")
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install FlashAttention from https://github.com/Dao-AILab/flash-attention")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn.layers.rotary import apply_rotary_emb_func
|
|
||||||
print(">>>> Flash RoPE installed")
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install RoPE kernels from https://github.com/Dao-AILab/flash-attention")
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaRMSNorm(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
return (self.weight * hidden_states).to(input_dtype) # for fp32 weight
|
|
||||||
|
|
||||||
|
|
||||||
class FlashRotaryEmbedding(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
base=10000.0,
|
|
||||||
interleaved=False,
|
|
||||||
scale_base=None,
|
|
||||||
scaling_factor=1.0,
|
|
||||||
pos_idx_in_fp32=True,
|
|
||||||
device=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.base = float(base)
|
|
||||||
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
|
||||||
# Generate and save the inverse frequency buffer (non trainable)
|
|
||||||
inv_freq = self._compute_inv_freq(device)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
self.interleaved = interleaved
|
|
||||||
self.scale_base = scale_base
|
|
||||||
self.scaling_factor = scaling_factor
|
|
||||||
scale = (
|
|
||||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
|
||||||
if scale_base is not None else None
|
|
||||||
)
|
|
||||||
self.register_buffer("scale", scale)
|
|
||||||
|
|
||||||
self._seq_len_cached = 0
|
|
||||||
self._cos_cached = None
|
|
||||||
self._sin_cached = None
|
|
||||||
self._cos_k_cached = None
|
|
||||||
self._sin_k_cached = None
|
|
||||||
|
|
||||||
def _compute_inv_freq(self, device=None):
|
|
||||||
return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
|
||||||
if (
|
|
||||||
seqlen > self._seq_len_cached or self._cos_cached.device != device
|
|
||||||
or self._cos_cached.dtype != dtype
|
|
||||||
or (self.training and self._cos_cached.is_inference())
|
|
||||||
):
|
|
||||||
self._seq_len_cached = seqlen
|
|
||||||
if self.pos_idx_in_fp32:
|
|
||||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
|
||||||
t /= self.scaling_factor
|
|
||||||
if self.inv_freq.dtype != torch.float32:
|
|
||||||
inv_freq = self.inv_freq.to(torch.float32)
|
|
||||||
else:
|
|
||||||
inv_freq = self.inv_freq
|
|
||||||
else:
|
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
|
||||||
t /= self.scaling_factor
|
|
||||||
inv_freq = self.inv_freq
|
|
||||||
freqs = torch.outer(t, inv_freq)
|
|
||||||
if self.scale is None:
|
|
||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
||||||
else:
|
|
||||||
power = (
|
|
||||||
(torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2) / self.scale_base
|
|
||||||
)
|
|
||||||
scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
|
|
||||||
# We want the multiplication by scale to happen in fp32
|
|
||||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
|
||||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
|
||||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
|
||||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
|
||||||
|
|
||||||
def forward(self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
r"""
|
|
||||||
q: (batch, seqlen, nheads, headdim)
|
|
||||||
k: (batch, seqlen, nheads, headdim)
|
|
||||||
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
|
||||||
token in the batch.
|
|
||||||
"""
|
|
||||||
self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype)
|
|
||||||
if self.scale is None:
|
|
||||||
return apply_rotary_emb_func(
|
|
||||||
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
|
||||||
self.interleaved, True # inplace=True
|
|
||||||
), apply_rotary_emb_func(
|
|
||||||
k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
|
||||||
self.interleaved, True # inplace=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
||||||
r"""
|
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
||||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
||||||
"""
|
|
||||||
batch, slen, _, num_key_value_heads, head_dim = hidden_states.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return hidden_states
|
|
||||||
hidden_states = hidden_states[:, :, :, :, None, :].expand(batch, slen, 2, num_key_value_heads, n_rep, head_dim)
|
|
||||||
return hidden_states.reshape(batch, slen, 2, num_key_value_heads * n_rep, head_dim)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaAttention(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: "LlamaConfig"):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
|
||||||
self.num_key_value_heads = config.num_key_value_heads
|
|
||||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
|
||||||
|
|
||||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
||||||
f" and `num_heads`: {self.num_heads})."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.q_proj = torch.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
||||||
self.k_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
||||||
self.v_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
||||||
self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
||||||
|
|
||||||
self.register_buffer(
|
|
||||||
"norm_factor",
|
|
||||||
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
|
|
||||||
persistent=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.rope_scaling is None:
|
|
||||||
scaling_factor = 1
|
|
||||||
else:
|
|
||||||
scaling_type = self.config.rope_scaling["type"]
|
|
||||||
scaling_factor = self.config.rope_scaling["factor"]
|
|
||||||
assert scaling_type == "linear"
|
|
||||||
|
|
||||||
self.rotary_emb = FlashRotaryEmbedding(
|
|
||||||
self.head_dim, base=10000, interleaved=False, scaling_factor=scaling_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
bsz, q_len, h_size = hidden_states.size()
|
|
||||||
|
|
||||||
has_layer_past = past_key_value is not None
|
|
||||||
|
|
||||||
if has_layer_past:
|
|
||||||
past_kv = past_key_value[0]
|
|
||||||
past_len = past_key_value[1]
|
|
||||||
else:
|
|
||||||
past_len = 0
|
|
||||||
|
|
||||||
q = self.q_proj(hidden_states)
|
|
||||||
k = self.k_proj(hidden_states)
|
|
||||||
v = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
q = q.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
|
||||||
v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
|
||||||
|
|
||||||
q, k = self.rotary_emb(q, k, past_len)
|
|
||||||
|
|
||||||
kv = torch.stack([k, v], 2)
|
|
||||||
kv = repeat_kv(kv, self.num_key_value_groups)
|
|
||||||
|
|
||||||
# Cache QKV values
|
|
||||||
if has_layer_past:
|
|
||||||
new_len = past_len+q.size(1)
|
|
||||||
if new_len > past_kv.size(1):
|
|
||||||
past_kv = torch.cat(
|
|
||||||
[past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)],
|
|
||||||
dim=1
|
|
||||||
)
|
|
||||||
past_kv[:, past_len:new_len] = kv
|
|
||||||
kv = past_kv[:, :new_len]
|
|
||||||
else:
|
|
||||||
past_kv = kv
|
|
||||||
|
|
||||||
past_key_value = (past_kv, past_len + q.size(1)) if use_cache else None
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
# varlen, ignore padding tokens, efficient for large batch with many paddings
|
|
||||||
logger.warning_once("padded sequences is less efficient")
|
|
||||||
|
|
||||||
unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask)
|
|
||||||
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):])
|
|
||||||
attn_outputs = flash_attn_varlen_kvpacked_func(
|
|
||||||
unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
|
|
||||||
max_seqlen_q, max_seqlen_k,
|
|
||||||
dropout_p=0.0, softmax_scale=1.0 / self.norm_factor,
|
|
||||||
causal=(not has_layer_past), return_attn_probs=output_attentions
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
|
||||||
attn_output = pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size)
|
|
||||||
attn_weights = attn_outputs[2] if output_attentions else None
|
|
||||||
|
|
||||||
else:
|
|
||||||
# no padding tokens, more efficient
|
|
||||||
attn_outputs = flash_attn_kvpacked_func(
|
|
||||||
q, kv, dropout_p=0.0, softmax_scale=1.0 / self.norm_factor,
|
|
||||||
causal=(not has_layer_past), return_attn_probs=output_attentions
|
|
||||||
)
|
|
||||||
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, h_size)
|
|
||||||
attn_weights = attn_outputs[2] if output_attentions else None
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
|
||||||
# takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
|
|
||||||
def _prepare_decoder_attention_mask(
|
|
||||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
|
||||||
):
|
|
||||||
# [bsz, seq_len]
|
|
||||||
if past_key_values_length > 0 and attention_mask is not None:
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
(
|
|
||||||
torch.full(
|
|
||||||
(input_shape[0], past_key_values_length),
|
|
||||||
True,
|
|
||||||
dtype=attention_mask.dtype,
|
|
||||||
device=attention_mask.device
|
|
||||||
),
|
|
||||||
attention_mask
|
|
||||||
),
|
|
||||||
dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None and torch.all(attention_mask):
|
|
||||||
return None # This uses the faster call when training with full samples
|
|
||||||
|
|
||||||
return attention_mask
|
|
|
@ -0,0 +1,232 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Modified from:
|
||||||
|
# [1] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from transformers.utils import logging
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install FlashAttention from https://github.com/Dao-AILab/flash-attention")
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaRMSNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return (self.weight * hidden_states).to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaShiftShortAttention(LlamaAttention):
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
if past_key_value is not None: # reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
if getattr(self, "num_key_value_groups"):
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if getattr(self, "shift_ratio", None) and self.training: # shift
|
||||||
|
group_size = int(q_len * getattr(self, "shift_ratio"))
|
||||||
|
if q_len % group_size > 0:
|
||||||
|
raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size))
|
||||||
|
num_group = q_len // group_size
|
||||||
|
for state in (query_states, key_states, value_states):
|
||||||
|
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||||
|
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1)
|
||||||
|
state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
if getattr(self, "shift_ratio", None) and self.training: # shift back
|
||||||
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaFlashAttention2(LlamaAttention):
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
# LlamaFlashAttention2 attention does not support output_attentions
|
||||||
|
output_attentions = False
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
if past_key_value is not None: # reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
if getattr(self, "num_key_value_groups"):
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
|
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
|
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
|
|
||||||
|
if getattr(self, "shift_ratio", None) and self.training: # shift
|
||||||
|
group_size = int(q_len * getattr(self, "shift_ratio"))
|
||||||
|
if q_len % group_size > 0:
|
||||||
|
raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size))
|
||||||
|
num_group = q_len // group_size
|
||||||
|
for state in (query_states, key_states, value_states):
|
||||||
|
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1)
|
||||||
|
state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
logger.warning_once("Padded sequences are less efficient.")
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
# -q_len: assumes left padding
|
||||||
|
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
|
||||||
|
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
|
||||||
|
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
unpadded_q,
|
||||||
|
unpadded_k,
|
||||||
|
unpadded_v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(self, "shift_ratio", None) and self.training: # shift back
|
||||||
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
||||||
|
# takes a boolean padding_mask. Fills in the past kv length for use in forward.
|
||||||
|
def _prepare_decoder_attention_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_shape: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
past_key_values_length: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if attention_mask is not None and torch.all(attention_mask):
|
||||||
|
return None # This uses the faster call when training with full samples
|
||||||
|
|
||||||
|
return attention_mask
|
|
@ -90,6 +90,10 @@ class DataArguments:
|
||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
||||||
)
|
)
|
||||||
|
sft_packing: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
||||||
|
)
|
||||||
|
|
||||||
def init_for_training(self): # support mixing multiple datasets
|
def init_for_training(self): # support mixing multiple datasets
|
||||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||||
|
|
|
@ -45,7 +45,11 @@ class ModelArguments:
|
||||||
)
|
)
|
||||||
flash_attn: Optional[bool] = field(
|
flash_attn: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable flash attention for faster training."}
|
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
||||||
|
)
|
||||||
|
shift_attn: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||||
)
|
)
|
||||||
checkpoint_dir: Optional[str] = field(
|
checkpoint_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|
|
@ -13,17 +13,19 @@ from transformers import (
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase
|
PreTrainedTokenizerBase
|
||||||
)
|
)
|
||||||
|
from transformers.models.llama import modeling_llama as LlamaModule
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
except ImportError:
|
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from llmtuner.extras.logging import reset_logging, get_logger
|
from llmtuner.extras.logging import reset_logging, get_logger
|
||||||
from llmtuner.extras.misc import count_parameters
|
from llmtuner.extras.misc import count_parameters
|
||||||
|
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
from llmtuner.tuner.core.adapter import init_adapter
|
from llmtuner.tuner.core.adapter import init_adapter
|
||||||
|
@ -73,10 +75,6 @@ def load_model_and_tokenizer(
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fix tokenizer (for ChatGLM2)
|
|
||||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
|
||||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
|
||||||
|
|
||||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||||
model_to_load = model_args.checkpoint_dir[0]
|
model_to_load = model_args.checkpoint_dir[0]
|
||||||
else:
|
else:
|
||||||
|
@ -84,10 +82,15 @@ def load_model_and_tokenizer(
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||||
|
|
||||||
|
# Fix tokenizer (for ChatGLM2)
|
||||||
|
if getattr(config, "model_type", None) == "chatglm":
|
||||||
|
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||||
|
|
||||||
# Fix config (for Qwen)
|
# Fix config (for Qwen)
|
||||||
if hasattr(config, "fp16") and hasattr(config, "bf16"):
|
if getattr(config, "model_type", None) == "qwen":
|
||||||
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
|
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
|
||||||
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
|
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
|
||||||
|
setattr(config, "fp32", model_args.compute_dtype == torch.float32)
|
||||||
|
|
||||||
# Set RoPE scaling
|
# Set RoPE scaling
|
||||||
if model_args.rope_scaling is not None:
|
if model_args.rope_scaling is not None:
|
||||||
|
@ -103,7 +106,6 @@ def load_model_and_tokenizer(
|
||||||
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
|
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
if model_args.rope_scaling == "dynamic":
|
if model_args.rope_scaling == "dynamic":
|
||||||
assert not model_args.flash_attn, "Flash attention does not support dynamic rope scaling."
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Dynamic NTK may not work well with fine-tuning. "
|
"Dynamic NTK may not work well with fine-tuning. "
|
||||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||||
|
@ -126,17 +128,23 @@ def load_model_and_tokenizer(
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support RoPE scaling.")
|
logger.warning("Current model does not support RoPE scaling.")
|
||||||
|
|
||||||
# Set flash attention
|
# Fix RMSNorm in fp32 weight (https://github.com/huggingface/transformers/pull/23535)
|
||||||
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
|
if getattr(config, "model_type", None) == "llama":
|
||||||
import transformers.models.llama.modeling_llama as LlamaModule
|
LlamaModule.LlamaRMSNorm = LlamaPatches.LlamaRMSNorm
|
||||||
import llmtuner.extras.patches.flash_llama as FlashLlama
|
|
||||||
LlamaModule.LlamaRMSNorm = FlashLlama.LlamaRMSNorm
|
# Set FlashAttention-2
|
||||||
LlamaModule.LlamaAttention = FlashLlama.LlamaAttention
|
if model_args.flash_attn:
|
||||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = FlashLlama._prepare_decoder_attention_mask
|
if getattr(config, "model_type", None) == "llama":
|
||||||
if not hasattr(config, "num_key_value_heads"): # for LLaMA-1 models
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||||
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
|
LlamaModule.LlamaModel._prepare_decoder_attention_mask = (
|
||||||
if getattr(config, "pretraining_tp", 1) != 1:
|
LlamaPatches._prepare_decoder_attention_mask
|
||||||
setattr(config, "pretraining_tp", 1)
|
)
|
||||||
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
|
else:
|
||||||
|
logger.warning("Current model does not support FlashAttention-2.")
|
||||||
|
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||||
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||||
|
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
||||||
|
|
||||||
# Quantization configurations (using bitsandbytes library).
|
# Quantization configurations (using bitsandbytes library).
|
||||||
is_mergeable = True
|
is_mergeable = True
|
||||||
|
@ -172,12 +180,20 @@ def load_model_and_tokenizer(
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Disable custom generate method (for Qwen)
|
# Set shift short attention (S^2-Attn)
|
||||||
|
if is_trainable and model_args.shift_attn:
|
||||||
|
if getattr(config, "model_type", None) == "llama":
|
||||||
|
setattr(model, "shift_ratio", 0.25)
|
||||||
|
logger.info("Using shift short attention proposed by LongLoRA.")
|
||||||
|
else:
|
||||||
|
logger.warning("Current model does not support shift short attention.")
|
||||||
|
|
||||||
|
# Disable custom generate method (for Qwen and Baichuan2)
|
||||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||||
|
|
||||||
# Fix LM head (for ChatGLM2)
|
# Fix LM head (for ChatGLM2)
|
||||||
if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
|
if getattr(config, "model_type", None) == "chatglm":
|
||||||
setattr(model, "lm_head", model.transformer.output_layer)
|
setattr(model, "lm_head", model.transformer.output_layer)
|
||||||
|
|
||||||
# Register auto class to save the custom code files.
|
# Register auto class to save the custom code files.
|
||||||
|
|
Loading…
Reference in New Issue