forked from p04798526/LLaMA-Factory-Mirror
fix mask_history tiny bug
This commit is contained in:
parent
18e455c232
commit
b5ca86cc07
|
@ -53,7 +53,7 @@ def _encode_supervised_example(
|
||||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||||
|
|
||||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools, mask_history)
|
||||||
total_length = 1 if template.efficient_eos else 0
|
total_length = 1 if template.efficient_eos else 0
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||||
if total_length >= cutoff_len:
|
if total_length >= cutoff_len:
|
||||||
|
@ -71,11 +71,12 @@ def _encode_supervised_example(
|
||||||
else:
|
else:
|
||||||
source_label = [IGNORE_INDEX] * source_len
|
source_label = [IGNORE_INDEX] * source_len
|
||||||
|
|
||||||
if mask_history and turn_idx != len(encoded_pairs) - 1:
|
if mask_history:
|
||||||
target_label = [IGNORE_INDEX] * target_len
|
target_label = target_ids if turn_idx==0 else [IGNORE_INDEX] * target_len
|
||||||
|
input_ids = source_ids + target_ids + input_ids
|
||||||
|
labels = source_label + target_label + labels
|
||||||
else:
|
else:
|
||||||
target_label = target_ids
|
target_label = target_ids
|
||||||
|
|
||||||
input_ids += source_ids + target_ids
|
input_ids += source_ids + target_ids
|
||||||
labels += source_label + target_label
|
labels += source_label + target_label
|
||||||
|
|
||||||
|
|
|
@ -69,12 +69,16 @@ class Template:
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
|
mask_history: bool = False,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
"""
|
"""
|
||||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||||
|
if not mask_history:
|
||||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||||
|
else:
|
||||||
|
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(len(encoded_messages)-2, -1, -2)]
|
||||||
|
|
||||||
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
r"""
|
r"""
|
||||||
|
|
Loading…
Reference in New Issue