fix reserved label len

This commit is contained in:
hiyouga 2024-02-04 17:54:26 +08:00
parent 19d33ede13
commit db0ab4d601
3 changed files with 33 additions and 13 deletions

View File

@ -55,7 +55,12 @@ def preprocess_supervised_dataset(
input_ids, labels = [], [] input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate( for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn( template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
): ):
if data_args.train_on_prompt: if data_args.train_on_prompt:
@ -143,7 +148,12 @@ def preprocess_unsupervised_dataset(
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}] messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
input_ids, labels = template.encode_oneturn( input_ids, labels = template.encode_oneturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
if template.efficient_eos: if template.efficient_eos:
@ -172,10 +182,20 @@ def preprocess_pairwise_dataset(
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn( prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len tokenizer,
chosen_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
_, rejected_ids = template.encode_oneturn( _, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len tokenizer,
rejected_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
if template.efficient_eos: if template.efficient_eos:

View File

@ -37,7 +37,7 @@ class Template:
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000, cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16, reserved_label_len: Optional[int] = 1,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
r""" r"""
Returns a single pair of token ids representing prompt and response respectively. Returns a single pair of token ids representing prompt and response respectively.
@ -57,7 +57,7 @@ class Template:
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000, cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16, reserved_label_len: Optional[int] = 1,
) -> Sequence[Tuple[List[int], List[int]]]: ) -> Sequence[Tuple[List[int], List[int]]]:
r""" r"""
Returns multiple pairs of token ids representing prompts and responses respectively. Returns multiple pairs of token ids representing prompts and responses respectively.
@ -144,10 +144,10 @@ class Template:
max_len=(cutoff_len - total_length), max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len, reserved_label_len=reserved_label_len,
) )
encoded_messages[i] = encoded_messages[i][:max_source_len] source_ids = encoded_messages[i][:max_source_len]
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len] target_ids = encoded_messages[i + 1][:max_target_len]
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1]) total_length += len(source_ids) + len(target_ids)
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1])) encoded_pairs.append((source_ids, target_ids))
return encoded_pairs return encoded_pairs

View File

@ -21,10 +21,10 @@ class DataArguments:
default="train", metadata={"help": "Which dataset split to use for training and evaluation."} default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
) )
cutoff_len: Optional[int] = field( cutoff_len: Optional[int] = field(
default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."} default=1024, metadata={"help": "The cutoff length of the model inputs after tokenization."}
) )
reserved_label_len: Optional[int] = field( reserved_label_len: Optional[int] = field(
default=1, metadata={"help": "The maximum length reserved for label after tokenization."} default=1, metadata={"help": "The minimum cutoff length reserved for label after tokenization."}
) )
train_on_prompt: Optional[bool] = field( train_on_prompt: Optional[bool] = field(
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."} default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
@ -57,7 +57,7 @@ class DataArguments:
ignore_pad_token_for_loss: Optional[bool] = field( ignore_pad_token_for_loss: Optional[bool] = field(
default=True, default=True,
metadata={ metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
}, },
) )
val_size: Optional[float] = field( val_size: Optional[float] = field(