Merge pull request #5115 from YeQiuO/main

fix: `Train on the last turn only` truncate bug
This commit is contained in:
hoshi-hiyouga 2024-08-09 17:58:27 +08:00 committed by GitHub
commit 51542cb15f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 18 additions and 8 deletions

View File

@ -206,6 +206,8 @@ def get_dataset(
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
if stage!="sft" and data_args.mask_history:
raise ValueError("`Train on the last turn only` is only valid for sft training.")
# Load tokenized dataset # Load tokenized dataset
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:

View File

@ -53,7 +53,7 @@ def _encode_supervised_example(
input_ids += [image_token_id] * getattr(processor, "image_seq_length") input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools, mask_history)
total_length = 1 if template.efficient_eos else 0 total_length = 1 if template.efficient_eos else 0
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= cutoff_len: if total_length >= cutoff_len:
@ -71,11 +71,12 @@ def _encode_supervised_example(
else: else:
source_label = [IGNORE_INDEX] * source_len source_label = [IGNORE_INDEX] * source_len
if mask_history and turn_idx != len(encoded_pairs) - 1: if mask_history:
target_label = [IGNORE_INDEX] * target_len target_label = target_ids if turn_idx==0 else [IGNORE_INDEX] * target_len
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else: else:
target_label = target_ids target_label = target_ids
input_ids += source_ids + target_ids input_ids += source_ids + target_ids
labels += source_label + target_label labels += source_label + target_label

View File

@ -69,12 +69,16 @@ class Template:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
mask_history: bool = False,
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
Returns multiple pairs of token ids representing prompts and responses respectively. Returns multiple pairs of token ids representing prompts and responses respectively.
""" """
encoded_messages = self._encode(tokenizer, messages, system, tools) encoded_messages = self._encode(tokenizer, messages, system, tools)
if not mask_history:
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
else:
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(len(encoded_messages)-2, -1, -2)]
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
r""" r"""

View File

@ -141,3 +141,6 @@ class DataArguments:
if self.streaming and self.max_samples is not None: if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.") raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.mask_history and self.train_on_prompt:
raise ValueError("`Train on the last turn only` does not support `train_on_prompt`.")