fix cli_demo

This commit is contained in:
hiyouga 2024-01-20 23:27:10 +08:00
parent cf818a2598
commit c550987a72
4 changed files with 6 additions and 6 deletions

View File

@ -34,6 +34,7 @@ def main():
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
@ -41,8 +42,6 @@ def main():
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "user", "content": query})
messages.append({"role": "assistant", "content": response})

View File

@ -37,6 +37,7 @@ class ChatModel:
tools: Optional[str] = None,
**input_kwargs,
) -> Tuple[Dict[str, Any], int]:
messages += [{"role": "assistant", "content": ""}]
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
)

View File

@ -141,7 +141,7 @@ class Template:
max_source_len, max_target_len = infer_max_len(
source_len=len(encoded_messages[i]),
target_len=len(encoded_messages[i + 1]),
cutoff_len=(cutoff_len - total_length),
max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
encoded_messages[i] = encoded_messages[i][: max_source_len]

View File

@ -38,10 +38,10 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, cutoff_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = cutoff_len - max_target_len
max_source_len = max_len - max_target_len
return max_source_len, max_target_len