diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 066f6c79..a6193744 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -98,12 +98,16 @@ class Template: r""" Converts context to token ids. """ + if hasattr(tokenizer, "tokenizer"): # for tiktoken tokenizer (Qwen) + kwargs = dict(allowed_special="all") + else: + kwargs = dict(add_special_tokens=False) + token_ids = [] for elem in context: if isinstance(elem, str): elem = elem.replace("{{query}}", query, 1) - elem = elem.replace("", "[MASK]") - token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False) + token_ids = token_ids + tokenizer.encode(elem, **kwargs) elif isinstance(elem, dict): token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] else: