From c87023d539875cd8e622d40212a5627c9c182fb8 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Fri, 9 Aug 2024 18:03:00 +0800 Subject: [PATCH] follow #5115 --- examples/train_full/llama3_full_predict.yaml | 2 +- src/llamafactory/data/loader.py | 2 -- .../data/processors/supervised.py | 18 ++++++---- src/llamafactory/data/template.py | 12 +++---- src/llamafactory/hparams/data_args.py | 2 +- src/llamafactory/hparams/parser.py | 33 ++++++++++--------- 6 files changed, 35 insertions(+), 34 deletions(-) diff --git a/examples/train_full/llama3_full_predict.yaml b/examples/train_full/llama3_full_predict.yaml index 965c8e89..5d2b6028 100644 --- a/examples/train_full/llama3_full_predict.yaml +++ b/examples/train_full/llama3_full_predict.yaml @@ -7,7 +7,7 @@ do_predict: true finetuning_type: full ### dataset -eval_dataset: alpaca_en_demo +eval_dataset: identity,alpaca_en_demo template: llama3 cutoff_len: 1024 max_samples: 50 diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index df2854a8..069ea199 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -206,8 +206,6 @@ def get_dataset( template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") - if stage!="sft" and data_args.mask_history: - raise ValueError("`Train on the last turn only` is only valid for sft training.") # Load tokenized dataset if data_args.tokenized_path is not None: diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 1231b2d9..950de12a 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -53,8 +53,11 @@ def _encode_supervised_example( input_ids += [image_token_id] * getattr(processor, "image_seq_length") labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") - encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools, mask_history) + encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) total_length = 1 if template.efficient_eos else 0 + if mask_history: + encoded_pairs = encoded_pairs[::-1] # high priority for last turns + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): if total_length >= cutoff_len: break @@ -66,20 +69,23 @@ def _encode_supervised_example( if train_on_prompt: source_label = source_ids - elif turn_idx != 0 and template.efficient_eos: + elif template.efficient_eos: source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) else: source_label = [IGNORE_INDEX] * source_len - if mask_history: - target_label = target_ids if turn_idx==0 else [IGNORE_INDEX] * target_len + if mask_history and turn_idx != 0: # train on the last turn only + target_label = [IGNORE_INDEX] * target_len + else: + target_label = target_ids + + if mask_history: # reversed sequences input_ids = source_ids + target_ids + input_ids labels = source_label + target_label + labels else: - target_label = target_ids input_ids += source_ids + target_ids labels += source_label + target_label - + if template.efficient_eos: input_ids += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id] diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 096f50ec..5d4b3011 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -69,16 +69,12 @@ class Template: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - mask_history: bool = False, ) -> List[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ encoded_messages = self._encode(tokenizer, messages, system, tools) - if not mask_history: - return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] - else: - return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(len(encoded_messages)-2, -1, -2)] + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: r""" @@ -594,10 +590,10 @@ _register_template( format_separator=EmptyFormatter(slots=["\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), default_system=( - "You are an AI programming assistant, utilizing the Deepseek Coder model, " - "developed by Deepseek Company, and you only answer questions related to computer science. " + "You are an AI programming assistant, utilizing the DeepSeek Coder model, " + "developed by DeepSeek Company, and you only answer questions related to computer science. " "For politically sensitive questions, security and privacy issues, " - "and other non-computer science questions, you will refuse to answer\n" + "and other non-computer science questions, you will refuse to answer.\n" ), ) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 5a13c9cf..0cb4a56d 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -143,4 +143,4 @@ class DataArguments: raise ValueError("`max_samples` is incompatible with `streaming`.") if self.mask_history and self.train_on_prompt: - raise ValueError("`Train on the last turn only` does not support `train_on_prompt`.") + raise ValueError("`mask_history` is incompatible with `train_on_prompt`.") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index f40e5693..72a18378 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -163,11 +163,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.stage != "pt" and data_args.template is None: raise ValueError("Please specify which `template` to use.") - if finetuning_args.stage != "sft" and training_args.predict_with_generate: - raise ValueError("`predict_with_generate` cannot be set as True except SFT.") + if finetuning_args.stage != "sft": + if training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True except SFT.") - if finetuning_args.stage != "sft" and data_args.neat_packing: - raise ValueError("`neat_packing` cannot be set as True except SFT.") + if data_args.neat_packing: + raise ValueError("`neat_packing` cannot be set as True except SFT.") + + if data_args.train_on_prompt or data_args.mask_history: + raise ValueError("`train_on_prompt` or `mask_history` cannot be set as True except SFT.") if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: raise ValueError("Please enable `predict_with_generate` to save model predictions.") @@ -175,21 +179,18 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end: raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.") - if finetuning_args.stage == "ppo" and not training_args.do_train: - raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.") + if finetuning_args.stage == "ppo": + if not training_args.do_train: + raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.") - if finetuning_args.stage == "ppo" and model_args.shift_attn: - raise ValueError("PPO training is incompatible with S^2-Attn.") + if model_args.shift_attn: + raise ValueError("PPO training is incompatible with S^2-Attn.") - if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: - raise ValueError("Unsloth does not support lora reward model.") + if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: + raise ValueError("Unsloth does not support lora reward model.") - if ( - finetuning_args.stage == "ppo" - and training_args.report_to - and training_args.report_to[0] not in ["wandb", "tensorboard"] - ): - raise ValueError("PPO only accepts wandb or tensorboard logger.") + if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]: + raise ValueError("PPO only accepts wandb or tensorboard logger.") if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")