diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py index 455908ae..435cf6ca 100644 --- a/src/llamafactory/data/processors/processor_utils.py +++ b/src/llamafactory/data/processors/processor_utils.py @@ -79,6 +79,9 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: + r""" + Computes the real sequence length after truncation by the cutoff_len. + """ if target_len * 2 < cutoff_len: # truncate source max_target_len = cutoff_len elif source_len * 2 < cutoff_len: # truncate target @@ -87,5 +90,6 @@ def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) new_target_len = min(max_target_len, target_len) - new_source_len = max(cutoff_len - new_target_len, 0) + max_source_len = max(cutoff_len - new_target_len, 0) + new_source_len = min(max_source_len, source_len) return new_source_len, new_target_len diff --git a/tests/data/test_processor.py b/tests/data/test_processor.py index fa8f7172..692fcaa1 100644 --- a/tests/data/test_processor.py +++ b/tests/data/test_processor.py @@ -26,6 +26,9 @@ from llamafactory.data.processors.processor_utils import infer_seqlen ((2000, 3000, 1000), (400, 600)), ((1000, 100, 1000), (900, 100)), ((100, 1000, 1000), (100, 900)), + ((100, 500, 1000), (100, 500)), + ((500, 100, 1000), (500, 100)), + ((10, 10, 1000), (10, 10)), ], ) def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):