enable cutoff len
This commit is contained in:
parent
83dbfce8c3
commit
f1067d2b58
|
@ -198,12 +198,12 @@
|
||||||
"hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k",
|
"hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k",
|
||||||
"formatting": "sharegpt"
|
"formatting": "sharegpt"
|
||||||
},
|
},
|
||||||
"glaive_func_call": {
|
"glaive_toolcall": {
|
||||||
"file_name": "glaive_func_call.json",
|
"file_name": "glaive_toolcall_10k.json",
|
||||||
"formatting": "sharegpt",
|
"formatting": "sharegpt",
|
||||||
"columns": {
|
"columns": {
|
||||||
"messages": "conversations",
|
"messages": "conversations",
|
||||||
"tool": "tools"
|
"tools": "tools"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"hh_rlhf_en": {
|
"hh_rlhf_en": {
|
||||||
|
|
|
@ -1,68 +0,0 @@
|
||||||
[
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "I need a new password. Can you generate one for me?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "Of course. How long would you like your password to be? And would you like it to include symbols?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "I would like it to be 12 characters long and yes, please include symbols."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "function_call",
|
|
||||||
"value": "{\"name\": \"generate_password\", \"arguments\": {\"length\": 12, \"include_symbols\": true}}"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "observation",
|
|
||||||
"value": "{\"password\": \"4&7j#9@1Q6*\"}"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "Here is your new password: 4&7j#9@1Q6*. Please make sure to save it in a secure location."
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"tools": "[{\"name\": \"generate_password\", \"description\": \"Generate a random password\", \"parameters\": {\"type\": \"object\", \"properties\": {\"length\": {\"type\": \"integer\", \"description\": \"The length of the password\"}, \"include_symbols\": {\"type\": \"boolean\", \"description\": \"Whether to include symbols in the password\"}}, \"required\": [\"length\"]}}, {\"name\": \"create_task\", \"description\": \"Create a new task in a task management system\", \"parameters\": {\"type\": \"object\", \"properties\": {\"title\": {\"type\": \"string\", \"description\": \"The title of the task\"}, \"due_date\": {\"type\": \"string\", \"format\": \"date\", \"description\": \"The due date of the task\"}, \"priority\": {\"type\": \"string\", \"enum\": [\"low\", \"medium\", \"high\"], \"description\": \"The priority of the task\"}}, \"required\": [\"title\", \"due_date\", \"priority\"]}}]"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "Can you tell me the latest news headlines for the United States?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "function_call",
|
|
||||||
"value": "{\"name\": \"get_news_headlines\", \"arguments\": {\"country\": \"United States\"}}"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "observation",
|
|
||||||
"value": "{\"headlines\": [\"Biden announces new vaccine mandates\", \"Hurricane Ida devastates Louisiana\", \"Apple unveils new iPhone\", \"NASA's Perseverance rover collects first Mars rock sample\"]}"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "Here are the latest news headlines for the United States:"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "That's interesting. What about the news in France?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "function_call",
|
|
||||||
"value": "{\"name\": \"get_news_headlines\", \"arguments\": {\"country\": \"France\"}}"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "observation",
|
|
||||||
"value": "{\"headlines\": [\"France recalls ambassadors to US and Australia\", \"French election: Macron's party braces for tough fight\", \"Louvre Museum to undergo major overhaul\", \"France to offer free birth control to all women under 25\"]}"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "Here are the latest news headlines for France:"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"tools": "[{\"name\": \"get_news_headlines\", \"description\": \"Get the latest news headlines\", \"parameters\": {\"type\": \"object\", \"properties\": {\"country\": {\"type\": \"string\", \"description\": \"The country for which to fetch news\"}}, \"required\": [\"country\"]}}]"
|
|
||||||
}
|
|
||||||
]
|
|
File diff suppressed because one or more lines are too long
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tool": []}
|
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||||
for i in range(len(examples[dataset_attr.prompt])):
|
for i in range(len(examples[dataset_attr.prompt])):
|
||||||
prompt = []
|
prompt = []
|
||||||
if dataset_attr.history:
|
if dataset_attr.history:
|
||||||
|
@ -33,13 +33,13 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||||
outputs["prompt"].append(prompt)
|
outputs["prompt"].append(prompt)
|
||||||
outputs["response"].append(response)
|
outputs["response"].append(response)
|
||||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||||
outputs["tool"].append("")
|
outputs["tools"].append("")
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tool": []}
|
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||||
tag_mapping = {
|
tag_mapping = {
|
||||||
dataset_attr.user_tag: Role.USER,
|
dataset_attr.user_tag: Role.USER,
|
||||||
dataset_attr.assistant_tag: Role.ASSISTANT,
|
dataset_attr.assistant_tag: Role.ASSISTANT,
|
||||||
|
@ -69,7 +69,7 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
||||||
outputs["prompt"].append(prompt)
|
outputs["prompt"].append(prompt)
|
||||||
outputs["response"].append(response)
|
outputs["response"].append(response)
|
||||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||||
outputs["tool"].append(examples[dataset_attr.tool][i] if dataset_attr.tool else "")
|
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ def align_dataset(
|
||||||
prompt: [{"role": "user", "content": "..."}]
|
prompt: [{"role": "user", "content": "..."}]
|
||||||
response: [{"role": "assistant", "content": "..."}]
|
response: [{"role": "assistant", "content": "..."}]
|
||||||
system: "..."
|
system: "..."
|
||||||
tool: "..."
|
tools: "..."
|
||||||
"""
|
"""
|
||||||
if dataset_attr.formatting == "alpaca":
|
if dataset_attr.formatting == "alpaca":
|
||||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
|
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
|
||||||
|
|
|
@ -93,6 +93,9 @@ class ToolFormatter:
|
||||||
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
|
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
|
||||||
try:
|
try:
|
||||||
tools = json.loads(content)
|
tools = json.loads(content)
|
||||||
|
if not len(tools):
|
||||||
|
return [""]
|
||||||
|
|
||||||
if self.type == "default":
|
if self.type == "default":
|
||||||
return [self._default(tools)]
|
return [self._default(tools)]
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
|
|
@ -29,7 +29,7 @@ class DatasetAttr:
|
||||||
history: Optional[str] = None
|
history: Optional[str] = None
|
||||||
|
|
||||||
messages: Optional[str] = "conversations"
|
messages: Optional[str] = "conversations"
|
||||||
tool: Optional[str] = None
|
tools: Optional[str] = None
|
||||||
|
|
||||||
role_tag: Optional[str] = "from"
|
role_tag: Optional[str] = "from"
|
||||||
content_tag: Optional[str] = "value"
|
content_tag: Optional[str] = "value"
|
||||||
|
@ -86,7 +86,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||||
if dataset_attr.formatting == "alpaca":
|
if dataset_attr.formatting == "alpaca":
|
||||||
column_names = ["prompt", "query", "response", "history"]
|
column_names = ["prompt", "query", "response", "history"]
|
||||||
else:
|
else:
|
||||||
column_names = ["messages", "tool"]
|
column_names = ["messages", "tools"]
|
||||||
|
|
||||||
column_names += ["system"]
|
column_names += ["system"]
|
||||||
for column_name in column_names:
|
for column_name in column_names:
|
||||||
|
|
|
@ -58,7 +58,7 @@ def preprocess_supervised_dataset(
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||||
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
|
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||||
)):
|
)):
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
|
@ -97,7 +97,7 @@ def preprocess_packed_supervised_dataset(
|
||||||
|
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||||
tokenizer, messages, examples["system"][i], examples["tool"][i]
|
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||||
)):
|
)):
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
|
@ -141,7 +141,7 @@ def preprocess_unsupervised_dataset(
|
||||||
|
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
input_ids, labels = template.encode_oneturn(
|
input_ids, labels = template.encode_oneturn(
|
||||||
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
|
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||||
)
|
)
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
|
@ -170,10 +170,10 @@ def preprocess_pairwise_dataset(
|
||||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||||
|
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||||
tokenizer, chosen_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
|
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||||
)
|
)
|
||||||
_, rejected_ids = template.encode_oneturn(
|
_, rejected_ids = template.encode_oneturn(
|
||||||
tokenizer, rejected_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
|
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||||
)
|
)
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
|
|
|
@ -95,7 +95,21 @@ class Template:
|
||||||
|
|
||||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)]
|
# TODO: need to improve
|
||||||
|
encoded_pairs = []
|
||||||
|
total_length = 0
|
||||||
|
for i in range(0, len(encoded_messages), 2):
|
||||||
|
if total_length >= cutoff_len:
|
||||||
|
break
|
||||||
|
|
||||||
|
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
|
||||||
|
total_length += len(encoded_messages[i])
|
||||||
|
|
||||||
|
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
|
||||||
|
total_length += len(encoded_messages[i+1])
|
||||||
|
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
|
||||||
|
|
||||||
|
return encoded_pairs
|
||||||
|
|
||||||
def _convert_elements_to_ids(
|
def _convert_elements_to_ids(
|
||||||
self,
|
self,
|
||||||
|
@ -161,7 +175,21 @@ class Llama2Template(Template):
|
||||||
|
|
||||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)]
|
# TODO: need to improve
|
||||||
|
encoded_pairs = []
|
||||||
|
total_length = 0
|
||||||
|
for i in range(0, len(encoded_messages), 2):
|
||||||
|
if total_length >= cutoff_len:
|
||||||
|
break
|
||||||
|
|
||||||
|
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
|
||||||
|
total_length += len(encoded_messages[i])
|
||||||
|
|
||||||
|
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
|
||||||
|
total_length += len(encoded_messages[i+1])
|
||||||
|
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
|
||||||
|
|
||||||
|
return encoded_pairs
|
||||||
|
|
||||||
|
|
||||||
templates: Dict[str, Template] = {}
|
templates: Dict[str, Template] = {}
|
||||||
|
|
Loading…
Reference in New Issue