mask_history args verify valid

This commit is contained in:
“Wzw” 2024-08-08 10:12:01 +08:00
parent b5ca86cc07
commit 2fa1e0b2ad
2 changed files with 5 additions and 0 deletions

View File

@ -206,6 +206,8 @@ def get_dataset(
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
if stage!="sft" and data_args.mask_history:
raise ValueError("`Train on the last turn only` is only valid for sft training.")
# Load tokenized dataset # Load tokenized dataset
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:

View File

@ -141,3 +141,6 @@ class DataArguments:
if self.streaming and self.max_samples is not None: if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.") raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.mask_history and self.train_on_prompt:
raise ValueError("`Train on the last turn only` does not support `train_on_prompt`.")