CPM-9G-8B/9G-Train/cpm/layers/attention.py

312 lines
12 KiB
Python
Raw Normal View History

2024-02-27 14:33:33 +08:00
import math
from typing import Optional
from typing import Tuple
try:
from .flash_triton import FlashAttnFunc
except:
FlashAttnFunc = None
import bmtrain as bmt
import torch
from einops import rearrange
from .linear import ColumnParallelLinear
from .linear import Linear
from .position_embedding import apply_chatglm_rotary_pos_emb
try:
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except:
flash_attn_varlen_func = None
try:
from flash_attn.bert_padding import pad_input
from flash_attn.bert_padding import unpad_input
except:
pad_input = None
unpad_input = None
class OpFlash(torch.autograd.Function):
@staticmethod
def forward(ctx, self, record, q, k, v, cu_seqlens, max_seqlen, dropout_p, causal):
ctx.self = self
ctx.cu_seqlens = cu_seqlens
ctx.max_length = max_seqlen
ctx.dropout_p = dropout_p
ctx.causal = causal
ctx.softmax_scale = q.shape[-1] ** (-0.5)
if not record and "out" in self._layer_dict:
out = self._layer_dict.pop("out")
softmax_lse = self._layer_dict.pop("softmax_lse")
rng_state = self._layer_dict.pop("rng_state")
else:
out, _, _, _, _, softmax_lse, _, rng_state = _flash_attn_varlen_forward(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
dropout_p,
ctx.softmax_scale,
causal=causal,
return_softmax=False,
)
if record:
self._layer_dict["out"] = out
self._layer_dict["softmax_lse"] = softmax_lse
self._layer_dict["rng_state"] = rng_state
ctx.save_for_backward(q, k, v, out, softmax_lse, rng_state)
return out
@staticmethod
def backward(ctx, dout):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
ctx.cu_seqlens,
ctx.cu_seqlens,
ctx.max_length,
ctx.max_length,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
rng_state=rng_state,
)
return None, None, dq, dk, dv, None, None, None, None
class Attention(bmt.DistributedModule):
def __init__(
self,
dim_model: int,
num_heads: int,
num_kv_heads: int,
dim_head: int,
dtype: torch.dtype = torch.half,
dropout_p: Optional[float] = None,
scale: bool = True,
add_qkv_bias: bool = False,
use_flash_attn: bool = False,
tp: int = 0,
) -> None:
super().__init__()
self.dim_model = dim_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_groups = num_heads // num_kv_heads
self.dim_head = dim_head
self.project_q = Linear(
self.dim_model,
self.num_heads * self.dim_head,
bias=add_qkv_bias,
dtype=dtype,
scale=scale,
tp=tp,
)
self.project_k = Linear(
self.dim_model,
self.num_kv_heads * self.dim_head,
bias=add_qkv_bias,
dtype=dtype,
scale=scale,
tp=tp,
)
self.project_v = Linear(
self.dim_model,
self.num_kv_heads * self.dim_head,
bias=add_qkv_bias,
dtype=dtype,
scale=scale,
tp=tp,
)
self.attention_out = Linear(
self.num_heads * self.dim_head,
self.dim_model,
dtype=dtype,
scale=scale,
tp=tp * 2,
)
self.softmax = torch.nn.Softmax(dim=-1)
if dropout_p is not None:
self.dropout = torch.nn.Dropout(p=dropout_p)
self.dropout_p = dropout_p
else:
self.dropout = None
self.use_flash_attn = use_flash_attn
self._layer_dict = {}
def forward(
self,
hidden_q: torch.Tensor,
hidden_kv: torch.Tensor,
attention_mask: torch.BoolTensor,
position_bias: torch.Tensor,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
attention_mask_bias: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: int = None,
position_ids: Optional[torch.Tensor] = None,
):
"""This model inherits from bmt.DistributedModule.
Args:
hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices.
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`.
Return:
out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output.
""" # noqa: E501
len_q = hidden_q.size(1)
len_k = hidden_kv.size(1)
if isinstance(self.project_q, ColumnParallelLinear):
assert hidden_q.data_ptr() == hidden_kv.data_ptr()
if self.project_q.scale and self.project_q.scale_before:
hidden_q = hidden_q / math.sqrt(self.project_q.dim_in)
hidden_q = bmt.nn.OpParallelLinear.apply(
hidden_q,
torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0),
torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0)
if self.project_q.bias is not None
else None,
True,
False,
False,
None,
)
if self.project_q.scale and not self.project_q.scale_before:
hidden_q = hidden_q / math.sqrt(self.project_q.dim_in)
block_size = hidden_q.shape[-1] // (self.head_groups + 1 + 1)
h_q = hidden_q[..., : block_size * self.head_groups]
h_k = hidden_q[..., block_size * self.head_groups : block_size * (self.head_groups + 1)]
h_v = hidden_q[..., block_size * (self.head_groups + 1) :]
else:
h_q = self.project_q(hidden_q)
h_k = self.project_k(hidden_kv)
h_v = self.project_v(hidden_kv)
batch_size = h_q.size(0)
if not self.use_flash_attn:
h_q = h_q / math.sqrt(math.sqrt(self.dim_head))
h_k = h_k / math.sqrt(math.sqrt(self.dim_head))
h_q = h_q.view(batch_size, len_q, -1, self.dim_head).permute(0, 2, 1, 3)
h_k = h_k.view(batch_size, len_k, -1, self.dim_head).permute(0, 2, 1, 3)
h_v = h_v.view(batch_size, len_k, -1, self.dim_head).permute(0, 2, 1, 3)
if pos_bias_type == "rotary":
# b h s d
h_q, h_k = position_bias(h_q, h_k, -2, offset=past_kv[0].size(-2) if past_kv is not None else 0)
elif pos_bias_type == "chatglm_rotary":
h_q = apply_chatglm_rotary_pos_emb(h_q, position_bias)
h_k = apply_chatglm_rotary_pos_emb(h_k, position_bias)
if past_kv is not None:
h_k = torch.cat([past_kv[0], h_k], dim=-2)
h_v = torch.cat([past_kv[1], h_v], dim=-2)
len_k = h_k.size(-2)
# (b, n_h, len_q, d_h) @ (b, n_h, d_h, len_k) -> (b, n_h, len_q, len_k)
# (b, n_kv_h, n_h_groups*len_q, d_h) @ (b, n_kv_h, d_h, len_k) -> (b, n_kv_h, n_h_groups*len_q, len_k) -> (b, n_h, len_q, len_k)
if self.head_groups == 1:
score = torch.matmul(h_q, h_k.transpose(-1, -2)) # / math.sqrt(self.dim_head) moved to line 75~76
else:
score = torch.matmul(
h_q.reshape(batch_size, -1, self.head_groups * len_q, self.dim_head),
h_k.transpose(-1, -2),
).view(
batch_size, -1, len_q, len_k
) # / math.sqrt(self.dim_head) moved to line 75~76
if pos_bias_type == "relative":
if len_q == 1: # inference with cache
if len(position_bias.size()) == 4:
position_bias = position_bias[:, :, -1:, :]
else:
position_bias = position_bias[:, -1:, :]
score = score + position_bias
score = torch.masked_fill(
score,
attention_mask.view(batch_size, 1, len_q, len_k) == False,
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
)
score = self.softmax(score)
score = torch.masked_fill(
score,
attention_mask.view(batch_size, 1, len_q, len_k) == False,
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
)
if self.dropout is not None:
score = self.dropout(score)
# (b, n_h, len_q, len_k) @ (b, n_h, len_k, d_h) -> (b, n_h, len_q, d_h)
# (b, n_kv_h, n_h_groups*len_q, len_k) @ (b, n_kv_h, len_k, d_h) -> (b, n_kv_h, n_h_groups*len_q, d_h) -> (b, n_h, len_q, d_h)
score = torch.matmul(score.view(batch_size, -1, self.head_groups * len_q, len_k), h_v).view(
batch_size, -1, len_q, self.dim_head
)
score = score.view(batch_size, -1, len_q, self.dim_head).permute(0, 2, 1, 3)
score = score.contiguous().view(batch_size, len_q, -1)
else:
if attention_mask_bias is not None:
assert pos_bias_type == "rotary"
h_q = h_q.view(batch_size, len_q, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_k = h_k.view(batch_size, len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_v = h_v.view(batch_size, len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_q, h_k = position_bias(h_q, h_k, -3)
score = FlashAttnFunc.apply(h_q, h_k, h_v, attention_mask_bias, False, None)
else:
if pos_bias_type == "chatglm_rotary":
raise NotImplemented("No FlashAttn version for ChatGLM at present!")
h_q = h_q.view(batch_size * len_q, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_k = h_k.view(batch_size * len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_v = h_v.view(batch_size * len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_q, h_k = position_bias(
h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids
)
# score = flash_attn_varlen_func(
# h_q, h_k, h_v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, self.dropout_p, causal=True
# )
score = OpFlash.apply(
self, not torch.is_grad_enabled(), h_q, h_k, h_v, cu_seqlens, max_seqlen, self.dropout_p, True
)
score = score.view(batch_size, len_q, -1)
score = self.attention_out(score)
if use_cache:
return score, (h_k, h_v)
else:
return score