fix qwen tokenizer #361
This commit is contained in:
parent
1afa51c2fa
commit
7f18d2a335
|
@ -98,12 +98,16 @@ class Template:
|
||||||
r"""
|
r"""
|
||||||
Converts context to token ids.
|
Converts context to token ids.
|
||||||
"""
|
"""
|
||||||
|
if hasattr(tokenizer, "tokenizer"): # for tiktoken tokenizer (Qwen)
|
||||||
|
kwargs = dict(allowed_special="all")
|
||||||
|
else:
|
||||||
|
kwargs = dict(add_special_tokens=False)
|
||||||
|
|
||||||
token_ids = []
|
token_ids = []
|
||||||
for elem in context:
|
for elem in context:
|
||||||
if isinstance(elem, str):
|
if isinstance(elem, str):
|
||||||
elem = elem.replace("{{query}}", query, 1)
|
elem = elem.replace("{{query}}", query, 1)
|
||||||
elem = elem.replace("<mask>", "[MASK]")
|
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
||||||
token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False)
|
|
||||||
elif isinstance(elem, dict):
|
elif isinstance(elem, dict):
|
||||||
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue