From b5ca86cc07d38cf342e351aab16cce4319245792 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CWzw=E2=80=9D?= <15700181042@163.com> Date: Thu, 8 Aug 2024 10:09:33 +0800 Subject: [PATCH] fix mask_history tiny bug --- src/llamafactory/data/processors/supervised.py | 15 ++++++++------- src/llamafactory/data/template.py | 6 +++++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index d4583b98..1231b2d9 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -53,7 +53,7 @@ def _encode_supervised_example( input_ids += [image_token_id] * getattr(processor, "image_seq_length") labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") - encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) + encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools, mask_history) total_length = 1 if template.efficient_eos else 0 for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): if total_length >= cutoff_len: @@ -71,14 +71,15 @@ def _encode_supervised_example( else: source_label = [IGNORE_INDEX] * source_len - if mask_history and turn_idx != len(encoded_pairs) - 1: - target_label = [IGNORE_INDEX] * target_len + if mask_history: + target_label = target_ids if turn_idx==0 else [IGNORE_INDEX] * target_len + input_ids = source_ids + target_ids + input_ids + labels = source_label + target_label + labels else: target_label = target_ids - - input_ids += source_ids + target_ids - labels += source_label + target_label - + input_ids += source_ids + target_ids + labels += source_label + target_label + if template.efficient_eos: input_ids += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id] diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index db0393d1..fc5994fe 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -69,12 +69,16 @@ class Template: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, + mask_history: bool = False, ) -> List[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ encoded_messages = self._encode(tokenizer, messages, system, tools) - return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + if not mask_history: + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + else: + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(len(encoded_messages)-2, -1, -2)] def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: r"""