fix bug
This commit is contained in:
parent
a9980617f5
commit
4b841a6b35
|
@ -15,6 +15,7 @@ class ChatModel:
|
||||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
self.model = dispatch_model(self.model)
|
self.model = dispatch_model(self.model)
|
||||||
|
self.model = self.model.eval() # change to eval mode
|
||||||
self.template = get_template(data_args.template)
|
self.template = get_template(data_args.template)
|
||||||
self.source_prefix = data_args.source_prefix
|
self.source_prefix = data_args.source_prefix
|
||||||
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
||||||
|
|
|
@ -75,7 +75,7 @@ class Template:
|
||||||
if tokenizer.eos_token_id and getattr(tokenizer, "add_eos_token", False):
|
if tokenizer.eos_token_id and getattr(tokenizer, "add_eos_token", False):
|
||||||
eos_ids = [tokenizer.eos_token_id]
|
eos_ids = [tokenizer.eos_token_id]
|
||||||
else: # use the first stop word as the eos token
|
else: # use the first stop word as the eos token
|
||||||
eos_ids = tokenizer.convert_tokens_to_ids(self.stop_words[0])
|
eos_ids = [tokenizer.convert_tokens_to_ids(self.stop_words[0])]
|
||||||
|
|
||||||
return bos_ids, eos_ids
|
return bos_ids, eos_ids
|
||||||
|
|
||||||
|
|
|
@ -84,10 +84,7 @@ def load_model_and_tokenizer(
|
||||||
if model_args.quantization_bit == 8:
|
if model_args.quantization_bit == 8:
|
||||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||||
config_kwargs["load_in_8bit"] = True
|
config_kwargs["load_in_8bit"] = True
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
load_in_8bit=True,
|
|
||||||
llm_int8_threshold=6.0
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_args.quantization_bit == 4:
|
elif model_args.quantization_bit == 4:
|
||||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||||
|
|
Loading…
Reference in New Issue