fix unusual output of 8bit models #278 #391

This commit is contained in:
hiyouga 2023-08-12 00:25:29 +08:00
parent a48cb0d474
commit dd51c24203
2 changed files with 4 additions and 1 deletions

View File

@ -142,6 +142,9 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory

View File

@ -92,7 +92,7 @@ def load_model_and_tokenizer(
)
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pretrained models (without valuehead).