support multiple modules in freeze training #1514

This commit is contained in:
hiyouga 2023-11-15 17:08:18 +08:00
parent bbbce1f516
commit 4907452d95
2 changed files with 19 additions and 11 deletions

View File

@ -20,22 +20,23 @@ class FinetuningArguments:
default=3, default=3,
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."} metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
) )
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( name_module_trainable: Optional[str] = field(
default="mlp", default="mlp",
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \
LLaMA choices: [\"mlp\", \"self_attn\"], \ LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \ BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
Qwen choices: [\"mlp\", \"attn\"], \ Qwen choices: [\"mlp\", \"attn\"], \
Phi-1.5 choices: [\"mlp\", \"mixer\"], \ Phi-1.5 choices: [\"mlp\", \"mixer\"], \
LLaMA-2, BlueLM, Baichuan, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."} Others choices: the same as LLaMA."}
) )
lora_rank: Optional[int] = field( lora_rank: Optional[int] = field(
default=8, default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
) )
lora_alpha: Optional[float] = field( lora_alpha: Optional[float] = field(
default=32.0, default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."}
) )
lora_dropout: Optional[float] = field( lora_dropout: Optional[float] = field(
default=0.1, default=0.1,
@ -49,7 +50,7 @@ class FinetuningArguments:
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
LLaMA-2, BlueLM, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."} Others choices: the same as LLaMA."}
) )
additional_target: Optional[str] = field( additional_target: Optional[str] = field(
default=None, default=None,
@ -93,12 +94,15 @@ class FinetuningArguments:
) )
def __post_init__(self): def __post_init__(self):
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA def split_arg(arg):
self.lora_target = [target.strip() for target in self.lora_target.split(",")] if isinstance(arg, str):
return [item.strip() for item in arg.split(",")]
if isinstance(self.additional_target, str): return arg
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
self.name_module_trainable = split_arg(self.name_module_trainable)
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target)
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):

View File

@ -46,7 +46,11 @@ def init_adapter(
else: # fine-tuning the first n layers if num_layer_trainable < 0 else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids] trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
for idx in trainable_layer_ids:
trainable_layers.append("{:d}.{}".format(idx, module_name))
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in trainable_layers): if not any(trainable_layer in name for trainable_layer in trainable_layers):
param.requires_grad_(False) param.requires_grad_(False)