fix generation bug #532

This commit is contained in:
hiyouga 2023-08-17 22:21:34 +08:00
parent b0ed0dec5e
commit be21fc83f9
5 changed files with 15 additions and 46 deletions

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
@ -49,10 +49,9 @@ class ChatModel:
top_p=top_p or gen_kwargs["top_p"], top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"], top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
eos_token_id=self.tokenizer.eos_token_id, eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
logits_processor=get_logits_processor(), logits_processor=get_logits_processor()
stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
)) ))
if max_length: if max_length:

View File

@ -2,9 +2,8 @@ import torch
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import ( from transformers import (
LogitsProcessor, LogitsProcessor,
LogitsProcessorList, InfNanRemoveLogitsProcessor,
StoppingCriteria, LogitsProcessorList
StoppingCriteriaList
) )
from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.constants import LAYERNORM_NAMES
@ -33,37 +32,12 @@ class AverageMeter:
self.avg = self.sum / self.count self.avg = self.sum / self.count
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores
def get_logits_processor() -> LogitsProcessorList: def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList() logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor()) logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor return logits_processor
class StopWordsCriteria(StoppingCriteria):
def __init__(self, stop_ids: List[int]) -> None:
super().__init__()
self.stop_ids = stop_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids])
def get_stopping_criteria(stop_ids: List[int]) -> StoppingCriteriaList:
stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(StopWordsCriteria(stop_ids))
return stopping_criteria
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r""" r"""
Returns the number of trainable parameters and number of all parameters in the model. Returns the number of trainable parameters and number of all parameters in the model.

View File

@ -10,7 +10,7 @@ from trl import PPOTrainer
from trl.core import LengthSampler from trl.core import LengthSampler
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor, get_stopping_criteria from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
@ -74,10 +74,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = self.generating_args.to_dict() gen_kwargs = self.generating_args.to_dict()
gen_kwargs["eos_token_id"] = self.tokenizer.eos_token_id gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor() gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
length_sampler = LengthSampler(max_target_length // 2, max_target_length) length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)

View File

@ -50,9 +50,10 @@ class Seq2SeqPeftTrainer(PeftTrainer):
loss, generated_tokens, labels = super().prediction_step( loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
generated_tokens = ( if generated_tokens is not None:
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None generated_tokens[:, :max(prompt_len, label_len)] = (
) self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
)
return (loss, generated_tokens, labels) return (loss, generated_tokens, labels)
@ -72,10 +73,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
pad_token_id = self.tokenizer.pad_token_id pad_token_id = self.tokenizer.pad_token_id
else: else:
if self.model.config.pad_token_id is not None: raise ValueError("PAD token is required.")
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model.")
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding

View File

@ -5,7 +5,7 @@ from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor, get_stopping_criteria from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.metric import ComputeMetrics
@ -52,10 +52,9 @@ def run_sft(
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict() gen_kwargs = generating_args.to_dict()
gen_kwargs["eos_token_id"] = tokenizer.eos_token_id gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor() gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)
# Training # Training
if training_args.do_train: if training_args.do_train: