support lora target auto find
This commit is contained in:
parent
d8d82ca281
commit
bca1a247bc
|
@ -62,7 +62,7 @@
|
|||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
|
||||
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||
|
||||
- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
|
||||
- **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
||||
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
|
||||
|
||||
## Supported Training Approaches
|
||||
|
@ -176,7 +176,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset wiki_demo \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir path_to_pt_checkpoint \
|
||||
|
|
|
@ -62,7 +62,7 @@
|
|||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
|
||||
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||
|
||||
- **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
|
||||
- **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
|
||||
- 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用对应的模板。
|
||||
|
||||
## 训练方法
|
||||
|
@ -176,7 +176,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset wiki_demo \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir path_to_pt_checkpoint \
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
import gc
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
|
@ -29,12 +27,6 @@ class AverageMeter:
|
|||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def get_logits_processor() -> LogitsProcessorList:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
|
@ -57,42 +49,10 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
return trainable_params, all_param
|
||||
|
||||
|
||||
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
finetuning_type: str,
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
) -> "PreTrainedModel":
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
|
||||
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
||||
input_dtype = output_layer.weight.dtype
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||
|
||||
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
|
||||
|
||||
return model
|
||||
def get_logits_processor() -> LogitsProcessorList:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
|
|
|
@ -42,8 +42,8 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
|
|||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
||||
return False
|
||||
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
|
||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]), persistent=False)
|
||||
return True
|
||||
|
|
|
@ -12,6 +12,7 @@ from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
|||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import load_trainable_params
|
||||
from llmtuner.tuner.core.utils import find_all_linear_modules
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
@ -81,13 +82,18 @@ def init_adapter(
|
|||
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and latest_checkpoint is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target == "all":
|
||||
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=finetuning_args.lora_rank,
|
||||
lora_alpha=finetuning_args.lora_alpha,
|
||||
lora_dropout=finetuning_args.lora_dropout,
|
||||
target_modules=finetuning_args.lora_target
|
||||
target_modules=target_modules
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
|
|
|
@ -23,10 +23,11 @@ except ImportError:
|
|||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
||||
from llmtuner.extras.misc import count_parameters
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.tuner.core.adapter import init_adapter
|
||||
from llmtuner.tuner.core.utils import prepare_model_for_training
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
import torch
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
def find_all_linear_modules(
|
||||
model: "PreTrainedModel",
|
||||
quantization_bit: Optional[int] = None,
|
||||
output_layer_name: Optional[str] = "lm_head"
|
||||
) -> List[str]:
|
||||
if quantization_bit is not None:
|
||||
import bitsandbytes as bnb
|
||||
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
||||
else:
|
||||
linear_cls = torch.nn.Linear
|
||||
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, linear_cls):
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
if output_layer_name in module_names:
|
||||
module_names.pop(output_layer_name)
|
||||
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
finetuning_type: str,
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) upcast the lm_head to fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
||||
"""
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
|
||||
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
||||
input_dtype = output_layer.weight.dtype
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||
|
||||
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
|
||||
|
||||
return model
|
|
@ -16,8 +16,8 @@ USER_CONFIG = "user.config"
|
|||
DATA_CONFIG = "dataset_info.json"
|
||||
|
||||
|
||||
def get_save_dir(model_name: str) -> str:
|
||||
return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1])
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
return os.path.join(DEFAULT_SAVE_DIR, *args)
|
||||
|
||||
|
||||
def get_config_path() -> os.PathLike:
|
||||
|
@ -29,7 +29,7 @@ def load_config() -> Dict[str, Any]:
|
|||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
return {"lang": "", "last_model": "", "path_dict": {}}
|
||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
|
||||
def save_config(lang: str, model_name: str, model_path: str) -> None:
|
||||
|
@ -56,7 +56,7 @@ def get_template(model_name: str) -> str:
|
|||
|
||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
checkpoints = []
|
||||
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for checkpoint in os.listdir(save_dir):
|
||||
if (
|
||||
|
|
|
@ -16,7 +16,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
|
|||
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
chat_model = WebChatModel()
|
||||
chat_model = WebChatModel(lazy_init=True)
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
|
||||
|
||||
load_btn.click(
|
||||
|
|
|
@ -12,7 +12,7 @@ from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES
|
|||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, load_config
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
||||
|
@ -97,21 +97,25 @@ class Runner:
|
|||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, ckpt) for ckpt in checkpoints]
|
||||
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
|
||||
)
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||
output_dir = get_save_dir(model_name, finetuning_type, output_dir)
|
||||
|
||||
user_config = load_config()
|
||||
cache_dir = user_config.get("cache_dir", None)
|
||||
|
||||
args = dict(
|
||||
stage=TRAINING_STAGES[training_stage],
|
||||
model_name_or_path=get_model_path(model_name),
|
||||
do_train=True,
|
||||
overwrite_cache=True,
|
||||
overwrite_cache=False,
|
||||
cache_dir=cache_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
|
||||
template=template,
|
||||
system_prompt=system_prompt,
|
||||
dataset_dir=dataset_dir,
|
||||
|
@ -172,22 +176,26 @@ class Runner:
|
|||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
|
||||
)
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints))
|
||||
output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints))
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
|
||||
output_dir = get_save_dir(model_name, finetuning_type, "eval_base")
|
||||
|
||||
user_config = load_config()
|
||||
cache_dir = user_config.get("cache_dir", None)
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=get_model_path(model_name),
|
||||
do_eval=True,
|
||||
overwrite_cache=True,
|
||||
overwrite_cache=False,
|
||||
predict_with_generate=True,
|
||||
cache_dir=cache_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
|
||||
template=template,
|
||||
system_prompt=system_prompt,
|
||||
dataset_dir=dataset_dir,
|
||||
|
|
|
@ -90,7 +90,7 @@ def get_eval_results(path: os.PathLike) -> str:
|
|||
|
||||
|
||||
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
|
||||
log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl")
|
||||
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
|
||||
if not os.path.isfile(log_file):
|
||||
return None
|
||||
|
||||
|
@ -139,7 +139,7 @@ def save_model(
|
|||
return
|
||||
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
|
||||
)
|
||||
|
||||
if not save_dir:
|
||||
|
|
Loading…
Reference in New Issue