fix tokenizer #417
This commit is contained in:
parent
caa0eda27d
commit
eecc4b2131
|
@ -5,7 +5,7 @@ from threading import Thread
|
|||
from transformers import PreTrainedModel, TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
|
||||
from llmtuner.extras.template import get_template
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@ class ChatModel:
|
|||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.model = dispatch_model(self.model)
|
||||
self.model = self.model.eval() # change to eval mode
|
||||
self.template = get_template(data_args.template)
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
||||
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
|
|||
from itertools import chain
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.template import get_template
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
|
@ -19,7 +19,7 @@ def preprocess_dataset(
|
|||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> "Dataset":
|
||||
column_names = list(dataset.column_names)
|
||||
template = get_template(data_args.template)
|
||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||
|
||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||
for i in range(len(examples["prompt"])):
|
||||
|
@ -119,10 +119,9 @@ def preprocess_dataset(
|
|||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(''.join([
|
||||
tokenizer.decode(d, skip_special_tokens=False)
|
||||
if d != IGNORE_INDEX else '-100' for d in example["labels"]
|
||||
])))
|
||||
print("labels:\n{}".format(tokenizer.decode([
|
||||
token_id if token_id != IGNORE_INDEX else tokenizer.pad_token_id for token_id in example["labels"]
|
||||
], skip_special_tokens=False)))
|
||||
|
||||
def print_pairwise_dataset_example(example):
|
||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
||||
|
|
|
@ -67,15 +67,15 @@ class Template:
|
|||
self,
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if tokenizer.bos_token_id and getattr(tokenizer, "add_bos_token", False):
|
||||
if tokenizer.bos_token_id and getattr(tokenizer, "add_bos_token", True):
|
||||
bos_ids = [tokenizer.bos_token_id]
|
||||
else: # bos token is optional
|
||||
bos_ids = []
|
||||
else:
|
||||
bos_ids = [] # bos token is optional
|
||||
|
||||
if tokenizer.eos_token_id and getattr(tokenizer, "add_eos_token", False):
|
||||
if tokenizer.eos_token_id and getattr(tokenizer, "add_eos_token", True):
|
||||
eos_ids = [tokenizer.eos_token_id]
|
||||
else: # use the first stop word as the eos token
|
||||
eos_ids = [tokenizer.convert_tokens_to_ids(self.stop_words[0])]
|
||||
else:
|
||||
raise ValueError("EOS token is required.")
|
||||
|
||||
return bos_ids, eos_ids
|
||||
|
||||
|
@ -172,9 +172,19 @@ def register_template(
|
|||
)
|
||||
|
||||
|
||||
def get_template(name: str) -> Template:
|
||||
def get_template_and_fix_tokenizer(
|
||||
name: str,
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
) -> Template:
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
|
||||
if tokenizer.eos_token_id is None and len(template.stop_words): # inplace method
|
||||
tokenizer.eos_token = template.stop_words[0]
|
||||
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return template
|
||||
|
||||
|
||||
|
|
|
@ -68,8 +68,6 @@ def load_model_and_tokenizer(
|
|||
padding_side=model_args.padding_side,
|
||||
**config_kwargs
|
||||
)
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: # add pad token
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
|
|
Loading…
Reference in New Issue