From 9f33f1edf544807e498f60881f30b00149fe570f Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Fri, 5 Jul 2024 08:33:22 +0800 Subject: [PATCH] fix processors --- src/llamafactory/data/processors/feedback.py | 14 ++++++++------ src/llamafactory/data/processors/pairwise.py | 4 ++-- src/llamafactory/data/processors/supervised.py | 4 ++-- src/llamafactory/data/template.py | 3 +-- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index 7ba05e23..8eadeda0 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -56,7 +56,7 @@ def _encode_feedback_example( kl_messages = prompt + [kl_response[1]] prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) - _, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) + kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) if template.efficient_eos: response_ids += [tokenizer.eos_token_id] @@ -65,17 +65,19 @@ def _encode_feedback_example( if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids - # do not consider the kl_response source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len) prompt_ids = prompt_ids[:source_len] response_ids = response_ids[:target_len] - kl_response_ids = kl_response_ids[:target_len] + kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), data_args.cutoff_len) + kl_prompt_ids = kl_prompt_ids[:kl_source_len] + kl_response_ids = kl_response_ids[:kl_target_len] input_ids = prompt_ids + response_ids - labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids - kl_input_ids = prompt_ids + kl_response_ids - kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids + labels = [IGNORE_INDEX] * source_len + response_ids + kl_input_ids = kl_prompt_ids + kl_response_ids + kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids return input_ids, labels, kl_input_ids, kl_labels, kto_tag diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index c6001e6e..9084c683 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -63,9 +63,9 @@ def _encode_pairwise_example( rejected_ids = rejected_ids[:target_len] chosen_input_ids = prompt_ids + chosen_ids - chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids + chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids rejected_input_ids = prompt_ids + rejected_ids - rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids + rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 747a0c1b..141054f4 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -65,9 +65,9 @@ def _encode_supervised_example( if data_args.train_on_prompt: source_mask = source_ids elif turn_idx != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) else: - source_mask = [IGNORE_INDEX] * len(source_ids) + source_mask = [IGNORE_INDEX] * source_len input_ids += source_ids + target_ids labels += source_mask + target_ids diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index aefd5195..9f49ac92 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -875,8 +875,7 @@ _register_template( _register_template( name="zephyr", - format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), - format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]), + format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), default_system="You are Zephyr, a helpful assistant.", )