support lora for llama pro
This commit is contained in:
parent
02c8c55ce3
commit
9aeb404a94
|
@ -99,12 +99,12 @@ def preprocess_packed_supervised_dataset(
|
|||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(tokenizer, messages, examples["system"][i], examples["tools"][i])
|
||||
for source_ids, target_ids in template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
elif len(input_ids) != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
@ -112,9 +112,9 @@ def preprocess_packed_supervised_dataset(
|
|||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
total_length = len(input_ids)
|
||||
block_size = data_args.cutoff_len
|
||||
|
@ -122,9 +122,10 @@ def preprocess_packed_supervised_dataset(
|
|||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -180,7 +181,6 @@ def preprocess_pairwise_dataset(
|
|||
|
||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
chosen_messages,
|
||||
|
|
|
@ -26,10 +26,6 @@ class FreezeArguments:
|
|||
default=3,
|
||||
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
||||
)
|
||||
use_llama_pro: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use llama pro for partial-parameter (freeze) fine-tuning."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -170,6 +166,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
|||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."},
|
||||
)
|
||||
use_llama_pro: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||
)
|
||||
disable_version_checking: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable version checking."},
|
||||
|
@ -195,13 +195,13 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
|||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||
|
||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
||||
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
||||
|
||||
if self.use_llama_pro and self.finetuning_type != "freeze":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze method.")
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
|
|
|
@ -5,7 +5,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
|||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import find_all_linear_modules
|
||||
from .utils import find_all_linear_modules, find_expanded_modules
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -82,6 +82,8 @@ def init_adapter(
|
|||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
adapter_to_resume = None
|
||||
|
@ -118,6 +120,9 @@ def init_adapter(
|
|||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
"target_modules": target_modules,
|
||||
|
|
|
@ -76,6 +76,33 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
|||
return list(module_names)
|
||||
|
||||
|
||||
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
|
||||
r"""
|
||||
Finds the modules in the expanded blocks to apply lora.
|
||||
"""
|
||||
num_layers = getattr(model.config, "num_hidden_layers", None)
|
||||
if not num_layers:
|
||||
raise ValueError("Model was not supported.")
|
||||
|
||||
if num_layers % num_layer_trainable != 0:
|
||||
raise ValueError(
|
||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
|
||||
)
|
||||
|
||||
stride = num_layers // num_layer_trainable
|
||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
|
||||
module_names = []
|
||||
for name, _ in model.named_modules():
|
||||
if any(target_module in name for target_module in target_modules) and any(
|
||||
trainable_layer in name for trainable_layer in trainable_layers
|
||||
):
|
||||
module_names.append(name)
|
||||
|
||||
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
return module_names
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
|
|
|
@ -108,6 +108,18 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Freeze config", open=False) as freeze_tab:
|
||||
with gr.Row():
|
||||
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
|
||||
name_module_trainable = gr.Textbox(scale=3)
|
||||
|
||||
input_elems.update({num_layer_trainable, name_module_trainable})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
freeze_tab=freeze_tab, num_layer_trainable=num_layer_trainable, name_module_trainable=name_module_trainable
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
|
|
|
@ -508,6 +508,45 @@ LOCALES = {
|
|||
"info": "仅训练块扩展后的参数。",
|
||||
},
|
||||
},
|
||||
"freeze_tab": {
|
||||
"en": {
|
||||
"label": "Freeze tuning configurations",
|
||||
},
|
||||
"ru": {
|
||||
"label": "конфигурации для настройки заморозки",
|
||||
},
|
||||
"zh": {
|
||||
"label": "部分参数微调设置",
|
||||
},
|
||||
},
|
||||
"num_layer_trainable": {
|
||||
"en": {
|
||||
"label": "Trainable layers",
|
||||
"info": "The number of trainable layers.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Обучаемые слои",
|
||||
"info": "Количество обучаемых слоев.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "可训练层数",
|
||||
"info": "可训练模型层的数量。",
|
||||
},
|
||||
},
|
||||
"name_module_trainable": {
|
||||
"en": {
|
||||
"label": "Trainable modules",
|
||||
"info": "The name of trainable modules. Use commas to separate multiple modules.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Обучаемые модули",
|
||||
"info": "Название обучаемых модулей. Используйте запятые для разделения нескольких модулей.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "可训练模块",
|
||||
"info": "可训练模块的名称。使用英文逗号分隔多个名称。",
|
||||
},
|
||||
},
|
||||
"lora_tab": {
|
||||
"en": {
|
||||
"label": "LoRA configurations",
|
||||
|
|
|
@ -129,26 +129,34 @@ class Runner:
|
|||
sft_packing=get("train.sft_packing"),
|
||||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
use_llama_pro=get("train.use_llama_pro"),
|
||||
lora_rank=get("train.lora_rank"),
|
||||
lora_dropout=get("train.lora_dropout"),
|
||||
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
||||
additional_target=get("train.additional_target") or None,
|
||||
use_rslora=get("train.use_rslora"),
|
||||
create_new_adapter=get("train.create_new_adapter"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||
fp16=(get("train.compute_type") == "fp16"),
|
||||
bf16=(get("train.compute_type") == "bf16"),
|
||||
)
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
args["name_module_trainable"] = get("train.name_module_trainable")
|
||||
elif args["finetuning_type"] == "lora":
|
||||
args["lora_rank"] = get("train.lora_rank")
|
||||
args["lora_dropout"] = get("train.lora_dropout")
|
||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
||||
args["additional_target"] = get("train.additional_target") or None
|
||||
args["use_rslora"] = get("train.use_rslora")
|
||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||
else:
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
|
||||
if args["use_llama_pro"]:
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
|
||||
)
|
||||
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
|
||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||
|
||||
if args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
|
|
Loading…
Reference in New Issue