update stream_chat

This commit is contained in:
hiyouga 2023-07-15 19:51:02 +08:00
parent 657cf0f55a
commit 8528a84e74
2 changed files with 16 additions and 8 deletions

View File

@ -23,7 +23,7 @@ class ChatModel:
self.generating_args = generating_args
def process_args(
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix if prefix else self.source_prefix
@ -59,7 +59,7 @@ class ChatModel:
return gen_kwargs, prompt_length
def chat(
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
@ -69,7 +69,7 @@ class ChatModel:
return response, (prompt_length, response_length)
def stream_chat(
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Tuple
from dataclasses import dataclass
@ -139,25 +139,33 @@ class Template:
else:
raise ValueError("Template {} does not exist.".format(self.name))
def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str:
def get_prompt(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> str:
r"""
Returns a string containing prompt without response.
"""
return "".join(self._format_example(query, history, prefix))
def get_dialog(self, query: str, resp: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
def get_dialog(
self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
r"""
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
"""
return self._format_example(query, history, prefix) + [resp]
def _register_template(self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True) -> None:
def _register_template(
self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True
) -> None:
self.prefix = prefix
self.prompt = prompt
self.sep = sep
self.use_history = use_history
def _format_example(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
def _format_example(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
prefix = prefix if prefix else self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
history = history if (history and self.use_history) else []