diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index eec5d957..07405db5 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -74,13 +74,13 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": """ bsz = attention_mask.size(0) dtype, device = attention_mask.dtype, attention_mask.device - max_num = torch.max(attention_mask) + max_num = torch.max(attention_mask).item() counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device) for i in range(max_num): counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1) counts = counts.flatten() - seqlens = counts[counts.nonzero().squeeze()] + seqlens = counts[counts.nonzero().squeeze(dim=-1)] return seqlens diff --git a/tests/model/model_utils/test_packing.py b/tests/model/model_utils/test_packing.py index 6fd9ba3b..8056099f 100644 --- a/tests/model/model_utils/test_packing.py +++ b/tests/model/model_utils/test_packing.py @@ -28,6 +28,11 @@ def test_get_seqlens_in_batch(): assert list(seqlens_in_batch.size()) == [5] assert torch.all(seqlens_in_batch == torch.tensor([2, 3, 1, 2, 3])) + attention_mask_with_indices = torch.tensor([[1, 1, 1]]) + seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices) + assert list(seqlens_in_batch.size()) == [1] + assert torch.all(seqlens_in_batch == torch.tensor([3])) + def test_get_unpad_data(): attention_mask_with_indices = torch.tensor( @@ -40,3 +45,9 @@ def test_get_unpad_data(): assert torch.all(indices == torch.tensor([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11])) assert torch.all(cu_seqlens == torch.tensor([0, 2, 5, 6, 8, 11], dtype=torch.int32)) assert max_seqlen_in_batch == 3 + + attention_mask_with_indices = torch.tensor([[1, 1, 1]]) + indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices) + assert torch.all(indices == torch.tensor([0, 1, 2])) + assert torch.all(cu_seqlens == torch.tensor([0, 3], dtype=torch.int32)) + assert max_seqlen_in_batch == 3