tiny fix
This commit is contained in:
parent
2780792754
commit
ff98f1cba8
|
@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import PreTrainedModel, TextIteratorStreamer
|
from transformers import PreTrainedModel, TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria
|
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
|
||||||
from llmtuner.extras.template import get_template
|
from llmtuner.extras.template import get_template
|
||||||
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ class ChatModel:
|
||||||
self.source_prefix = data_args.source_prefix
|
self.source_prefix = data_args.source_prefix
|
||||||
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
||||||
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
|
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
|
||||||
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model
|
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
|
||||||
|
|
||||||
def process_args(
|
def process_args(
|
||||||
self,
|
self,
|
||||||
|
@ -52,7 +52,7 @@ class ChatModel:
|
||||||
top_k=top_k or gen_kwargs["top_k"],
|
top_k=top_k or gen_kwargs["top_k"],
|
||||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||||
logits_processor=get_logits_processor(),
|
logits_processor=get_logits_processor(),
|
||||||
stopping_criteria=get_stopwords_criteria(self.stop_ids)
|
stopping_criteria=get_stopping_criteria(self.stop_ids)
|
||||||
))
|
))
|
||||||
|
|
||||||
if max_length:
|
if max_length:
|
||||||
|
|
|
@ -29,7 +29,6 @@ class AverageMeter:
|
||||||
self.avg = self.sum / self.count
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
# Avoids runtime error in model.generate(do_sample=True).
|
|
||||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
@ -55,10 +54,10 @@ class StopWordsCriteria(StoppingCriteria):
|
||||||
return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids])
|
return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids])
|
||||||
|
|
||||||
|
|
||||||
def get_stopwords_criteria(stop_ids: List[int]) -> StoppingCriteriaList:
|
def get_stopping_criteria(stop_ids: List[int]) -> StoppingCriteriaList:
|
||||||
stopwords_criteria = StoppingCriteriaList()
|
stopping_criteria = StoppingCriteriaList()
|
||||||
stopwords_criteria.append(StopWordsCriteria(stop_ids))
|
stopping_criteria.append(StopWordsCriteria(stop_ids))
|
||||||
return stopwords_criteria
|
return stopping_criteria
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
|
|
Loading…
Reference in New Issue