This commit is contained in:
hiyouga 2023-08-03 17:42:28 +08:00
parent 2780792754
commit ff98f1cba8
2 changed files with 7 additions and 8 deletions

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import PreTrainedModel, TextIteratorStreamer from transformers import PreTrainedModel, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
from llmtuner.extras.template import get_template from llmtuner.extras.template import get_template
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
@ -19,7 +19,7 @@ class ChatModel:
self.source_prefix = data_args.source_prefix self.source_prefix = data_args.source_prefix
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words) self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words)) self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
def process_args( def process_args(
self, self,
@ -52,7 +52,7 @@ class ChatModel:
top_k=top_k or gen_kwargs["top_k"], top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
stopping_criteria=get_stopwords_criteria(self.stop_ids) stopping_criteria=get_stopping_criteria(self.stop_ids)
)) ))
if max_length: if max_length:

View File

@ -29,7 +29,6 @@ class AverageMeter:
self.avg = self.sum / self.count self.avg = self.sum / self.count
# Avoids runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor): class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@ -55,10 +54,10 @@ class StopWordsCriteria(StoppingCriteria):
return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids]) return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids])
def get_stopwords_criteria(stop_ids: List[int]) -> StoppingCriteriaList: def get_stopping_criteria(stop_ids: List[int]) -> StoppingCriteriaList:
stopwords_criteria = StoppingCriteriaList() stopping_criteria = StoppingCriteriaList()
stopwords_criteria.append(StopWordsCriteria(stop_ids)) stopping_criteria.append(StopWordsCriteria(stop_ids))
return stopwords_criteria return stopping_criteria
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: