From 1bd13d7ca197edaa9a1143b061249b4fa6003b97 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 3 Jun 2023 23:22:05 +0800 Subject: [PATCH] fix int8 inference --- src/cli_demo.py | 9 --------- src/utils/common.py | 9 +++------ src/web_demo.py | 9 --------- 3 files changed, 3 insertions(+), 24 deletions(-) diff --git a/src/cli_demo.py b/src/cli_demo.py index 90e0e7bd..72091d0c 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -17,15 +17,6 @@ def main(): model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model, tokenizer = load_pretrained(model_args, finetuning_args) - if torch.cuda.device_count() > 1: - from accelerate import dispatch_model, infer_auto_device_map - device_map = infer_auto_device_map(model) - model = dispatch_model(model, device_map) - else: - model = model.cuda() - - model.eval() - def format_example_alpaca(query, history): prompt = "Below is an instruction that describes a task. " prompt += "Write a response that appropriately completes the request.\n" diff --git a/src/utils/common.py b/src/utils/common.py index b784d1fc..a0b9b551 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -172,16 +172,13 @@ def load_pretrained( #require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git") #require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") #require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git") - from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible - cuda = get_cuda_lib_handle() - cc = get_compute_capability(cuda) - assert is_cublasLt_compatible(cc), "The current GPU(s) is incompatible with quantization." - config_kwargs["load_in_8bit"] = True - config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit is_mergeable = False logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + if model_args.quantization_bit is not None or (not is_trainable): # automatically load in CUDA + config_kwargs["device_map"] = "auto" + # Load and prepare pretrained models (without valuehead). model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, diff --git a/src/web_demo.py b/src/web_demo.py index 5cd05c34..c5c0ddf3 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -15,15 +15,6 @@ require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher model_args, data_args, finetuning_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) -if torch.cuda.device_count() > 1: - from accelerate import dispatch_model, infer_auto_device_map - device_map = infer_auto_device_map(model) - model = dispatch_model(model, device_map) -else: - model = model.cuda() - -model.eval() - """Override Chatbot.postprocess"""