fix processors
This commit is contained in:
parent
e43809bced
commit
9f33f1edf5
|
@ -56,7 +56,7 @@ def _encode_feedback_example(
|
||||||
kl_messages = prompt + [kl_response[1]]
|
kl_messages = prompt + [kl_response[1]]
|
||||||
|
|
||||||
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||||
_, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
response_ids += [tokenizer.eos_token_id]
|
response_ids += [tokenizer.eos_token_id]
|
||||||
|
@ -65,17 +65,19 @@ def _encode_feedback_example(
|
||||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||||
|
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
|
||||||
|
|
||||||
# do not consider the kl_response
|
|
||||||
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
|
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
|
||||||
prompt_ids = prompt_ids[:source_len]
|
prompt_ids = prompt_ids[:source_len]
|
||||||
response_ids = response_ids[:target_len]
|
response_ids = response_ids[:target_len]
|
||||||
kl_response_ids = kl_response_ids[:target_len]
|
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), data_args.cutoff_len)
|
||||||
|
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
|
||||||
|
kl_response_ids = kl_response_ids[:kl_target_len]
|
||||||
|
|
||||||
input_ids = prompt_ids + response_ids
|
input_ids = prompt_ids + response_ids
|
||||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
labels = [IGNORE_INDEX] * source_len + response_ids
|
||||||
kl_input_ids = prompt_ids + kl_response_ids
|
kl_input_ids = kl_prompt_ids + kl_response_ids
|
||||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
||||||
|
|
||||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||||
|
|
||||||
|
|
|
@ -63,9 +63,9 @@ def _encode_pairwise_example(
|
||||||
rejected_ids = rejected_ids[:target_len]
|
rejected_ids = rejected_ids[:target_len]
|
||||||
|
|
||||||
chosen_input_ids = prompt_ids + chosen_ids
|
chosen_input_ids = prompt_ids + chosen_ids
|
||||||
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
|
||||||
rejected_input_ids = prompt_ids + rejected_ids
|
rejected_input_ids = prompt_ids + rejected_ids
|
||||||
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
|
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
||||||
|
|
||||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||||
|
|
||||||
|
|
|
@ -65,9 +65,9 @@ def _encode_supervised_example(
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
||||||
else:
|
else:
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
source_mask = [IGNORE_INDEX] * source_len
|
||||||
|
|
||||||
input_ids += source_ids + target_ids
|
input_ids += source_ids + target_ids
|
||||||
labels += source_mask + target_ids
|
labels += source_mask + target_ids
|
||||||
|
|
|
@ -875,8 +875,7 @@ _register_template(
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="zephyr",
|
name="zephyr",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
|
||||||
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
|
|
||||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||||
default_system="You are Zephyr, a helpful assistant.",
|
default_system="You are Zephyr, a helpful assistant.",
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue