This commit is contained in:
hiyouga 2023-08-08 17:55:55 +08:00
parent a9980617f5
commit 4b841a6b35
3 changed files with 3 additions and 5 deletions

View File

@ -15,6 +15,7 @@ class ChatModel:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.model = self.model.eval() # change to eval mode
self.template = get_template(data_args.template) self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix self.source_prefix = data_args.source_prefix
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words) self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)

View File

@ -75,7 +75,7 @@ class Template:
if tokenizer.eos_token_id and getattr(tokenizer, "add_eos_token", False): if tokenizer.eos_token_id and getattr(tokenizer, "add_eos_token", False):
eos_ids = [tokenizer.eos_token_id] eos_ids = [tokenizer.eos_token_id]
else: # use the first stop word as the eos token else: # use the first stop word as the eos token
eos_ids = tokenizer.convert_tokens_to_ids(self.stop_words[0]) eos_ids = [tokenizer.convert_tokens_to_ids(self.stop_words[0])]
return bos_ids, eos_ids return bos_ids, eos_ids

View File

@ -84,10 +84,7 @@ def load_model_and_tokenizer(
if model_args.quantization_bit == 8: if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig( config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
load_in_8bit=True,
llm_int8_threshold=6.0
)
elif model_args.quantization_bit == 4: elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")