From 4b841a6b35585120c65e2718d6002c69cc40b925 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 8 Aug 2023 17:55:55 +0800 Subject: [PATCH] fix bug --- src/llmtuner/chat/stream_chat.py | 1 + src/llmtuner/extras/template.py | 2 +- src/llmtuner/tuner/core/loader.py | 5 +---- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index c33a7e61..3e41c54d 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -15,6 +15,7 @@ class ChatModel: model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model = dispatch_model(self.model) + self.model = self.model.eval() # change to eval mode self.template = get_template(data_args.template) self.source_prefix = data_args.source_prefix self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 6463c3d5..ed0a6fbc 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -75,7 +75,7 @@ class Template: if tokenizer.eos_token_id and getattr(tokenizer, "add_eos_token", False): eos_ids = [tokenizer.eos_token_id] else: # use the first stop word as the eos token - eos_ids = tokenizer.convert_tokens_to_ids(self.stop_words[0]) + eos_ids = [tokenizer.convert_tokens_to_ids(self.stop_words[0])] return bos_ids, eos_ids diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 690a6c80..d66c6060 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -84,10 +84,7 @@ def load_model_and_tokenizer( if model_args.quantization_bit == 8: require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") config_kwargs["load_in_8bit"] = True - config_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_8bit=True, - llm_int8_threshold=6.0 - ) + config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) elif model_args.quantization_bit == 4: require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")