a monkey patch for lora_target

This commit is contained in:
hiyouga 2023-07-18 00:31:40 +08:00
parent f8193e8009
commit 262252d67b
2 changed files with 11 additions and 0 deletions

View File

@ -29,3 +29,12 @@ SUPPORTED_MODELS = {
"InternLM-7B-Base": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
}
DEFAULT_MODULE = { # will be deprecated
"LLaMA": "q_proj,v_proj",
"BLOOM": "query_key_value",
"BLOOMZ": "query_key_value",
"Falcon": "query_key_value",
"Baichuan": "W_pack",
"InternLM": "q_proj,v_proj"
}

View File

@ -6,6 +6,7 @@ import transformers
from typing import Optional, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft
@ -79,6 +80,7 @@ class Runner:
model_name_or_path=model_name_or_path,
do_train=True,
finetuning_type=finetuning_type,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,