fix arg dtype

This commit is contained in:
hiyouga 2024-03-05 20:53:30 +08:00
parent 259af60d28
commit e0c47358f9
2 changed files with 2 additions and 2 deletions

View File

@ -123,7 +123,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=0.1, scale=1)
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1, scale=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2)

View File

@ -144,7 +144,7 @@ class Runner:
args["name_module_trainable"] = get("train.name_module_trainable")
elif args["finetuning_type"] == "lora":
args["lora_rank"] = int(get("train.lora_rank"))
args["lora_alpha"] = float(get("train.lora_alpha"))
args["lora_alpha"] = int(get("train.lora_alpha"))
args["lora_dropout"] = float(get("train.lora_dropout"))
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
args["use_rslora"] = get("train.use_rslora")