fix data entry
This commit is contained in:
parent
6bf4c1274f
commit
354f13c01a
|
@ -75,11 +75,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||||
role_mapping = {
|
role_mapping = {
|
||||||
Role.USER: DataRole.USER,
|
Role.USER: DataRole.USER.value,
|
||||||
Role.ASSISTANT: DataRole.ASSISTANT,
|
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||||
Role.SYSTEM: DataRole.SYSTEM,
|
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||||
Role.FUNCTION: DataRole.FUNCTION,
|
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||||
Role.TOOL: DataRole.OBSERVATION,
|
Role.TOOL: DataRole.OBSERVATION.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.get("/v1/models", response_model=ModelList)
|
@app.get("/v1/models", response_model=ModelList)
|
||||||
|
@ -95,7 +95,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||||
|
|
||||||
if role_mapping[request.messages[0].role] == DataRole.SYSTEM:
|
if request.messages[0].role == Role.SYSTEM:
|
||||||
system = request.messages.pop(0).content
|
system = request.messages.pop(0).content
|
||||||
else:
|
else:
|
||||||
system = ""
|
system = ""
|
||||||
|
@ -105,11 +105,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
|
|
||||||
input_messages = []
|
input_messages = []
|
||||||
for i, message in enumerate(request.messages):
|
for i, message in enumerate(request.messages):
|
||||||
|
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
|
||||||
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||||
if i % 2 == 0 and input_messages[i]["role"] not in [DataRole.USER, DataRole.OBSERVATION]:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
|
||||||
elif i % 2 == 1 and input_messages[i]["role"] not in [DataRole.ASSISTANT, DataRole.FUNCTION]:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
|
||||||
|
|
||||||
tool_list = request.tools
|
tool_list = request.tools
|
||||||
if isinstance(tool_list, list) and len(tool_list):
|
if isinstance(tool_list, list) and len(tool_list):
|
||||||
|
|
|
@ -19,8 +19,8 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||||
prompt = []
|
prompt = []
|
||||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||||
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||||
prompt.append({"role": Role.USER, "content": old_prompt})
|
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||||
prompt.append({"role": Role.ASSISTANT, "content": old_response})
|
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||||
|
|
||||||
content = []
|
content = []
|
||||||
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
||||||
|
@ -29,12 +29,14 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||||
content.append(examples[dataset_attr.query][i])
|
content.append(examples[dataset_attr.query][i])
|
||||||
|
|
||||||
prompt.append({"role": Role.USER, "content": "\n".join(content)})
|
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||||
|
|
||||||
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||||
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
|
response = [
|
||||||
|
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||||
|
]
|
||||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||||
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||||
else:
|
else:
|
||||||
response = []
|
response = []
|
||||||
|
|
||||||
|
@ -49,11 +51,11 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||||
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||||
tag_mapping = {
|
tag_mapping = {
|
||||||
dataset_attr.user_tag: Role.USER,
|
dataset_attr.user_tag: Role.USER.value,
|
||||||
dataset_attr.assistant_tag: Role.ASSISTANT,
|
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||||
dataset_attr.observation_tag: Role.OBSERVATION,
|
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||||
dataset_attr.function_tag: Role.FUNCTION,
|
dataset_attr.function_tag: Role.FUNCTION.value,
|
||||||
dataset_attr.system_tag: Role.SYSTEM,
|
dataset_attr.system_tag: Role.SYSTEM.value,
|
||||||
}
|
}
|
||||||
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||||
|
|
|
@ -146,7 +146,7 @@ def preprocess_unsupervised_dataset(
|
||||||
if len(examples["response"][i]) == 1:
|
if len(examples["response"][i]) == 1:
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
else:
|
else:
|
||||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
|
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(
|
input_ids, labels = template.encode_oneturn(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
|
|
@ -88,16 +88,16 @@ class Template:
|
||||||
elif i > 0 and i % 2 == 0:
|
elif i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER:
|
if message["role"] == Role.USER.value:
|
||||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||||
elif message["role"] == Role.ASSISTANT:
|
elif message["role"] == Role.ASSISTANT.value:
|
||||||
elements += self.format_assistant.apply(content=message["content"])
|
elements += self.format_assistant.apply(content=message["content"])
|
||||||
elif message["role"] == Role.OBSERVATION:
|
elif message["role"] == Role.OBSERVATION.value:
|
||||||
elements += self.format_observation.apply(content=message["content"])
|
elements += self.format_observation.apply(content=message["content"])
|
||||||
elif message["role"] == Role.FUNCTION:
|
elif message["role"] == Role.FUNCTION.value:
|
||||||
elements += self.format_function.apply(content=message["content"])
|
elements += self.format_function.apply(content=message["content"])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||||
|
|
||||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
|
@ -179,16 +179,16 @@ class Llama2Template(Template):
|
||||||
elif i > 0 and i % 2 == 0:
|
elif i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER:
|
if message["role"] == Role.USER.value:
|
||||||
elements += self.format_user.apply(content=system_text + message["content"])
|
elements += self.format_user.apply(content=system_text + message["content"])
|
||||||
elif message["role"] == Role.ASSISTANT:
|
elif message["role"] == Role.ASSISTANT.value:
|
||||||
elements += self.format_assistant.apply(content=message["content"])
|
elements += self.format_assistant.apply(content=message["content"])
|
||||||
elif message["role"] == Role.OBSERVATION:
|
elif message["role"] == Role.OBSERVATION.value:
|
||||||
elements += self.format_observation.apply(content=message["content"])
|
elements += self.format_observation.apply(content=message["content"])
|
||||||
elif message["role"] == Role.FUNCTION:
|
elif message["role"] == Role.FUNCTION.value:
|
||||||
elements += self.format_function.apply(content=message["content"])
|
elements += self.format_function.apply(content=message["content"])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||||
|
|
||||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
|
|
|
@ -115,7 +115,7 @@ class WebChatModel(ChatModel):
|
||||||
temperature: float,
|
temperature: float,
|
||||||
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
||||||
chatbot.append([query, ""])
|
chatbot.append([query, ""])
|
||||||
query_messages = messages + [{"role": Role.USER, "content": query}]
|
query_messages = messages + [{"role": Role.USER.value, "content": query}]
|
||||||
response = ""
|
response = ""
|
||||||
for new_text in self.stream_chat(
|
for new_text in self.stream_chat(
|
||||||
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||||
|
@ -130,10 +130,10 @@ class WebChatModel(ChatModel):
|
||||||
name, arguments = result
|
name, arguments = result
|
||||||
arguments = json.loads(arguments)
|
arguments = json.loads(arguments)
|
||||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||||
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
|
output_messages = query_messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
|
||||||
bot_text = "```json\n" + tool_call + "\n```"
|
bot_text = "```json\n" + tool_call + "\n```"
|
||||||
else:
|
else:
|
||||||
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
|
output_messages = query_messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||||
bot_text = result
|
bot_text = result
|
||||||
|
|
||||||
chatbot[-1] = [query, self.postprocess(bot_text)]
|
chatbot[-1] = [query, self.postprocess(bot_text)]
|
||||||
|
|
Loading…
Reference in New Issue