From cbbee7933e80df9a9af45160c8e6c076df00b4f8 Mon Sep 17 00:00:00 2001 From: codemayq Date: Mon, 21 Aug 2023 20:51:24 +0800 Subject: [PATCH] add template encode test --- tests/template_encode.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/template_encode.py diff --git a/tests/template_encode.py b/tests/template_encode.py new file mode 100644 index 00000000..c1c936b7 --- /dev/null +++ b/tests/template_encode.py @@ -0,0 +1,36 @@ +# Test Template Encode +# Usage: python .\tests\template_encode.py --model_name_and_path D:\llm\chinese-alpaca-2-7b +# --template llama2_zh --query 'how are you?' +# --history '[[\"Hello!\",\"Hi,I am llama2.\"]]' + +import sys +import fire +from typing import List, Optional, Tuple +from transformers import AutoTokenizer +sys.path.append("./src") +from llmtuner.extras.template import get_template_and_fix_tokenizer + + +def encode( + model_name_and_path: str, + template: str, + query: str, + resp: Optional[str] = "", + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None): + tokenizer = AutoTokenizer.from_pretrained( + model_name_and_path, + trust_remote_code=True + ) + + template = get_template_and_fix_tokenizer(template, tokenizer) + + encoded_pairs = template.encode_multiturn(tokenizer, query, resp, history, system) + for prompt_ids, answer_ids in encoded_pairs: + print("="*50) + print("prompt_ids: {}, answer_ids: {}".format(prompt_ids, answer_ids)) + print("prompt decode: {}".format(tokenizer.decode(prompt_ids))) + + +if __name__ == '__main__': + fire.Fire(encode)