add kto to webui

This commit is contained in:
hiyouga 2024-05-20 21:20:25 +08:00
parent d52fae2fa8
commit 9b0f4d7602
3 changed files with 91 additions and 38 deletions

View File

@ -184,14 +184,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as rlhf_tab: with gr.Accordion(open=False) as rlhf_tab:
with gr.Row(): with gr.Row():
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01) pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01) pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01) pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair"], value="sigmoid")
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True) reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
with gr.Column():
ppo_score_norm = gr.Checkbox()
ppo_whiten_rewards = gr.Checkbox()
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model}) input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards})
elem_dict.update( elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model) dict(
rlhf_tab=rlhf_tab,
pref_beta=pref_beta,
pref_ftx=pref_ftx,
pref_loss=pref_loss,
reward_model=reward_model,
ppo_score_norm=ppo_score_norm,
ppo_whiten_rewards=ppo_whiten_rewards,
)
) )
with gr.Accordion(open=False) as galore_tab: with gr.Accordion(open=False) as galore_tab:

View File

@ -774,52 +774,52 @@ LOCALES = {
"label": "RLHF 参数设置", "label": "RLHF 参数设置",
}, },
}, },
"dpo_beta": { "pref_beta": {
"en": { "en": {
"label": "DPO beta", "label": "Beta value",
"info": "Value of the beta parameter in the DPO loss.", "info": "Value of the beta parameter in the loss.",
}, },
"ru": { "ru": {
"label": "DPO бета", "label": "Бета значение",
"info": "Значение параметра бета в функции потерь DPO.", "info": "Значение параметра бета в функции потерь.",
}, },
"zh": { "zh": {
"label": "DPO beta 参数", "label": "Beta 参数",
"info": "DPO 损失函数中 beta 超参数大小。", "info": "损失函数中 beta 超参数大小。",
}, },
}, },
"dpo_ftx": { "pref_ftx": {
"en": { "en": {
"label": "DPO-ftx weight", "label": "Ftx gamma",
"info": "The weight of SFT loss in the DPO-ftx.", "info": "The weight of SFT loss in the final loss.",
}, },
"ru": { "ru": {
"label": "Вес DPO-ftx", "label": "Ftx гамма",
"info": "Вес функции потерь SFT в DPO-ftx.", "info": "Вес потери SFT в итоговой потере.",
}, },
"zh": { "zh": {
"label": "DPO-ftx 权重", "label": "Ftx gamma",
"info": "DPO-ftx 中 SFT 损失的权重大小。", "info": "损失函数中 SFT 损失的权重大小。",
}, },
}, },
"orpo_beta": { "pref_loss": {
"en": { "en": {
"label": "ORPO beta", "label": "Loss type",
"info": "Value of the beta parameter in the ORPO loss.", "info": "The type of the loss function.",
}, },
"ru": { "ru": {
"label": "ORPO бета", "label": "Тип потерь",
"info": "Значение параметра бета в функции потерь ORPO.", "info": "Тип функции потерь.",
}, },
"zh": { "zh": {
"label": "ORPO beta 参数", "label": "损失类型",
"info": "ORPO 损失函数中 beta 超参数大小", "info": "损失函数的类型",
}, },
}, },
"reward_model": { "reward_model": {
"en": { "en": {
"label": "Reward model", "label": "Reward model",
"info": "Adapter of the reward model for PPO training.", "info": "Adapter of the reward model in PPO training.",
}, },
"ru": { "ru": {
"label": "Модель вознаграждения", "label": "Модель вознаграждения",
@ -830,6 +830,34 @@ LOCALES = {
"info": "PPO 训练中奖励模型的适配器路径。", "info": "PPO 训练中奖励模型的适配器路径。",
}, },
}, },
"ppo_score_norm": {
"en": {
"label": "Score norm",
"info": "Normalizing scores in PPO training.",
},
"ru": {
"label": "Норма оценок",
"info": "Нормализация оценок в тренировке PPO.",
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中归一化奖励分数。",
},
},
"ppo_whiten_rewards": {
"en": {
"label": "Whiten rewards",
"info": "Whiten the rewards in PPO training.",
},
"ru": {
"label": "Белые вознаграждения",
"info": "Осветлите вознаграждения в обучении PPO.",
},
"zh": {
"label": "白化奖励",
"info": "PPO 训练中将奖励分数做白化处理。",
},
},
"galore_tab": { "galore_tab": {
"en": { "en": {
"label": "GaLore configurations", "label": "GaLore configurations",

View File

@ -145,11 +145,14 @@ class Runner:
plot_loss=True, plot_loss=True,
) )
# freeze config
if args["finetuning_type"] == "freeze": if args["finetuning_type"] == "freeze":
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers") args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
args["freeze_trainable_modules"] = get("train.freeze_trainable_modules") args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
elif args["finetuning_type"] == "lora":
# lora config
if args["finetuning_type"] == "lora":
args["lora_rank"] = get("train.lora_rank") args["lora_rank"] = get("train.lora_rank")
args["lora_alpha"] = get("train.lora_alpha") args["lora_alpha"] = get("train.lora_alpha")
args["lora_dropout"] = get("train.lora_dropout") args["lora_dropout"] = get("train.lora_dropout")
@ -163,6 +166,7 @@ class Runner:
if args["use_llama_pro"]: if args["use_llama_pro"]:
args["num_layer_trainable"] = get("train.num_layer_trainable") args["num_layer_trainable"] = get("train.num_layer_trainable")
# rlhf config
if args["stage"] == "ppo": if args["stage"] == "ppo":
args["reward_model"] = ",".join( args["reward_model"] = ",".join(
[ [
@ -171,31 +175,41 @@ class Runner:
] ]
) )
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full" args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
args["ppo_score_norm"] = get("train.ppo_score_norm")
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
args["top_k"] = 0
args["top_p"] = 0.9
elif args["stage"] == "dpo": elif args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta") args["dpo_beta"] = get("train.pref_beta")
args["dpo_ftx"] = get("train.dpo_ftx") args["dpo_ftx"] = get("train.pref_ftx")
args["dpo_loss"] = get("train.pref_loss")
elif args["stage"] == "kto":
args["kto_beta"] = get("train.pref_beta")
args["kto_ftx"] = get("train.pref_ftx")
elif args["stage"] == "orpo": elif args["stage"] == "orpo":
args["orpo_beta"] = get("train.orpo_beta") args["orpo_beta"] = get("train.pref_beta")
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
args["evaluation_strategy"] = "steps"
args["eval_steps"] = args["save_steps"]
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
# galore config
if args["use_galore"]: if args["use_galore"]:
args["galore_rank"] = get("train.galore_rank") args["galore_rank"] = get("train.galore_rank")
args["galore_update_interval"] = get("train.galore_update_interval") args["galore_update_interval"] = get("train.galore_update_interval")
args["galore_scale"] = get("train.galore_scale") args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target") args["galore_target"] = get("train.galore_target")
# badam config
if args["use_badam"]: if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode") args["badam_mode"] = get("train.badam_mode")
args["badam_switch_mode"] = get("train.badam_switch_mode") args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_switch_interval"] = get("train.badam_switch_interval") args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio") args["badam_update_ratio"] = get("train.badam_update_ratio")
# eval config
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
args["evaluation_strategy"] = "steps"
args["eval_steps"] = args["save_steps"]
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
return args return args
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: