add kto to webui
This commit is contained in:
parent
d52fae2fa8
commit
9b0f4d7602
|
@ -184,14 +184,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
|
|
||||||
with gr.Accordion(open=False) as rlhf_tab:
|
with gr.Accordion(open=False) as rlhf_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||||
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
|
pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
|
||||||
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair"], value="sigmoid")
|
||||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
|
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
|
||||||
|
with gr.Column():
|
||||||
|
ppo_score_norm = gr.Checkbox()
|
||||||
|
ppo_whiten_rewards = gr.Checkbox()
|
||||||
|
|
||||||
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
|
input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards})
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
|
dict(
|
||||||
|
rlhf_tab=rlhf_tab,
|
||||||
|
pref_beta=pref_beta,
|
||||||
|
pref_ftx=pref_ftx,
|
||||||
|
pref_loss=pref_loss,
|
||||||
|
reward_model=reward_model,
|
||||||
|
ppo_score_norm=ppo_score_norm,
|
||||||
|
ppo_whiten_rewards=ppo_whiten_rewards,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Accordion(open=False) as galore_tab:
|
with gr.Accordion(open=False) as galore_tab:
|
||||||
|
|
|
@ -774,52 +774,52 @@ LOCALES = {
|
||||||
"label": "RLHF 参数设置",
|
"label": "RLHF 参数设置",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"dpo_beta": {
|
"pref_beta": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "DPO beta",
|
"label": "Beta value",
|
||||||
"info": "Value of the beta parameter in the DPO loss.",
|
"info": "Value of the beta parameter in the loss.",
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
"label": "DPO бета",
|
"label": "Бета значение",
|
||||||
"info": "Значение параметра бета в функции потерь DPO.",
|
"info": "Значение параметра бета в функции потерь.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "DPO beta 参数",
|
"label": "Beta 参数",
|
||||||
"info": "DPO 损失函数中 beta 超参数大小。",
|
"info": "损失函数中 beta 超参数大小。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"dpo_ftx": {
|
"pref_ftx": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "DPO-ftx weight",
|
"label": "Ftx gamma",
|
||||||
"info": "The weight of SFT loss in the DPO-ftx.",
|
"info": "The weight of SFT loss in the final loss.",
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
"label": "Вес DPO-ftx",
|
"label": "Ftx гамма",
|
||||||
"info": "Вес функции потерь SFT в DPO-ftx.",
|
"info": "Вес потери SFT в итоговой потере.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "DPO-ftx 权重",
|
"label": "Ftx gamma",
|
||||||
"info": "DPO-ftx 中 SFT 损失的权重大小。",
|
"info": "损失函数中 SFT 损失的权重大小。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"orpo_beta": {
|
"pref_loss": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "ORPO beta",
|
"label": "Loss type",
|
||||||
"info": "Value of the beta parameter in the ORPO loss.",
|
"info": "The type of the loss function.",
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
"label": "ORPO бета",
|
"label": "Тип потерь",
|
||||||
"info": "Значение параметра бета в функции потерь ORPO.",
|
"info": "Тип функции потерь.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "ORPO beta 参数",
|
"label": "损失类型",
|
||||||
"info": "ORPO 损失函数中 beta 超参数大小。",
|
"info": "损失函数的类型。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"reward_model": {
|
"reward_model": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Reward model",
|
"label": "Reward model",
|
||||||
"info": "Adapter of the reward model for PPO training.",
|
"info": "Adapter of the reward model in PPO training.",
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
"label": "Модель вознаграждения",
|
"label": "Модель вознаграждения",
|
||||||
|
@ -830,6 +830,34 @@ LOCALES = {
|
||||||
"info": "PPO 训练中奖励模型的适配器路径。",
|
"info": "PPO 训练中奖励模型的适配器路径。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"ppo_score_norm": {
|
||||||
|
"en": {
|
||||||
|
"label": "Score norm",
|
||||||
|
"info": "Normalizing scores in PPO training.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Норма оценок",
|
||||||
|
"info": "Нормализация оценок в тренировке PPO.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "奖励模型",
|
||||||
|
"info": "PPO 训练中归一化奖励分数。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"ppo_whiten_rewards": {
|
||||||
|
"en": {
|
||||||
|
"label": "Whiten rewards",
|
||||||
|
"info": "Whiten the rewards in PPO training.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Белые вознаграждения",
|
||||||
|
"info": "Осветлите вознаграждения в обучении PPO.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "白化奖励",
|
||||||
|
"info": "PPO 训练中将奖励分数做白化处理。",
|
||||||
|
},
|
||||||
|
},
|
||||||
"galore_tab": {
|
"galore_tab": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "GaLore configurations",
|
"label": "GaLore configurations",
|
||||||
|
|
|
@ -145,11 +145,14 @@ class Runner:
|
||||||
plot_loss=True,
|
plot_loss=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# freeze config
|
||||||
if args["finetuning_type"] == "freeze":
|
if args["finetuning_type"] == "freeze":
|
||||||
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
|
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
|
||||||
args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
|
args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
|
||||||
args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
|
args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
|
||||||
elif args["finetuning_type"] == "lora":
|
|
||||||
|
# lora config
|
||||||
|
if args["finetuning_type"] == "lora":
|
||||||
args["lora_rank"] = get("train.lora_rank")
|
args["lora_rank"] = get("train.lora_rank")
|
||||||
args["lora_alpha"] = get("train.lora_alpha")
|
args["lora_alpha"] = get("train.lora_alpha")
|
||||||
args["lora_dropout"] = get("train.lora_dropout")
|
args["lora_dropout"] = get("train.lora_dropout")
|
||||||
|
@ -163,6 +166,7 @@ class Runner:
|
||||||
if args["use_llama_pro"]:
|
if args["use_llama_pro"]:
|
||||||
args["num_layer_trainable"] = get("train.num_layer_trainable")
|
args["num_layer_trainable"] = get("train.num_layer_trainable")
|
||||||
|
|
||||||
|
# rlhf config
|
||||||
if args["stage"] == "ppo":
|
if args["stage"] == "ppo":
|
||||||
args["reward_model"] = ",".join(
|
args["reward_model"] = ",".join(
|
||||||
[
|
[
|
||||||
|
@ -171,31 +175,41 @@ class Runner:
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||||
|
args["ppo_score_norm"] = get("train.ppo_score_norm")
|
||||||
|
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
|
||||||
|
args["top_k"] = 0
|
||||||
|
args["top_p"] = 0.9
|
||||||
elif args["stage"] == "dpo":
|
elif args["stage"] == "dpo":
|
||||||
args["dpo_beta"] = get("train.dpo_beta")
|
args["dpo_beta"] = get("train.pref_beta")
|
||||||
args["dpo_ftx"] = get("train.dpo_ftx")
|
args["dpo_ftx"] = get("train.pref_ftx")
|
||||||
|
args["dpo_loss"] = get("train.pref_loss")
|
||||||
|
elif args["stage"] == "kto":
|
||||||
|
args["kto_beta"] = get("train.pref_beta")
|
||||||
|
args["kto_ftx"] = get("train.pref_ftx")
|
||||||
elif args["stage"] == "orpo":
|
elif args["stage"] == "orpo":
|
||||||
args["orpo_beta"] = get("train.orpo_beta")
|
args["orpo_beta"] = get("train.pref_beta")
|
||||||
|
|
||||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
|
||||||
args["val_size"] = get("train.val_size")
|
|
||||||
args["evaluation_strategy"] = "steps"
|
|
||||||
args["eval_steps"] = args["save_steps"]
|
|
||||||
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
|
||||||
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
|
|
||||||
|
|
||||||
|
# galore config
|
||||||
if args["use_galore"]:
|
if args["use_galore"]:
|
||||||
args["galore_rank"] = get("train.galore_rank")
|
args["galore_rank"] = get("train.galore_rank")
|
||||||
args["galore_update_interval"] = get("train.galore_update_interval")
|
args["galore_update_interval"] = get("train.galore_update_interval")
|
||||||
args["galore_scale"] = get("train.galore_scale")
|
args["galore_scale"] = get("train.galore_scale")
|
||||||
args["galore_target"] = get("train.galore_target")
|
args["galore_target"] = get("train.galore_target")
|
||||||
|
|
||||||
|
# badam config
|
||||||
if args["use_badam"]:
|
if args["use_badam"]:
|
||||||
args["badam_mode"] = get("train.badam_mode")
|
args["badam_mode"] = get("train.badam_mode")
|
||||||
args["badam_switch_mode"] = get("train.badam_switch_mode")
|
args["badam_switch_mode"] = get("train.badam_switch_mode")
|
||||||
args["badam_switch_interval"] = get("train.badam_switch_interval")
|
args["badam_switch_interval"] = get("train.badam_switch_interval")
|
||||||
args["badam_update_ratio"] = get("train.badam_update_ratio")
|
args["badam_update_ratio"] = get("train.badam_update_ratio")
|
||||||
|
|
||||||
|
# eval config
|
||||||
|
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||||
|
args["val_size"] = get("train.val_size")
|
||||||
|
args["evaluation_strategy"] = "steps"
|
||||||
|
args["eval_steps"] = args["save_steps"]
|
||||||
|
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||||
|
|
Loading…
Reference in New Issue