diff --git a/examples/README.md b/examples/README.md index 902d26b1..007a81ab 100644 --- a/examples/README.md +++ b/examples/README.md @@ -195,6 +195,12 @@ llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml ``` +#### PiSSA Fine-Tuning + +```bash +llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml +``` + #### Mixture-of-Depths Fine-Tuning ```bash @@ -211,11 +217,5 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml #### FSDP+QLoRA Fine-Tuning ```bash -bash examples/extras/fsdp_qlora/single_node.sh -``` - -#### PiSSA Fine-Tuning - -```bash -llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml +bash examples/extras/fsdp_qlora/train.sh ``` diff --git a/examples/README_zh.md b/examples/README_zh.md index 586e498c..b9d90f25 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -195,6 +195,12 @@ llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml ``` +#### PiSSA 微调 + +```bash +llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml +``` + #### 深度混合微调 ```bash @@ -211,11 +217,5 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml #### FSDP+QLoRA 微调 ```bash -bash examples/extras/fsdp_qlora/single_node.sh -``` - -#### PiSSA 微调 - -```bash -llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml +bash examples/extras/fsdp_qlora/train.sh ``` diff --git a/examples/extras/fsdp_qlora/single_node.sh b/examples/extras/fsdp_qlora/train.sh similarity index 100% rename from examples/extras/fsdp_qlora/single_node.sh rename to examples/extras/fsdp_qlora/train.sh diff --git a/scripts/llama_pro.py b/scripts/llama_pro.py index 395375ef..17bf6fc2 100644 --- a/scripts/llama_pro.py +++ b/scripts/llama_pro.py @@ -120,7 +120,7 @@ def block_expansion( json.dump(index, f, indent=2, sort_keys=True) print("Model weights saved in {}".format(output_dir)) - print("Fine-tune this model with:") + print("- Fine-tune this model with:") print("model_name_or_path: {}".format(output_dir)) print("finetuning_type: freeze") print("freeze_trainable_layers: {}".format(num_expand)) diff --git a/scripts/loftq_init.py b/scripts/loftq_init.py index 556f342c..b9506fa3 100644 --- a/scripts/loftq_init.py +++ b/scripts/loftq_init.py @@ -74,7 +74,7 @@ def quantize_loftq( tokenizer.save_pretrained(output_dir) print("Model weights saved in {}".format(output_dir)) - print("Fine-tune this model with:") + print("- Fine-tune this model with:") print("model_name_or_path: {}".format(output_dir)) print("adapter_name_or_path: {}".format(loftq_dir)) print("finetuning_type: lora") diff --git a/scripts/pissa_init.py b/scripts/pissa_init.py index 1b673c45..10b81efc 100644 --- a/scripts/pissa_init.py +++ b/scripts/pissa_init.py @@ -68,11 +68,14 @@ def quantize_pissa( tokenizer.save_pretrained(output_dir) print("Model weights saved in {}".format(output_dir)) - print("Fine-tune this model with:") + print("- Fine-tune this model with:") print("model_name_or_path: {}".format(output_dir)) print("adapter_name_or_path: {}".format(pissa_dir)) print("finetuning_type: lora") + print("pissa_init: false") print("pissa_convert: true") + print("- and optionally with:") + print("quantization_bit: 4") if __name__ == "__main__": diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py index 64566fe8..630e5f75 100644 --- a/tests/model/test_lora.py +++ b/tests/model/test_lora.py @@ -56,9 +56,15 @@ INFER_ARGS = { } -def load_reference_model() -> "torch.nn.Module": - model = AutoModelForCausalLM.from_pretrained(TINY_LLAMA) - return PeftModel.from_pretrained(model, TINY_LLAMA_ADAPTER) +def load_reference_model(is_trainable: bool = False) -> "LoraModel": + model = AutoModelForCausalLM.from_pretrained( + TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device() + ) + lora_model = PeftModel.from_pretrained(model, TINY_LLAMA_ADAPTER, is_trainable=is_trainable) + for param in filter(lambda p: p.requires_grad, lora_model.parameters()): + param.data = param.data.to(torch.float32) + + return lora_model def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []): @@ -148,13 +154,7 @@ def test_lora_train_old_adapters(): tokenizer_module = load_tokenizer(model_args) model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) - base_model = AutoModelForCausalLM.from_pretrained( - TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device() - ) - ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER, is_trainable=True) - for param in filter(lambda p: p.requires_grad, ref_model.parameters()): - param.data = param.data.to(torch.float32) - + ref_model = load_reference_model(is_trainable=True) compare_model(model, ref_model) @@ -165,13 +165,7 @@ def test_lora_train_new_adapters(): tokenizer_module = load_tokenizer(model_args) model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) - base_model = AutoModelForCausalLM.from_pretrained( - TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device() - ) - ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER, is_trainable=True) - for param in filter(lambda p: p.requires_grad, ref_model.parameters()): - param.data = param.data.to(torch.float32) - + ref_model = load_reference_model(is_trainable=True) compare_model( model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"] ) @@ -200,9 +194,5 @@ def test_lora_inference(): tokenizer_module = load_tokenizer(model_args) model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) - base_model = AutoModelForCausalLM.from_pretrained( - TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device() - ) - ref_model: "LoraModel" = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER) - ref_model = ref_model.merge_and_unload() + ref_model = load_reference_model().merge_and_unload() compare_model(model, ref_model)