diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 069ea199..df2854a8 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -206,6 +206,8 @@ def get_dataset( template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") + if stage!="sft" and data_args.mask_history: + raise ValueError("`Train on the last turn only` is only valid for sft training.") # Load tokenized dataset if data_args.tokenized_path is not None: diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index d4583b98..1231b2d9 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -53,7 +53,7 @@ def _encode_supervised_example( input_ids += [image_token_id] * getattr(processor, "image_seq_length") labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") - encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) + encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools, mask_history) total_length = 1 if template.efficient_eos else 0 for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): if total_length >= cutoff_len: @@ -71,14 +71,15 @@ def _encode_supervised_example( else: source_label = [IGNORE_INDEX] * source_len - if mask_history and turn_idx != len(encoded_pairs) - 1: - target_label = [IGNORE_INDEX] * target_len + if mask_history: + target_label = target_ids if turn_idx==0 else [IGNORE_INDEX] * target_len + input_ids = source_ids + target_ids + input_ids + labels = source_label + target_label + labels else: target_label = target_ids - - input_ids += source_ids + target_ids - labels += source_label + target_label - + input_ids += source_ids + target_ids + labels += source_label + target_label + if template.efficient_eos: input_ids += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id] diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index f4d73701..096f50ec 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -69,12 +69,16 @@ class Template: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, + mask_history: bool = False, ) -> List[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ encoded_messages = self._encode(tokenizer, messages, system, tools) - return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + if not mask_history: + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + else: + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(len(encoded_messages)-2, -1, -2)] def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: r""" diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index cd762c75..5a13c9cf 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -141,3 +141,6 @@ class DataArguments: if self.streaming and self.max_samples is not None: raise ValueError("`max_samples` is incompatible with `streaming`.") + + if self.mask_history and self.train_on_prompt: + raise ValueError("`Train on the last turn only` does not support `train_on_prompt`.")