diff --git a/scripts/pissa_init.py b/scripts/pissa_init.py index 10b81efc..50239727 100644 --- a/scripts/pissa_init.py +++ b/scripts/pissa_init.py @@ -50,7 +50,7 @@ def quantize_pissa( lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, lora_dropout=lora_dropout, target_modules=[name.strip() for name in lora_target.split(",")], - init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter) + init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter), ) # Init PiSSA model diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 1ef46eca..b676891e 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -352,7 +352,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA self.additional_target: Optional[List[str]] = split_arg(self.additional_target) self.galore_target: List[str] = split_arg(self.galore_target) self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only - self.use_ref_model = (self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]) + self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."