lint
This commit is contained in:
parent
7857c0990b
commit
24c160df3d
|
@ -50,7 +50,7 @@ def quantize_pissa(
|
||||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||||
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter)
|
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init PiSSA model
|
# Init PiSSA model
|
||||||
|
|
|
@ -352,7 +352,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||||
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||||
self.use_ref_model = (self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"])
|
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||||
|
|
||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
Loading…
Reference in New Issue