diff --git a/src/cli_demo.py b/src/cli_demo.py index 09d444c6..96007f1a 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -34,6 +34,7 @@ def main(): print("History has been removed.") continue + messages.append({"role": "user", "content": query}) print("Assistant: ", end="", flush=True) response = "" @@ -41,8 +42,6 @@ def main(): print(new_text, end="", flush=True) response += new_text print() - - messages.append({"role": "user", "content": query}) messages.append({"role": "assistant", "content": response}) diff --git a/src/llmtuner/chat/chat_model.py b/src/llmtuner/chat/chat_model.py index cfa2700b..d92848c2 100644 --- a/src/llmtuner/chat/chat_model.py +++ b/src/llmtuner/chat/chat_model.py @@ -37,6 +37,7 @@ class ChatModel: tools: Optional[str] = None, **input_kwargs, ) -> Tuple[Dict[str, Any], int]: + messages += [{"role": "assistant", "content": ""}] prompt, _ = self.template.encode_oneturn( tokenizer=self.tokenizer, messages=messages, system=system, tools=tools ) diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 77d053f3..d4fd88fc 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -141,7 +141,7 @@ class Template: max_source_len, max_target_len = infer_max_len( source_len=len(encoded_messages[i]), target_len=len(encoded_messages[i + 1]), - cutoff_len=(cutoff_len - total_length), + max_len=(cutoff_len - total_length), reserved_label_len=reserved_label_len, ) encoded_messages[i] = encoded_messages[i][: max_source_len] diff --git a/src/llmtuner/data/utils.py b/src/llmtuner/data/utils.py index b8dfa123..062d390f 100644 --- a/src/llmtuner/data/utils.py +++ b/src/llmtuner/data/utils.py @@ -38,10 +38,10 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) -def infer_max_len(source_len: int, target_len: int, cutoff_len: int, reserved_label_len: int) -> Tuple[int, int]: - max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) +def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]: + max_target_len = int(max_len * (target_len / (source_len + target_len))) max_target_len = max(max_target_len, reserved_label_len) - max_source_len = cutoff_len - max_target_len + max_source_len = max_len - max_target_len return max_source_len, max_target_len