fix badam configs

This commit is contained in:
hiyouga 2024-05-02 02:47:04 +08:00
parent f1c0eedeb3
commit 9433c8c215
5 changed files with 44 additions and 69 deletions

View File

@ -221,16 +221,18 @@ class BAdamArgument:
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_switch_interval: Optional[int] = field(
default=50,
metadata={
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
},
)
badam_update_ratio: float = field(
default=0.0,
default=0.05,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
@ -308,6 +310,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.use_galore and self.finetuning_type == "lora":
raise ValueError("Cannot use LoRA with GaLore together.")
if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.")
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")

View File

@ -317,14 +317,14 @@ def _create_badam_optimizer(
base_optimizer=base_optimizer,
named_parameters_list=list(model.named_parameters()),
block_prefix_list=None,
switch_block_every=finetuning_args.badam_switch_block_every,
switch_block_every=finetuning_args.badam_switch_interval,
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
f"switch block every {finetuning_args.badam_switch_block_every} steps, "
f"switch block every {finetuning_args.badam_switch_interval} steps, "
f"default start block is {finetuning_args.badam_start_block}"
)

View File

@ -215,17 +215,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
use_badam = gr.Checkbox()
badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
badam_switch_block_every = gr.Slider(value=50, minimum=-1, maximum=200, step=1)
badam_update_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.01)
badam_switch_interval = gr.Slider(value=50, minimum=1, maximum=1024, step=1)
badam_update_ratio = gr.Slider(value=0.05, minimum=0, maximum=1, step=0.01)
input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_block_every, badam_update_ratio})
input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_interval, badam_update_ratio})
elem_dict.update(
dict(
badam_tab=badam_tab,
use_badam=use_badam,
badam_mode=badam_mode,
badam_switch_mode=badam_switch_mode,
badam_switch_block_every=badam_switch_block_every,
badam_switch_interval=badam_switch_interval,
badam_update_ratio=badam_update_ratio,
)
)

View File

@ -905,15 +905,15 @@ LOCALES = {
"use_badam": {
"en": {
"label": "Use BAdam",
"info": "Enable the block coordinate optimization with Adam.",
"info": "Enable the BAdam optimizer.",
},
"ru": {
"label": "Использовать BAdam",
"info": "Включите блочную оптимизацию координат с Adam.",
"info": "Включите оптимизатор BAdam.",
},
"zh": {
"label": "使用 BAdam",
"info": "使用多Block协同的Adam优化器。",
"info": "使用 BAdam 优化器。",
},
},
"badam_mode": {
@ -923,25 +923,11 @@ LOCALES = {
},
"ru": {
"label": "Режим BAdam",
"info": "Использовать оптимизатор BAdam с обработкой слоев или с обработкой коэффициентов.",
"info": "Использовать ли оптимизатор BAdam с послоевой или пропорциональной настройкой.",
},
"zh": {
"label": "BAdam 模式",
"info": "使用layer或者ratio比例模式。",
},
},
"badam_switch_block_every": {
"en": {
"label": "Switch block frequency",
"info": "How often to switch model's block update. Set to -1 to disable the block update.",
},
"ru": {
"label": "Частота переключения",
"info": "Как часто переключать обновление блока модели. Установите -1, чтобы отключить обновление блока.",
},
"zh": {
"label": "切换block的频率",
"info": "控制切换block切换的频率如果是-1,则不切换。",
"info": "使用 layer-wise 或 ratio-wise BAdam 优化器。",
},
},
"badam_switch_mode": {
@ -950,12 +936,26 @@ LOCALES = {
"info": "The strategy of picking block to update for layer-wise BAdam.",
},
"ru": {
"label": "Переключить режим",
"info": "Стратегия выбора блока для обновления в методе BAdam по слоям.",
"label": "Режим переключения",
"info": "Стратегия выбора блока для обновления для послойного BAdam.",
},
"zh": {
"label": "Block切换策略",
"info": "如果是layer类型的训练模式如何切换block。",
"label": "切换策略",
"info": "Layer-wise BAdam 优化器的块切换策略。",
},
},
"badam_switch_interval": {
"en": {
"label": "Switch interval",
"info": "Number of steps to update the block for layer-wise BAdam.",
},
"ru": {
"label": "Интервал переключения",
"info": "количество шагов для обновления блока для пошагового BAdam.",
},
"zh": {
"label": "切换频率",
"info": "Layer-wise BAdam 优化器的块切换频率。",
},
},
"badam_update_ratio": {
@ -965,39 +965,11 @@ LOCALES = {
},
"ru": {
"label": "Коэффициент обновления",
"info": "Коэффициент обновления для метода BAdam, основанного на коэффициентах.",
"info": "Коэффициент обновления для BAdam с учётом соотношений.",
},
"zh": {
"label": "Block更新比例",
"info": "如果是比例类型的训练模式block每次更新的范围比例。",
},
},
"badam_mask_mode": {
"en": {
"label": "Mask mode",
"info": "The mode of the mask for BAdam optimizer.",
},
"ru": {
"label": "Режим маски",
"info": "Режим маски для оптимизатора BAdam.",
},
"zh": {
"label": "Mask模式",
"info": "BAdam优化器内训练参数的mask关系。",
},
},
"badam_verbose": {
"en": {
"label": "Verbosity level",
"info": "0 for no print, 1 for print the block prefix, 2 for print trainable parameters.",
},
"ru": {
"label": "Уровень многословности",
"info": "0 для отсутствия печати, 1 для печати префикса блока, 2 для печати обучаемых параметров.",
},
"zh": {
"label": "输出日志级别",
"info": "0不输出1输出block前缀 1输出可训练的参数。",
"label": "Block 更新比例",
"info": "Ratio-wise BAdam 优化器的更新比例。",
},
},
"cmd_preview_btn": {

View File

@ -147,11 +147,11 @@ class Runner:
shift_attn=get("train.shift_attn"),
report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"),
use_badam=get("train.use_badam"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"),
use_badam=get("train.use_badam"),
)
args["disable_tqdm"] = True
@ -201,11 +201,9 @@ class Runner:
if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode")
args["badam_switch_block_every"] = get("train.badam_switch_block_every")
args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio")
args["badam_mask_mode"] = get("train.badam_mask_mode")
args["badam_verbose"] = get("train.badam_verbose")
return args