add template encode test
This commit is contained in:
parent
5235b15c91
commit
cbbee7933e
|
@ -0,0 +1,36 @@
|
||||||
|
# Test Template Encode
|
||||||
|
# Usage: python .\tests\template_encode.py --model_name_and_path D:\llm\chinese-alpaca-2-7b
|
||||||
|
# --template llama2_zh --query 'how are you?'
|
||||||
|
# --history '[[\"Hello!\",\"Hi,I am llama2.\"]]'
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import fire
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
sys.path.append("./src")
|
||||||
|
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
model_name_and_path: str,
|
||||||
|
template: str,
|
||||||
|
query: str,
|
||||||
|
resp: Optional[str] = "",
|
||||||
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
system: Optional[str] = None):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_name_and_path,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
template = get_template_and_fix_tokenizer(template, tokenizer)
|
||||||
|
|
||||||
|
encoded_pairs = template.encode_multiturn(tokenizer, query, resp, history, system)
|
||||||
|
for prompt_ids, answer_ids in encoded_pairs:
|
||||||
|
print("="*50)
|
||||||
|
print("prompt_ids: {}, answer_ids: {}".format(prompt_ids, answer_ids))
|
||||||
|
print("prompt decode: {}".format(tokenizer.decode(prompt_ids)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
fire.Fire(encode)
|
Loading…
Reference in New Issue