fix generation bug #532
This commit is contained in:
parent
b0ed0dec5e
commit
be21fc83f9
|
@ -3,7 +3,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
|
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
||||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||||
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
||||||
|
|
||||||
|
@ -49,10 +49,9 @@ class ChatModel:
|
||||||
top_p=top_p or gen_kwargs["top_p"],
|
top_p=top_p or gen_kwargs["top_p"],
|
||||||
top_k=top_k or gen_kwargs["top_k"],
|
top_k=top_k or gen_kwargs["top_k"],
|
||||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||||
eos_token_id=self.tokenizer.eos_token_id,
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
logits_processor=get_logits_processor(),
|
logits_processor=get_logits_processor()
|
||||||
stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
|
||||||
))
|
))
|
||||||
|
|
||||||
if max_length:
|
if max_length:
|
||||||
|
|
|
@ -2,9 +2,8 @@ import torch
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
LogitsProcessorList,
|
InfNanRemoveLogitsProcessor,
|
||||||
StoppingCriteria,
|
LogitsProcessorList
|
||||||
StoppingCriteriaList
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
|
@ -33,37 +32,12 @@ class AverageMeter:
|
||||||
self.avg = self.sum / self.count
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
||||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
|
||||||
scores.zero_()
|
|
||||||
scores[..., 0] = 1.0
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> LogitsProcessorList:
|
def get_logits_processor() -> LogitsProcessorList:
|
||||||
logits_processor = LogitsProcessorList()
|
logits_processor = LogitsProcessorList()
|
||||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||||
return logits_processor
|
return logits_processor
|
||||||
|
|
||||||
|
|
||||||
class StopWordsCriteria(StoppingCriteria):
|
|
||||||
|
|
||||||
def __init__(self, stop_ids: List[int]) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.stop_ids = stop_ids
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
||||||
return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids])
|
|
||||||
|
|
||||||
|
|
||||||
def get_stopping_criteria(stop_ids: List[int]) -> StoppingCriteriaList:
|
|
||||||
stopping_criteria = StoppingCriteriaList()
|
|
||||||
stopping_criteria.append(StopWordsCriteria(stop_ids))
|
|
||||||
return stopping_criteria
|
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
r"""
|
r"""
|
||||||
Returns the number of trainable parameters and number of all parameters in the model.
|
Returns the number of trainable parameters and number of all parameters in the model.
|
||||||
|
|
|
@ -10,7 +10,7 @@ from trl import PPOTrainer
|
||||||
from trl.core import LengthSampler
|
from trl.core import LengthSampler
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor, get_stopping_criteria
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||||
|
|
||||||
|
@ -74,10 +74,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = self.generating_args.to_dict()
|
gen_kwargs = self.generating_args.to_dict()
|
||||||
gen_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
|
gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids
|
||||||
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
|
||||||
|
|
||||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
|
|
|
@ -50,9 +50,10 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
||||||
loss, generated_tokens, labels = super().prediction_step(
|
loss, generated_tokens, labels = super().prediction_step(
|
||||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||||
)
|
)
|
||||||
generated_tokens = (
|
if generated_tokens is not None:
|
||||||
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
|
generated_tokens[:, :max(prompt_len, label_len)] = (
|
||||||
)
|
self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
|
||||||
|
)
|
||||||
|
|
||||||
return (loss, generated_tokens, labels)
|
return (loss, generated_tokens, labels)
|
||||||
|
|
||||||
|
@ -72,10 +73,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
||||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
else:
|
else:
|
||||||
if self.model.config.pad_token_id is not None:
|
raise ValueError("PAD token is required.")
|
||||||
pad_token_id = self.model.config.pad_token_id
|
|
||||||
else:
|
|
||||||
raise ValueError("Pad_token_id must be set in the configuration of the model.")
|
|
||||||
|
|
||||||
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
||||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||||
|
|
|
@ -5,7 +5,7 @@ from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.misc import get_logits_processor, get_stopping_criteria
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||||
|
@ -52,10 +52,9 @@ def run_sft(
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = generating_args.to_dict()
|
gen_kwargs = generating_args.to_dict()
|
||||||
gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
|
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
|
Loading…
Reference in New Issue