update stream_chat
This commit is contained in:
parent
657cf0f55a
commit
8528a84e74
|
@ -23,7 +23,7 @@ class ChatModel:
|
|||
self.generating_args = generating_args
|
||||
|
||||
def process_args(
|
||||
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
prefix = prefix if prefix else self.source_prefix
|
||||
|
||||
|
@ -59,7 +59,7 @@ class ChatModel:
|
|||
return gen_kwargs, prompt_length
|
||||
|
||||
def chat(
|
||||
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
||||
generation_output = self.model.generate(**gen_kwargs)
|
||||
|
@ -69,7 +69,7 @@ class ChatModel:
|
|||
return response, (prompt_length, response_length)
|
||||
|
||||
def stream_chat(
|
||||
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
|
@ -139,25 +139,33 @@ class Template:
|
|||
else:
|
||||
raise ValueError("Template {} does not exist.".format(self.name))
|
||||
|
||||
def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str:
|
||||
def get_prompt(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||
) -> str:
|
||||
r"""
|
||||
Returns a string containing prompt without response.
|
||||
"""
|
||||
return "".join(self._format_example(query, history, prefix))
|
||||
|
||||
def get_dialog(self, query: str, resp: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
|
||||
def get_dialog(
|
||||
self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||
) -> List[str]:
|
||||
r"""
|
||||
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
|
||||
"""
|
||||
return self._format_example(query, history, prefix) + [resp]
|
||||
|
||||
def _register_template(self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True) -> None:
|
||||
def _register_template(
|
||||
self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True
|
||||
) -> None:
|
||||
self.prefix = prefix
|
||||
self.prompt = prompt
|
||||
self.sep = sep
|
||||
self.use_history = use_history
|
||||
|
||||
def _format_example(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
|
||||
def _format_example(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||
) -> List[str]:
|
||||
prefix = prefix if prefix else self.prefix # use prefix if provided
|
||||
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
||||
history = history if (history and self.use_history) else []
|
||||
|
|
Loading…
Reference in New Issue