From 2fa1e0b2add60142c178e5e21ebaad7132fa5b00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CWzw=E2=80=9D?= <15700181042@163.com> Date: Thu, 8 Aug 2024 10:12:01 +0800 Subject: [PATCH] mask_history args verify valid --- src/llamafactory/data/loader.py | 2 ++ src/llamafactory/hparams/data_args.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 069ea199..df2854a8 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -206,6 +206,8 @@ def get_dataset( template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") + if stage!="sft" and data_args.mask_history: + raise ValueError("`Train on the last turn only` is only valid for sft training.") # Load tokenized dataset if data_args.tokenized_path is not None: diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index cd762c75..5a13c9cf 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -141,3 +141,6 @@ class DataArguments: if self.streaming and self.max_samples is not None: raise ValueError("`max_samples` is incompatible with `streaming`.") + + if self.mask_history and self.train_on_prompt: + raise ValueError("`Train on the last turn only` does not support `train_on_prompt`.")