fix tokenizer
This commit is contained in:
parent
ef5b299b18
commit
572ea3bafb
|
@ -1,10 +1,15 @@
|
|||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
|
||||
|
@ -179,11 +184,16 @@ def get_template_and_fix_tokenizer(
|
|||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
|
||||
if tokenizer.eos_token_id is None and len(template.stop_words): # inplace method
|
||||
tokenizer.eos_token = template.stop_words[0]
|
||||
if tokenizer.eos_token_id is None: # inplace method
|
||||
if len(template.stop_words):
|
||||
tokenizer.eos_token = template.stop_words[0]
|
||||
else:
|
||||
tokenizer.eos_token = "<|endoftext|>"
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
|
||||
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))
|
||||
return template
|
||||
|
|
Loading…
Reference in New Issue