fix flash shift short attention

This commit is contained in:
hiyouga 2023-10-09 17:54:48 +08:00
parent 6b24f29c8a
commit 0a356bc897
1 changed files with 12 additions and 4 deletions

View File

@ -61,7 +61,9 @@ class LlamaShiftShortAttention(LlamaAttention):
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
@ -80,7 +82,9 @@ class LlamaShiftShortAttention(LlamaAttention):
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
@ -151,7 +155,9 @@ class LlamaFlashAttention2(LlamaAttention):
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
@ -184,7 +190,9 @@ class LlamaFlashAttention2(LlamaAttention):
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)