fix cli_demo

This commit is contained in:
hiyouga 2024-01-20 23:27:10 +08:00
parent cf818a2598
commit c550987a72
4 changed files with 6 additions and 6 deletions

View File

@ -34,6 +34,7 @@ def main():
print("History has been removed.") print("History has been removed.")
continue continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
response = "" response = ""
@ -41,8 +42,6 @@ def main():
print(new_text, end="", flush=True) print(new_text, end="", flush=True)
response += new_text response += new_text
print() print()
messages.append({"role": "user", "content": query})
messages.append({"role": "assistant", "content": response}) messages.append({"role": "assistant", "content": response})

View File

@ -37,6 +37,7 @@ class ChatModel:
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs, **input_kwargs,
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
messages += [{"role": "assistant", "content": ""}]
prompt, _ = self.template.encode_oneturn( prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
) )

View File

@ -141,7 +141,7 @@ class Template:
max_source_len, max_target_len = infer_max_len( max_source_len, max_target_len = infer_max_len(
source_len=len(encoded_messages[i]), source_len=len(encoded_messages[i]),
target_len=len(encoded_messages[i + 1]), target_len=len(encoded_messages[i + 1]),
cutoff_len=(cutoff_len - total_length), max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len, reserved_label_len=reserved_label_len,
) )
encoded_messages[i] = encoded_messages[i][: max_source_len] encoded_messages[i] = encoded_messages[i][: max_source_len]

View File

@ -38,10 +38,10 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, cutoff_len: int, reserved_label_len: int) -> Tuple[int, int]: def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len) max_target_len = max(max_target_len, reserved_label_len)
max_source_len = cutoff_len - max_target_len max_source_len = max_len - max_target_len
return max_source_len, max_target_len return max_source_len, max_target_len