This commit is contained in:
hiyouga 2024-07-04 03:47:05 +08:00
parent 44747cebd2
commit 0c699de39d
2 changed files with 13 additions and 2 deletions

View File

@ -74,13 +74,13 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
"""
bsz = attention_mask.size(0)
dtype, device = attention_mask.dtype, attention_mask.device
max_num = torch.max(attention_mask)
max_num = torch.max(attention_mask).item()
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
for i in range(max_num):
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
counts = counts.flatten()
seqlens = counts[counts.nonzero().squeeze()]
seqlens = counts[counts.nonzero().squeeze(dim=-1)]
return seqlens

View File

@ -28,6 +28,11 @@ def test_get_seqlens_in_batch():
assert list(seqlens_in_batch.size()) == [5]
assert torch.all(seqlens_in_batch == torch.tensor([2, 3, 1, 2, 3]))
attention_mask_with_indices = torch.tensor([[1, 1, 1]])
seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices)
assert list(seqlens_in_batch.size()) == [1]
assert torch.all(seqlens_in_batch == torch.tensor([3]))
def test_get_unpad_data():
attention_mask_with_indices = torch.tensor(
@ -40,3 +45,9 @@ def test_get_unpad_data():
assert torch.all(indices == torch.tensor([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]))
assert torch.all(cu_seqlens == torch.tensor([0, 2, 5, 6, 8, 11], dtype=torch.int32))
assert max_seqlen_in_batch == 3
attention_mask_with_indices = torch.tensor([[1, 1, 1]])
indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices)
assert torch.all(indices == torch.tensor([0, 1, 2]))
assert torch.all(cu_seqlens == torch.tensor([0, 3], dtype=torch.int32))
assert max_seqlen_in_batch == 3