fix cli_demo
This commit is contained in:
parent
cf818a2598
commit
c550987a72
|
@ -34,6 +34,7 @@ def main():
|
||||||
print("History has been removed.")
|
print("History has been removed.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": query})
|
||||||
print("Assistant: ", end="", flush=True)
|
print("Assistant: ", end="", flush=True)
|
||||||
|
|
||||||
response = ""
|
response = ""
|
||||||
|
@ -41,8 +42,6 @@ def main():
|
||||||
print(new_text, end="", flush=True)
|
print(new_text, end="", flush=True)
|
||||||
response += new_text
|
response += new_text
|
||||||
print()
|
print()
|
||||||
|
|
||||||
messages.append({"role": "user", "content": query})
|
|
||||||
messages.append({"role": "assistant", "content": response})
|
messages.append({"role": "assistant", "content": response})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,7 @@ class ChatModel:
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
|
messages += [{"role": "assistant", "content": ""}]
|
||||||
prompt, _ = self.template.encode_oneturn(
|
prompt, _ = self.template.encode_oneturn(
|
||||||
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
|
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
|
||||||
)
|
)
|
||||||
|
|
|
@ -141,7 +141,7 @@ class Template:
|
||||||
max_source_len, max_target_len = infer_max_len(
|
max_source_len, max_target_len = infer_max_len(
|
||||||
source_len=len(encoded_messages[i]),
|
source_len=len(encoded_messages[i]),
|
||||||
target_len=len(encoded_messages[i + 1]),
|
target_len=len(encoded_messages[i + 1]),
|
||||||
cutoff_len=(cutoff_len - total_length),
|
max_len=(cutoff_len - total_length),
|
||||||
reserved_label_len=reserved_label_len,
|
reserved_label_len=reserved_label_len,
|
||||||
)
|
)
|
||||||
encoded_messages[i] = encoded_messages[i][: max_source_len]
|
encoded_messages[i] = encoded_messages[i][: max_source_len]
|
||||||
|
|
|
@ -38,10 +38,10 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||||
|
|
||||||
|
|
||||||
def infer_max_len(source_len: int, target_len: int, cutoff_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||||
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||||
max_target_len = max(max_target_len, reserved_label_len)
|
max_target_len = max(max_target_len, reserved_label_len)
|
||||||
max_source_len = cutoff_len - max_target_len
|
max_source_len = max_len - max_target_len
|
||||||
return max_source_len, max_target_len
|
return max_source_len, max_target_len
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue