fix #4683
This commit is contained in:
parent
ed232311e8
commit
e43809bced
|
@ -79,6 +79,9 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
|
||||||
|
|
||||||
|
|
||||||
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
||||||
|
r"""
|
||||||
|
Computes the real sequence length after truncation by the cutoff_len.
|
||||||
|
"""
|
||||||
if target_len * 2 < cutoff_len: # truncate source
|
if target_len * 2 < cutoff_len: # truncate source
|
||||||
max_target_len = cutoff_len
|
max_target_len = cutoff_len
|
||||||
elif source_len * 2 < cutoff_len: # truncate target
|
elif source_len * 2 < cutoff_len: # truncate target
|
||||||
|
@ -87,5 +90,6 @@ def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int
|
||||||
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
||||||
|
|
||||||
new_target_len = min(max_target_len, target_len)
|
new_target_len = min(max_target_len, target_len)
|
||||||
new_source_len = max(cutoff_len - new_target_len, 0)
|
max_source_len = max(cutoff_len - new_target_len, 0)
|
||||||
|
new_source_len = min(max_source_len, source_len)
|
||||||
return new_source_len, new_target_len
|
return new_source_len, new_target_len
|
||||||
|
|
|
@ -26,6 +26,9 @@ from llamafactory.data.processors.processor_utils import infer_seqlen
|
||||||
((2000, 3000, 1000), (400, 600)),
|
((2000, 3000, 1000), (400, 600)),
|
||||||
((1000, 100, 1000), (900, 100)),
|
((1000, 100, 1000), (900, 100)),
|
||||||
((100, 1000, 1000), (100, 900)),
|
((100, 1000, 1000), (100, 900)),
|
||||||
|
((100, 500, 1000), (100, 500)),
|
||||||
|
((500, 100, 1000), (500, 100)),
|
||||||
|
((10, 10, 1000), (10, 10)),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):
|
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):
|
||||||
|
|
Loading…
Reference in New Issue