update stream_chat
This commit is contained in:
parent
657cf0f55a
commit
8528a84e74
|
@ -23,7 +23,7 @@ class ChatModel:
|
||||||
self.generating_args = generating_args
|
self.generating_args = generating_args
|
||||||
|
|
||||||
def process_args(
|
def process_args(
|
||||||
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
|
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
prefix = prefix if prefix else self.source_prefix
|
prefix = prefix if prefix else self.source_prefix
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ class ChatModel:
|
||||||
return gen_kwargs, prompt_length
|
return gen_kwargs, prompt_length
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
|
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||||
) -> Tuple[str, Tuple[int, int]]:
|
) -> Tuple[str, Tuple[int, int]]:
|
||||||
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
||||||
generation_output = self.model.generate(**gen_kwargs)
|
generation_output = self.model.generate(**gen_kwargs)
|
||||||
|
@ -69,7 +69,7 @@ class ChatModel:
|
||||||
return response, (prompt_length, response_length)
|
return response, (prompt_length, response_length)
|
||||||
|
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
|
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@ -139,25 +139,33 @@ class Template:
|
||||||
else:
|
else:
|
||||||
raise ValueError("Template {} does not exist.".format(self.name))
|
raise ValueError("Template {} does not exist.".format(self.name))
|
||||||
|
|
||||||
def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str:
|
def get_prompt(
|
||||||
|
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||||
|
) -> str:
|
||||||
r"""
|
r"""
|
||||||
Returns a string containing prompt without response.
|
Returns a string containing prompt without response.
|
||||||
"""
|
"""
|
||||||
return "".join(self._format_example(query, history, prefix))
|
return "".join(self._format_example(query, history, prefix))
|
||||||
|
|
||||||
def get_dialog(self, query: str, resp: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
|
def get_dialog(
|
||||||
|
self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||||
|
) -> List[str]:
|
||||||
r"""
|
r"""
|
||||||
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
|
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
|
||||||
"""
|
"""
|
||||||
return self._format_example(query, history, prefix) + [resp]
|
return self._format_example(query, history, prefix) + [resp]
|
||||||
|
|
||||||
def _register_template(self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True) -> None:
|
def _register_template(
|
||||||
|
self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True
|
||||||
|
) -> None:
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.sep = sep
|
self.sep = sep
|
||||||
self.use_history = use_history
|
self.use_history = use_history
|
||||||
|
|
||||||
def _format_example(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
|
def _format_example(
|
||||||
|
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||||
|
) -> List[str]:
|
||||||
prefix = prefix if prefix else self.prefix # use prefix if provided
|
prefix = prefix if prefix else self.prefix # use prefix if provided
|
||||||
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
||||||
history = history if (history and self.use_history) else []
|
history = history if (history and self.use_history) else []
|
||||||
|
|
Loading…
Reference in New Issue