diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 3e41c54d..3be70616 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -5,7 +5,7 @@ from threading import Thread from transformers import PreTrainedModel, TextIteratorStreamer from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria -from llmtuner.extras.template import get_template +from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer @@ -16,7 +16,7 @@ class ChatModel: self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model = dispatch_model(self.model) self.model = self.model.eval() # change to eval mode - self.template = get_template(data_args.template) + self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.source_prefix = data_args.source_prefix self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words) self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words)) diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 2a25805e..dc01a77d 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal from itertools import chain from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.extras.template import get_template +from llmtuner.extras.template import get_template_and_fix_tokenizer if TYPE_CHECKING: from datasets import Dataset @@ -19,7 +19,7 @@ def preprocess_dataset( stage: Literal["pt", "sft", "rm", "ppo"] ) -> "Dataset": column_names = list(dataset.column_names) - template = get_template(data_args.template) + template = get_template_and_fix_tokenizer(data_args.template, tokenizer) def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: for i in range(len(examples["prompt"])): @@ -119,10 +119,9 @@ def preprocess_dataset( print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("label_ids:\n{}".format(example["labels"])) - print("labels:\n{}".format(''.join([ - tokenizer.decode(d, skip_special_tokens=False) - if d != IGNORE_INDEX else '-100' for d in example["labels"] - ]))) + print("labels:\n{}".format(tokenizer.decode([ + token_id if token_id != IGNORE_INDEX else tokenizer.pad_token_id for token_id in example["labels"] + ], skip_special_tokens=False))) def print_pairwise_dataset_example(example): print("accept_ids:\n{}".format(example["accept_ids"])) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index ed0a6fbc..402aa3b5 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -67,15 +67,15 @@ class Template: self, tokenizer: "PreTrainedTokenizer" ) -> Tuple[List[int], List[int]]: - if tokenizer.bos_token_id and getattr(tokenizer, "add_bos_token", False): + if tokenizer.bos_token_id and getattr(tokenizer, "add_bos_token", True): bos_ids = [tokenizer.bos_token_id] - else: # bos token is optional - bos_ids = [] + else: + bos_ids = [] # bos token is optional - if tokenizer.eos_token_id and getattr(tokenizer, "add_eos_token", False): + if tokenizer.eos_token_id and getattr(tokenizer, "add_eos_token", True): eos_ids = [tokenizer.eos_token_id] - else: # use the first stop word as the eos token - eos_ids = [tokenizer.convert_tokens_to_ids(self.stop_words[0])] + else: + raise ValueError("EOS token is required.") return bos_ids, eos_ids @@ -172,9 +172,19 @@ def register_template( ) -def get_template(name: str) -> Template: +def get_template_and_fix_tokenizer( + name: str, + tokenizer: "PreTrainedTokenizer" +) -> Template: template = templates.get(name, None) assert template is not None, "Template {} does not exist.".format(name) + + if tokenizer.eos_token_id is None and len(template.stop_words): # inplace method + tokenizer.eos_token = template.stop_words[0] + + if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: + tokenizer.pad_token = tokenizer.eos_token + return template diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index d66c6060..c06eabfa 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -68,8 +68,6 @@ def load_model_and_tokenizer( padding_side=model_args.padding_side, **config_kwargs ) - if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: # add pad token - tokenizer.pad_token = tokenizer.eos_token if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": model_to_load = model_args.checkpoint_dir[0]