From 572ea3bafb1b495e33b1abd1998972f3a5e6f310 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 9 Aug 2023 17:52:15 +0800 Subject: [PATCH] fix tokenizer --- src/llmtuner/extras/template.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 5b00af03..91595751 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -1,10 +1,15 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from dataclasses import dataclass +from llmtuner.extras.logging import get_logger + if TYPE_CHECKING: from transformers import PreTrainedTokenizer +logger = get_logger(__name__) + + @dataclass class Template: @@ -179,11 +184,16 @@ def get_template_and_fix_tokenizer( template = templates.get(name, None) assert template is not None, "Template {} does not exist.".format(name) - if tokenizer.eos_token_id is None and len(template.stop_words): # inplace method - tokenizer.eos_token = template.stop_words[0] + if tokenizer.eos_token_id is None: # inplace method + if len(template.stop_words): + tokenizer.eos_token = template.stop_words[0] + else: + tokenizer.eos_token = "<|endoftext|>" + logger.info("Add eos token: {}".format(tokenizer.eos_token)) - if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: + if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token + logger.info("Add pad token: {}".format(tokenizer.pad_token)) tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words)) return template