diff --git a/src/utils/common.py b/src/utils/common.py index 396c91d7..b784d1fc 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -55,11 +55,12 @@ require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") logger = get_logger(__name__) -def init_adapter( +def _init_adapter( model: PreTrainedModel, model_args: ModelArguments, finetuning_args: FinetuningArguments, - is_trainable: bool + is_trainable: bool, + is_mergeable: bool ) -> PreTrainedModel: r""" Initializes the adapters. @@ -84,16 +85,19 @@ def init_adapter( else: param.data = param.data.to(torch.float32) - if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: - assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." - load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods + if model_args.checkpoint_dir is not None: + if finetuning_args.finetuning_type != "lora": + assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." + load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods + else: + assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." if finetuning_args.finetuning_type == "lora": logger.info("Fine-tuning method: LoRA") lastest_checkpoint = None if model_args.checkpoint_dir is not None: - if is_trainable and model_args.resume_lora_training: # continually train on the lora weights + if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] else: checkpoints_to_merge = model_args.checkpoint_dir @@ -105,8 +109,8 @@ def init_adapter( if len(checkpoints_to_merge) > 0: logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) - if lastest_checkpoint is not None: # resume lora training - model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True) + if lastest_checkpoint is not None: # resume lora training or quantized inference + model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable) if is_trainable and lastest_checkpoint is None: # create new lora weights while training lora_config = LoraConfig( @@ -159,11 +163,15 @@ def load_pretrained( tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the token config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + is_mergeable = True # Quantization configurations (using bitsandbytes library). if model_args.quantization_bit is not None: assert model_args.quantization_bit == 8, "We only accept 8-bit quantization." - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1") + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + #require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git") + #require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") + #require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git") from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible cuda = get_cuda_lib_handle() cc = get_compute_capability(cuda) @@ -171,6 +179,7 @@ def load_pretrained( config_kwargs["load_in_8bit"] = True config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit + is_mergeable = False logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) # Load and prepare pretrained models (without valuehead). @@ -182,7 +191,7 @@ def load_pretrained( **config_kwargs ) model = prepare_model_for_training(model) if is_trainable else model - model = init_adapter(model, model_args, finetuning_args, is_trainable) + model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) if stage == "rm" or stage == "ppo": # add value head model = AutoModelForCausalLMWithValueHead.from_pretrained(model)