diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index c096ddc7..b8eafd00 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -22,9 +22,6 @@ def preprocess_dataset( column_names = list(next(iter(dataset)).keys()) template = get_template_and_fix_tokenizer(data_args.template, tokenizer) - if template is not None and template.efficient_eos and data_args.sft_packing: - raise ValueError("Current template is incompatible with packing.") - def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: for i in range(len(examples["prompt"])): query, response = examples["prompt"][i], examples["response"][i] @@ -105,9 +102,19 @@ def preprocess_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} input_ids, labels = [], [] for query, response, history, system in construct_example(examples): - for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system): + for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( + tokenizer, query, response, history, system + )): + if turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) input_ids += source_ids + target_ids - labels += source_ids + target_ids # TODO: try masking source_ids here + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] total_length = len(input_ids) block_size = data_args.cutoff_len diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 786ad5d1..ae486c69 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -423,7 +423,7 @@ register_template( r""" -Supports: https://huggingface.co/qhduan/aquilachat-7b +Supports: https://huggingface.co/BAAI/AquilaChat-7B """ register_template( name="aquila", @@ -439,7 +439,11 @@ register_template( ), sep=[ "###" - ] + ], + stop_words=[ + "" + ], + efficient_eos=True )