a monkey patch for lora_target
This commit is contained in:
parent
f8193e8009
commit
262252d67b
|
@ -29,3 +29,12 @@ SUPPORTED_MODELS = {
|
|||
"InternLM-7B-Base": "internlm/internlm-7b",
|
||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
|
||||
}
|
||||
|
||||
DEFAULT_MODULE = { # will be deprecated
|
||||
"LLaMA": "q_proj,v_proj",
|
||||
"BLOOM": "query_key_value",
|
||||
"BLOOMZ": "query_key_value",
|
||||
"Falcon": "query_key_value",
|
||||
"Baichuan": "W_pack",
|
||||
"InternLM": "q_proj,v_proj"
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import transformers
|
|||
from typing import Optional, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import get_train_args, run_sft
|
||||
|
@ -79,6 +80,7 @@ class Runner:
|
|||
model_name_or_path=model_name_or_path,
|
||||
do_train=True,
|
||||
finetuning_type=finetuning_type,
|
||||
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
|
||||
prompt_template=template,
|
||||
dataset=",".join(dataset),
|
||||
dataset_dir=dataset_dir,
|
||||
|
|
Loading…
Reference in New Issue