From ff98f1cba8d3be5b6a516b26a6019f867365110e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 3 Aug 2023 17:42:28 +0800 Subject: [PATCH] tiny fix --- src/llmtuner/chat/stream_chat.py | 6 +++--- src/llmtuner/extras/misc.py | 9 ++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 7796e90f..13168007 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple from threading import Thread from transformers import PreTrainedModel, TextIteratorStreamer -from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria +from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria from llmtuner.extras.template import get_template from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer @@ -19,7 +19,7 @@ class ChatModel: self.source_prefix = data_args.source_prefix self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words) self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words)) - self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model + self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen) def process_args( self, @@ -52,7 +52,7 @@ class ChatModel: top_k=top_k or gen_kwargs["top_k"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], logits_processor=get_logits_processor(), - stopping_criteria=get_stopwords_criteria(self.stop_ids) + stopping_criteria=get_stopping_criteria(self.stop_ids) )) if max_length: diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 766de40d..e1fbb156 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -29,7 +29,6 @@ class AverageMeter: self.avg = self.sum / self.count -# Avoids runtime error in model.generate(do_sample=True). class InvalidScoreLogitsProcessor(LogitsProcessor): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -55,10 +54,10 @@ class StopWordsCriteria(StoppingCriteria): return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids]) -def get_stopwords_criteria(stop_ids: List[int]) -> StoppingCriteriaList: - stopwords_criteria = StoppingCriteriaList() - stopwords_criteria.append(StopWordsCriteria(stop_ids)) - return stopwords_criteria +def get_stopping_criteria(stop_ids: List[int]) -> StoppingCriteriaList: + stopping_criteria = StoppingCriteriaList() + stopping_criteria.append(StopWordsCriteria(stop_ids)) + return stopping_criteria def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: