fix processors

This commit is contained in:
hiyouga 2024-07-05 08:33:22 +08:00
parent e43809bced
commit 9f33f1edf5
4 changed files with 13 additions and 12 deletions

View File

@ -56,7 +56,7 @@ def _encode_feedback_example(
kl_messages = prompt + [kl_response[1]] kl_messages = prompt + [kl_response[1]]
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
_, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
if template.efficient_eos: if template.efficient_eos:
response_ids += [tokenizer.eos_token_id] response_ids += [tokenizer.eos_token_id]
@ -65,17 +65,19 @@ def _encode_feedback_example(
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
# do not consider the kl_response
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len) source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len] response_ids = response_ids[:target_len]
kl_response_ids = kl_response_ids[:target_len] kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), data_args.cutoff_len)
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
kl_response_ids = kl_response_ids[:kl_target_len]
input_ids = prompt_ids + response_ids input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = prompt_ids + kl_response_ids kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag return input_ids, labels, kl_input_ids, kl_labels, kto_tag

View File

@ -63,9 +63,9 @@ def _encode_pairwise_example(
rejected_ids = rejected_ids[:target_len] rejected_ids = rejected_ids[:target_len]
chosen_input_ids = prompt_ids + chosen_ids chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels

View File

@ -65,9 +65,9 @@ def _encode_supervised_example(
if data_args.train_on_prompt: if data_args.train_on_prompt:
source_mask = source_ids source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos: elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else: else:
source_mask = [IGNORE_INDEX] * len(source_ids) source_mask = [IGNORE_INDEX] * source_len
input_ids += source_ids + target_ids input_ids += source_ids + target_ids
labels += source_mask + target_ids labels += source_mask + target_ids

View File

@ -875,8 +875,7 @@ _register_template(
_register_template( _register_template(
name="zephyr", name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are Zephyr, a helpful assistant.", default_system="You are Zephyr, a helpful assistant.",
) )