fix int8 inference

This commit is contained in:
hiyouga 2023-06-03 23:22:05 +08:00
parent 926291940d
commit 1bd13d7ca1
3 changed files with 3 additions and 24 deletions

View File

@ -17,15 +17,6 @@ def main():
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args, finetuning_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
def format_example_alpaca(query, history): def format_example_alpaca(query, history):
prompt = "Below is an instruction that describes a task. " prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n" prompt += "Write a response that appropriately completes the request.\n"

View File

@ -172,16 +172,13 @@ def load_pretrained(
#require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git") #require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git")
#require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") #require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
#require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git") #require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git")
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
assert is_cublasLt_compatible(cc), "The current GPU(s) is incompatible with quantization."
config_kwargs["load_in_8bit"] = True config_kwargs["load_in_8bit"] = True
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
is_mergeable = False is_mergeable = False
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if model_args.quantization_bit is not None or (not is_trainable): # automatically load in CUDA
config_kwargs["device_map"] = "auto"
# Load and prepare pretrained models (without valuehead). # Load and prepare pretrained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,

View File

@ -15,15 +15,6 @@ require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher
model_args, data_args, finetuning_args = prepare_infer_args() model_args, data_args, finetuning_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
"""Override Chatbot.postprocess""" """Override Chatbot.postprocess"""