fix aquila template, repair sft packing mechanism
This commit is contained in:
parent
e1dcb8e4dc
commit
be420e4179
|
@ -22,9 +22,6 @@ def preprocess_dataset(
|
||||||
column_names = list(next(iter(dataset)).keys())
|
column_names = list(next(iter(dataset)).keys())
|
||||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||||
|
|
||||||
if template is not None and template.efficient_eos and data_args.sft_packing:
|
|
||||||
raise ValueError("Current template is incompatible with packing.")
|
|
||||||
|
|
||||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||||
for i in range(len(examples["prompt"])):
|
for i in range(len(examples["prompt"])):
|
||||||
query, response = examples["prompt"][i], examples["response"][i]
|
query, response = examples["prompt"][i], examples["response"][i]
|
||||||
|
@ -105,9 +102,19 @@ def preprocess_dataset(
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for query, response, history, system in construct_example(examples):
|
for query, response, history, system in construct_example(examples):
|
||||||
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
|
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||||
|
tokenizer, query, response, history, system
|
||||||
|
)):
|
||||||
|
if turn_idx != 0 and template.efficient_eos:
|
||||||
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
|
else:
|
||||||
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
input_ids += source_ids + target_ids
|
input_ids += source_ids + target_ids
|
||||||
labels += source_ids + target_ids # TODO: try masking source_ids here
|
labels += source_mask + target_ids
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
input_ids += [tokenizer.eos_token_id]
|
||||||
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
total_length = len(input_ids)
|
total_length = len(input_ids)
|
||||||
block_size = data_args.cutoff_len
|
block_size = data_args.cutoff_len
|
||||||
|
|
|
@ -423,7 +423,7 @@ register_template(
|
||||||
|
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Supports: https://huggingface.co/qhduan/aquilachat-7b
|
Supports: https://huggingface.co/BAAI/AquilaChat-7B
|
||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="aquila",
|
name="aquila",
|
||||||
|
@ -439,7 +439,11 @@ register_template(
|
||||||
),
|
),
|
||||||
sep=[
|
sep=[
|
||||||
"###"
|
"###"
|
||||||
]
|
],
|
||||||
|
stop_words=[
|
||||||
|
"</s>"
|
||||||
|
],
|
||||||
|
efficient_eos=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue