fix tokenizer

This commit is contained in:
hiyouga 2023-08-09 17:52:15 +08:00
parent ef5b299b18
commit 572ea3bafb
1 changed files with 13 additions and 3 deletions

View File

@ -1,10 +1,15 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
logger = get_logger(__name__)
@dataclass
class Template:
@ -179,11 +184,16 @@ def get_template_and_fix_tokenizer(
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
if tokenizer.eos_token_id is None and len(template.stop_words): # inplace method
tokenizer.eos_token = template.stop_words[0]
if tokenizer.eos_token_id is None: # inplace method
if len(template.stop_words):
tokenizer.eos_token = template.stop_words[0]
else:
tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))
return template