diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 00b25b45..b6d71fcf 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple from threading import Thread from transformers import TextIteratorStreamer -from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria +from llmtuner.extras.misc import dispatch_model, get_logits_processor from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer @@ -49,10 +49,9 @@ class ChatModel: top_p=top_p or gen_kwargs["top_p"], top_k=top_k or gen_kwargs["top_k"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], - eos_token_id=self.tokenizer.eos_token_id, + eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, pad_token_id=self.tokenizer.pad_token_id, - logits_processor=get_logits_processor(), - stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids) + logits_processor=get_logits_processor() )) if max_length: diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index b3e0e4b1..db91b337 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -2,9 +2,8 @@ import torch from typing import TYPE_CHECKING, List, Optional, Tuple from transformers import ( LogitsProcessor, - LogitsProcessorList, - StoppingCriteria, - StoppingCriteriaList + InfNanRemoveLogitsProcessor, + LogitsProcessorList ) from llmtuner.extras.constants import LAYERNORM_NAMES @@ -33,37 +32,12 @@ class AverageMeter: self.avg = self.sum / self.count -class InvalidScoreLogitsProcessor(LogitsProcessor): - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 0] = 1.0 - return scores - - def get_logits_processor() -> LogitsProcessorList: logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) + logits_processor.append(InfNanRemoveLogitsProcessor()) return logits_processor -class StopWordsCriteria(StoppingCriteria): - - def __init__(self, stop_ids: List[int]) -> None: - super().__init__() - self.stop_ids = stop_ids - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids]) - - -def get_stopping_criteria(stop_ids: List[int]) -> StoppingCriteriaList: - stopping_criteria = StoppingCriteriaList() - stopping_criteria.append(StopWordsCriteria(stop_ids)) - return stopping_criteria - - def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: r""" Returns the number of trainable parameters and number of all parameters in the model. diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index fa4170f6..4474b5bb 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -10,7 +10,7 @@ from trl import PPOTrainer from trl.core import LengthSampler from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor, get_stopping_criteria +from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model @@ -74,10 +74,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): # Keyword arguments for `model.generate` gen_kwargs = self.generating_args.to_dict() - gen_kwargs["eos_token_id"] = self.tokenizer.eos_token_id + gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id gen_kwargs["logits_processor"] = get_logits_processor() - gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids) length_sampler = LengthSampler(max_target_length // 2, max_target_length) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) diff --git a/src/llmtuner/tuner/sft/trainer.py b/src/llmtuner/tuner/sft/trainer.py index 6243928f..1ddaec1f 100644 --- a/src/llmtuner/tuner/sft/trainer.py +++ b/src/llmtuner/tuner/sft/trainer.py @@ -50,9 +50,10 @@ class Seq2SeqPeftTrainer(PeftTrainer): loss, generated_tokens, labels = super().prediction_step( model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys ) - generated_tokens = ( - generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None - ) + if generated_tokens is not None: + generated_tokens[:, :max(prompt_len, label_len)] = ( + self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)]) + ) return (loss, generated_tokens, labels) @@ -72,10 +73,7 @@ class Seq2SeqPeftTrainer(PeftTrainer): assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." pad_token_id = self.tokenizer.pad_token_id else: - if self.model.config.pad_token_id is not None: - pad_token_id = self.model.config.pad_token_id - else: - raise ValueError("Pad_token_id must be set in the configuration of the model.") + raise ValueError("PAD token is required.") padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index f2d72fc6..5b0f836b 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -5,7 +5,7 @@ from transformers import DataCollatorForSeq2Seq from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.extras.misc import get_logits_processor, get_stopping_criteria +from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.sft.metric import ComputeMetrics @@ -52,10 +52,9 @@ def run_sft( # Keyword arguments for `model.generate` gen_kwargs = generating_args.to_dict() - gen_kwargs["eos_token_id"] = tokenizer.eos_token_id + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids gen_kwargs["pad_token_id"] = tokenizer.pad_token_id gen_kwargs["logits_processor"] = get_logits_processor() - gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids) # Training if training_args.do_train: