tiny fix
This commit is contained in:
parent
44747cebd2
commit
0c699de39d
|
@ -74,13 +74,13 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
|||
"""
|
||||
bsz = attention_mask.size(0)
|
||||
dtype, device = attention_mask.dtype, attention_mask.device
|
||||
max_num = torch.max(attention_mask)
|
||||
max_num = torch.max(attention_mask).item()
|
||||
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
|
||||
for i in range(max_num):
|
||||
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
|
||||
|
||||
counts = counts.flatten()
|
||||
seqlens = counts[counts.nonzero().squeeze()]
|
||||
seqlens = counts[counts.nonzero().squeeze(dim=-1)]
|
||||
return seqlens
|
||||
|
||||
|
||||
|
|
|
@ -28,6 +28,11 @@ def test_get_seqlens_in_batch():
|
|||
assert list(seqlens_in_batch.size()) == [5]
|
||||
assert torch.all(seqlens_in_batch == torch.tensor([2, 3, 1, 2, 3]))
|
||||
|
||||
attention_mask_with_indices = torch.tensor([[1, 1, 1]])
|
||||
seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices)
|
||||
assert list(seqlens_in_batch.size()) == [1]
|
||||
assert torch.all(seqlens_in_batch == torch.tensor([3]))
|
||||
|
||||
|
||||
def test_get_unpad_data():
|
||||
attention_mask_with_indices = torch.tensor(
|
||||
|
@ -40,3 +45,9 @@ def test_get_unpad_data():
|
|||
assert torch.all(indices == torch.tensor([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]))
|
||||
assert torch.all(cu_seqlens == torch.tensor([0, 2, 5, 6, 8, 11], dtype=torch.int32))
|
||||
assert max_seqlen_in_batch == 3
|
||||
|
||||
attention_mask_with_indices = torch.tensor([[1, 1, 1]])
|
||||
indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices)
|
||||
assert torch.all(indices == torch.tensor([0, 1, 2]))
|
||||
assert torch.all(cu_seqlens == torch.tensor([0, 3], dtype=torch.int32))
|
||||
assert max_seqlen_in_batch == 3
|
||||
|
|
Loading…
Reference in New Issue