fix aquila template, repair sft packing mechanism

This commit is contained in:
hiyouga 2023-10-10 18:49:55 +08:00
parent e1dcb8e4dc
commit be420e4179
2 changed files with 18 additions and 7 deletions

View File

@ -22,9 +22,6 @@ def preprocess_dataset(
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer) template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
if template is not None and template.efficient_eos and data_args.sft_packing:
raise ValueError("Current template is incompatible with packing.")
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i] query, response = examples["prompt"][i], examples["response"][i]
@ -105,9 +102,19 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], [] input_ids, labels = [], []
for query, response, history, system in construct_example(examples): for query, response, history, system in construct_example(examples):
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system): for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
if turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids input_ids += source_ids + target_ids
labels += source_ids + target_ids # TODO: try masking source_ids here labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
total_length = len(input_ids) total_length = len(input_ids)
block_size = data_args.cutoff_len block_size = data_args.cutoff_len

View File

@ -423,7 +423,7 @@ register_template(
r""" r"""
Supports: https://huggingface.co/qhduan/aquilachat-7b Supports: https://huggingface.co/BAAI/AquilaChat-7B
""" """
register_template( register_template(
name="aquila", name="aquila",
@ -439,7 +439,11 @@ register_template(
), ),
sep=[ sep=[
"###" "###"
] ],
stop_words=[
"</s>"
],
efficient_eos=True
) )