support orpo in webui
This commit is contained in:
parent
17bf8a2c3a
commit
5195add324
|
@ -169,10 +169,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||||
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
||||||
|
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
|
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
|
||||||
|
|
||||||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
|
||||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
|
elem_dict.update(
|
||||||
|
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Accordion(open=False) as galore_tab:
|
with gr.Accordion(open=False) as galore_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
|
@ -757,6 +757,20 @@ LOCALES = {
|
||||||
"info": "DPO-ftx 中 SFT 损失的权重大小。",
|
"info": "DPO-ftx 中 SFT 损失的权重大小。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"orpo_beta": {
|
||||||
|
"en": {
|
||||||
|
"label": "ORPO beta",
|
||||||
|
"info": "Value of the beta parameter in the ORPO loss.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "ORPO бета",
|
||||||
|
"info": "Значение параметра бета в функции потерь ORPO.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "ORPO beta 参数",
|
||||||
|
"info": "ORPO 损失函数中 beta 超参数大小。",
|
||||||
|
},
|
||||||
|
},
|
||||||
"reward_model": {
|
"reward_model": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Reward model",
|
"label": "Reward model",
|
||||||
|
|
|
@ -174,10 +174,11 @@ class Runner:
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||||
|
elif args["stage"] == "dpo":
|
||||||
if args["stage"] == "dpo":
|
|
||||||
args["dpo_beta"] = get("train.dpo_beta")
|
args["dpo_beta"] = get("train.dpo_beta")
|
||||||
args["dpo_ftx"] = get("train.dpo_ftx")
|
args["dpo_ftx"] = get("train.dpo_ftx")
|
||||||
|
elif args["stage"] == "orpo":
|
||||||
|
args["orpo_beta"] = get("train.orpo_beta")
|
||||||
|
|
||||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||||
args["val_size"] = get("train.val_size")
|
args["val_size"] = get("train.val_size")
|
||||||
|
|
Loading…
Reference in New Issue