fix data entry
This commit is contained in:
parent
6bf4c1274f
commit
354f13c01a
|
@ -75,11 +75,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
|
||||
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||
role_mapping = {
|
||||
Role.USER: DataRole.USER,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT,
|
||||
Role.SYSTEM: DataRole.SYSTEM,
|
||||
Role.FUNCTION: DataRole.FUNCTION,
|
||||
Role.TOOL: DataRole.OBSERVATION,
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||
Role.TOOL: DataRole.OBSERVATION.value,
|
||||
}
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
|
@ -95,7 +95,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
if role_mapping[request.messages[0].role] == DataRole.SYSTEM:
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = ""
|
||||
|
@ -105,11 +105,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
|
||||
input_messages = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||
if i % 2 == 0 and input_messages[i]["role"] not in [DataRole.USER, DataRole.OBSERVATION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and input_messages[i]["role"] not in [DataRole.ASSISTANT, DataRole.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
tool_list = request.tools
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
|
|
|
@ -19,8 +19,8 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
|||
prompt = []
|
||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||
prompt.append({"role": Role.USER, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT, "content": old_response})
|
||||
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||
|
||||
content = []
|
||||
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
||||
|
@ -29,12 +29,14 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
|||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||
content.append(examples[dataset_attr.query][i])
|
||||
|
||||
prompt.append({"role": Role.USER, "content": "\n".join(content)})
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||
|
||||
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else:
|
||||
response = []
|
||||
|
||||
|
@ -49,11 +51,11 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
|||
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION,
|
||||
dataset_attr.function_tag: Role.FUNCTION,
|
||||
dataset_attr.system_tag: Role.SYSTEM,
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||
dataset_attr.function_tag: Role.FUNCTION.value,
|
||||
dataset_attr.system_tag: Role.SYSTEM.value,
|
||||
}
|
||||
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||
|
|
|
@ -146,7 +146,7 @@ def preprocess_unsupervised_dataset(
|
|||
if len(examples["response"][i]) == 1:
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
else:
|
||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
|
||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer,
|
||||
|
|
|
@ -88,16 +88,16 @@ class Template:
|
|||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
|
@ -179,16 +179,16 @@ class Llama2Template(Template):
|
|||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=system_text + message["content"])
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
|
|
|
@ -115,7 +115,7 @@ class WebChatModel(ChatModel):
|
|||
temperature: float,
|
||||
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
query_messages = messages + [{"role": Role.USER, "content": query}]
|
||||
query_messages = messages + [{"role": Role.USER.value, "content": query}]
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
|
@ -130,10 +130,10 @@ class WebChatModel(ChatModel):
|
|||
name, arguments = result
|
||||
arguments = json.loads(arguments)
|
||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
|
||||
output_messages = query_messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
|
||||
bot_text = "```json\n" + tool_call + "\n```"
|
||||
else:
|
||||
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
|
||||
output_messages = query_messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||
bot_text = result
|
||||
|
||||
chatbot[-1] = [query, self.postprocess(bot_text)]
|
||||
|
|
Loading…
Reference in New Issue