fix cli_demo
This commit is contained in:
parent
cf818a2598
commit
c550987a72
|
@ -34,6 +34,7 @@ def main():
|
|||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": query})
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
|
@ -41,8 +42,6 @@ def main():
|
|||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
print()
|
||||
|
||||
messages.append({"role": "user", "content": query})
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ class ChatModel:
|
|||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
messages += [{"role": "assistant", "content": ""}]
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
|
||||
)
|
||||
|
|
|
@ -141,7 +141,7 @@ class Template:
|
|||
max_source_len, max_target_len = infer_max_len(
|
||||
source_len=len(encoded_messages[i]),
|
||||
target_len=len(encoded_messages[i + 1]),
|
||||
cutoff_len=(cutoff_len - total_length),
|
||||
max_len=(cutoff_len - total_length),
|
||||
reserved_label_len=reserved_label_len,
|
||||
)
|
||||
encoded_messages[i] = encoded_messages[i][: max_source_len]
|
||||
|
|
|
@ -38,10 +38,10 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
|||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||
|
||||
|
||||
def infer_max_len(source_len: int, target_len: int, cutoff_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||
max_target_len = max(max_target_len, reserved_label_len)
|
||||
max_source_len = cutoff_len - max_target_len
|
||||
max_source_len = max_len - max_target_len
|
||||
return max_source_len, max_target_len
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue