diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 07fc7bf0..73c8f332 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -23,7 +23,7 @@ class ChatModel: self.generating_args = generating_args def process_args( - self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs + self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs ) -> Tuple[Dict[str, Any], int]: prefix = prefix if prefix else self.source_prefix @@ -59,7 +59,7 @@ class ChatModel: return gen_kwargs, prompt_length def chat( - self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs + self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs ) -> Tuple[str, Tuple[int, int]]: gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) generation_output = self.model.generate(**gen_kwargs) @@ -69,7 +69,7 @@ class ChatModel: return response, (prompt_length, response_length) def stream_chat( - self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs + self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs ) -> Generator[str, None, None]: gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 67c82c7a..88469d5c 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Tuple from dataclasses import dataclass @@ -139,25 +139,33 @@ class Template: else: raise ValueError("Template {} does not exist.".format(self.name)) - def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str: + def get_prompt( + self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" + ) -> str: r""" Returns a string containing prompt without response. """ return "".join(self._format_example(query, history, prefix)) - def get_dialog(self, query: str, resp: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]: + def get_dialog( + self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" + ) -> List[str]: r""" Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response. """ return self._format_example(query, history, prefix) + [resp] - def _register_template(self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True) -> None: + def _register_template( + self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True + ) -> None: self.prefix = prefix self.prompt = prompt self.sep = sep self.use_history = use_history - def _format_example(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]: + def _format_example( + self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" + ) -> List[str]: prefix = prefix if prefix else self.prefix # use prefix if provided prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix history = history if (history and self.use_history) else []