fix reserved label len
This commit is contained in:
parent
19d33ede13
commit
db0ab4d601
|
@ -55,7 +55,12 @@ def preprocess_supervised_dataset(
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||||
template.encode_multiturn(
|
template.encode_multiturn(
|
||||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
tokenizer,
|
||||||
|
messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
|
@ -143,7 +148,12 @@ def preprocess_unsupervised_dataset(
|
||||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
|
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(
|
input_ids, labels = template.encode_oneturn(
|
||||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
tokenizer,
|
||||||
|
messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
|
@ -172,10 +182,20 @@ def preprocess_pairwise_dataset(
|
||||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||||
|
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||||
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
tokenizer,
|
||||||
|
chosen_messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
)
|
)
|
||||||
_, rejected_ids = template.encode_oneturn(
|
_, rejected_ids = template.encode_oneturn(
|
||||||
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
tokenizer,
|
||||||
|
rejected_messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
|
|
|
@ -37,7 +37,7 @@ class Template:
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: Optional[int] = 1_000_000,
|
cutoff_len: Optional[int] = 1_000_000,
|
||||||
reserved_label_len: Optional[int] = 16,
|
reserved_label_len: Optional[int] = 1,
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
r"""
|
r"""
|
||||||
Returns a single pair of token ids representing prompt and response respectively.
|
Returns a single pair of token ids representing prompt and response respectively.
|
||||||
|
@ -57,7 +57,7 @@ class Template:
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: Optional[int] = 1_000_000,
|
cutoff_len: Optional[int] = 1_000_000,
|
||||||
reserved_label_len: Optional[int] = 16,
|
reserved_label_len: Optional[int] = 1,
|
||||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
|
@ -144,10 +144,10 @@ class Template:
|
||||||
max_len=(cutoff_len - total_length),
|
max_len=(cutoff_len - total_length),
|
||||||
reserved_label_len=reserved_label_len,
|
reserved_label_len=reserved_label_len,
|
||||||
)
|
)
|
||||||
encoded_messages[i] = encoded_messages[i][:max_source_len]
|
source_ids = encoded_messages[i][:max_source_len]
|
||||||
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
|
target_ids = encoded_messages[i + 1][:max_target_len]
|
||||||
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
|
total_length += len(source_ids) + len(target_ids)
|
||||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
encoded_pairs.append((source_ids, target_ids))
|
||||||
|
|
||||||
return encoded_pairs
|
return encoded_pairs
|
||||||
|
|
||||||
|
|
|
@ -21,10 +21,10 @@ class DataArguments:
|
||||||
default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
|
default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||||
)
|
)
|
||||||
cutoff_len: Optional[int] = field(
|
cutoff_len: Optional[int] = field(
|
||||||
default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."}
|
default=1024, metadata={"help": "The cutoff length of the model inputs after tokenization."}
|
||||||
)
|
)
|
||||||
reserved_label_len: Optional[int] = field(
|
reserved_label_len: Optional[int] = field(
|
||||||
default=1, metadata={"help": "The maximum length reserved for label after tokenization."}
|
default=1, metadata={"help": "The minimum cutoff length reserved for label after tokenization."}
|
||||||
)
|
)
|
||||||
train_on_prompt: Optional[bool] = field(
|
train_on_prompt: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
|
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
|
||||||
|
@ -57,7 +57,7 @@ class DataArguments:
|
||||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
ignore_pad_token_for_loss: Optional[bool] = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
val_size: Optional[float] = field(
|
val_size: Optional[float] = field(
|
||||||
|
|
Loading…
Reference in New Issue