diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 7b1560d3..b47734e0 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -75,11 +75,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1))) role_mapping = { - Role.USER: DataRole.USER, - Role.ASSISTANT: DataRole.ASSISTANT, - Role.SYSTEM: DataRole.SYSTEM, - Role.FUNCTION: DataRole.FUNCTION, - Role.TOOL: DataRole.OBSERVATION, + Role.USER: DataRole.USER.value, + Role.ASSISTANT: DataRole.ASSISTANT.value, + Role.SYSTEM: DataRole.SYSTEM.value, + Role.FUNCTION: DataRole.FUNCTION.value, + Role.TOOL: DataRole.OBSERVATION.value, } @app.get("/v1/models", response_model=ModelList) @@ -95,7 +95,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": if len(request.messages) == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") - if role_mapping[request.messages[0].role] == DataRole.SYSTEM: + if request.messages[0].role == Role.SYSTEM: system = request.messages.pop(0).content else: system = "" @@ -105,11 +105,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": input_messages = [] for i, message in enumerate(request.messages): + if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + input_messages.append({"role": role_mapping[message.role], "content": message.content}) - if i % 2 == 0 and input_messages[i]["role"] not in [DataRole.USER, DataRole.OBSERVATION]: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") - elif i % 2 == 1 and input_messages[i]["role"] not in [DataRole.ASSISTANT, DataRole.FUNCTION]: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") tool_list = request.tools if isinstance(tool_list, list) and len(tool_list): diff --git a/src/llmtuner/data/aligner.py b/src/llmtuner/data/aligner.py index fbf3a32d..4de37e6d 100644 --- a/src/llmtuner/data/aligner.py +++ b/src/llmtuner/data/aligner.py @@ -19,8 +19,8 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") prompt = [] if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): for old_prompt, old_response in examples[dataset_attr.history][i]: - prompt.append({"role": Role.USER, "content": old_prompt}) - prompt.append({"role": Role.ASSISTANT, "content": old_response}) + prompt.append({"role": Role.USER.value, "content": old_prompt}) + prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) content = [] if dataset_attr.prompt and examples[dataset_attr.prompt][i]: @@ -29,12 +29,14 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") if dataset_attr.query and examples[dataset_attr.query][i]: content.append(examples[dataset_attr.query][i]) - prompt.append({"role": Role.USER, "content": "\n".join(content)}) + prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list): - response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]] + response = [ + {"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i] + ] elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): - response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}] + response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}] else: response = [] @@ -49,11 +51,11 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: outputs = {"prompt": [], "response": [], "system": [], "tools": []} tag_mapping = { - dataset_attr.user_tag: Role.USER, - dataset_attr.assistant_tag: Role.ASSISTANT, - dataset_attr.observation_tag: Role.OBSERVATION, - dataset_attr.function_tag: Role.FUNCTION, - dataset_attr.system_tag: Role.SYSTEM, + dataset_attr.user_tag: Role.USER.value, + dataset_attr.assistant_tag: Role.ASSISTANT.value, + dataset_attr.observation_tag: Role.OBSERVATION.value, + dataset_attr.function_tag: Role.FUNCTION.value, + dataset_attr.system_tag: Role.SYSTEM.value, } odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 14f2a388..60dccefe 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -146,7 +146,7 @@ def preprocess_unsupervised_dataset( if len(examples["response"][i]) == 1: messages = examples["prompt"][i] + examples["response"][i] else: - messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}] + messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}] input_ids, labels = template.encode_oneturn( tokenizer, diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 43e92d65..77f93c5e 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -88,16 +88,16 @@ class Template: elif i > 0 and i % 2 == 0: elements += self.format_separator.apply() - if message["role"] == Role.USER: + if message["role"] == Role.USER.value: elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) - elif message["role"] == Role.ASSISTANT: + elif message["role"] == Role.ASSISTANT.value: elements += self.format_assistant.apply(content=message["content"]) - elif message["role"] == Role.OBSERVATION: + elif message["role"] == Role.OBSERVATION.value: elements += self.format_observation.apply(content=message["content"]) - elif message["role"] == Role.FUNCTION: + elif message["role"] == Role.FUNCTION.value: elements += self.format_function.apply(content=message["content"]) else: - raise NotImplementedError + raise NotImplementedError("Unexpected role: {}".format(message["role"])) encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) @@ -179,16 +179,16 @@ class Llama2Template(Template): elif i > 0 and i % 2 == 0: elements += self.format_separator.apply() - if message["role"] == Role.USER: + if message["role"] == Role.USER.value: elements += self.format_user.apply(content=system_text + message["content"]) - elif message["role"] == Role.ASSISTANT: + elif message["role"] == Role.ASSISTANT.value: elements += self.format_assistant.apply(content=message["content"]) - elif message["role"] == Role.OBSERVATION: + elif message["role"] == Role.OBSERVATION.value: elements += self.format_observation.apply(content=message["content"]) - elif message["role"] == Role.FUNCTION: + elif message["role"] == Role.FUNCTION.value: elements += self.format_function.apply(content=message["content"]) else: - raise NotImplementedError + raise NotImplementedError("Unexpected role: {}".format(message["role"])) encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index de8b0ca0..41d657ff 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -115,7 +115,7 @@ class WebChatModel(ChatModel): temperature: float, ) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]: chatbot.append([query, ""]) - query_messages = messages + [{"role": Role.USER, "content": query}] + query_messages = messages + [{"role": Role.USER.value, "content": query}] response = "" for new_text in self.stream_chat( query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature @@ -130,10 +130,10 @@ class WebChatModel(ChatModel): name, arguments = result arguments = json.loads(arguments) tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False) - output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}] + output_messages = query_messages + [{"role": Role.FUNCTION.value, "content": tool_call}] bot_text = "```json\n" + tool_call + "\n```" else: - output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}] + output_messages = query_messages + [{"role": Role.ASSISTANT.value, "content": result}] bot_text = result chatbot[-1] = [query, self.postprocess(bot_text)]