diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index be853604..9b48c89a 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -184,14 +184,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(open=False) as rlhf_tab: with gr.Row(): - dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01) - dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01) - orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01) + pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01) + pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01) + pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair"], value="sigmoid") reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True) + with gr.Column(): + ppo_score_norm = gr.Checkbox() + ppo_whiten_rewards = gr.Checkbox() - input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model}) + input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards}) elem_dict.update( - dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model) + dict( + rlhf_tab=rlhf_tab, + pref_beta=pref_beta, + pref_ftx=pref_ftx, + pref_loss=pref_loss, + reward_model=reward_model, + ppo_score_norm=ppo_score_norm, + ppo_whiten_rewards=ppo_whiten_rewards, + ) ) with gr.Accordion(open=False) as galore_tab: diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 7afe6ec3..bd4a4205 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -774,52 +774,52 @@ LOCALES = { "label": "RLHF 参数设置", }, }, - "dpo_beta": { + "pref_beta": { "en": { - "label": "DPO beta", - "info": "Value of the beta parameter in the DPO loss.", + "label": "Beta value", + "info": "Value of the beta parameter in the loss.", }, "ru": { - "label": "DPO бета", - "info": "Значение параметра бета в функции потерь DPO.", + "label": "Бета значение", + "info": "Значение параметра бета в функции потерь.", }, "zh": { - "label": "DPO beta 参数", - "info": "DPO 损失函数中 beta 超参数大小。", + "label": "Beta 参数", + "info": "损失函数中 beta 超参数大小。", }, }, - "dpo_ftx": { + "pref_ftx": { "en": { - "label": "DPO-ftx weight", - "info": "The weight of SFT loss in the DPO-ftx.", + "label": "Ftx gamma", + "info": "The weight of SFT loss in the final loss.", }, "ru": { - "label": "Вес DPO-ftx", - "info": "Вес функции потерь SFT в DPO-ftx.", + "label": "Ftx гамма", + "info": "Вес потери SFT в итоговой потере.", }, "zh": { - "label": "DPO-ftx 权重", - "info": "DPO-ftx 中 SFT 损失的权重大小。", + "label": "Ftx gamma", + "info": "损失函数中 SFT 损失的权重大小。", }, }, - "orpo_beta": { + "pref_loss": { "en": { - "label": "ORPO beta", - "info": "Value of the beta parameter in the ORPO loss.", + "label": "Loss type", + "info": "The type of the loss function.", }, "ru": { - "label": "ORPO бета", - "info": "Значение параметра бета в функции потерь ORPO.", + "label": "Тип потерь", + "info": "Тип функции потерь.", }, "zh": { - "label": "ORPO beta 参数", - "info": "ORPO 损失函数中 beta 超参数大小。", + "label": "损失类型", + "info": "损失函数的类型。", }, }, "reward_model": { "en": { "label": "Reward model", - "info": "Adapter of the reward model for PPO training.", + "info": "Adapter of the reward model in PPO training.", }, "ru": { "label": "Модель вознаграждения", @@ -830,6 +830,34 @@ LOCALES = { "info": "PPO 训练中奖励模型的适配器路径。", }, }, + "ppo_score_norm": { + "en": { + "label": "Score norm", + "info": "Normalizing scores in PPO training.", + }, + "ru": { + "label": "Норма оценок", + "info": "Нормализация оценок в тренировке PPO.", + }, + "zh": { + "label": "奖励模型", + "info": "PPO 训练中归一化奖励分数。", + }, + }, + "ppo_whiten_rewards": { + "en": { + "label": "Whiten rewards", + "info": "Whiten the rewards in PPO training.", + }, + "ru": { + "label": "Белые вознаграждения", + "info": "Осветлите вознаграждения в обучении PPO.", + }, + "zh": { + "label": "白化奖励", + "info": "PPO 训练中将奖励分数做白化处理。", + }, + }, "galore_tab": { "en": { "label": "GaLore configurations", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index ef911a16..24046e62 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -145,11 +145,14 @@ class Runner: plot_loss=True, ) + # freeze config if args["finetuning_type"] == "freeze": args["freeze_trainable_layers"] = get("train.freeze_trainable_layers") args["freeze_trainable_modules"] = get("train.freeze_trainable_modules") args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None - elif args["finetuning_type"] == "lora": + + # lora config + if args["finetuning_type"] == "lora": args["lora_rank"] = get("train.lora_rank") args["lora_alpha"] = get("train.lora_alpha") args["lora_dropout"] = get("train.lora_dropout") @@ -163,6 +166,7 @@ class Runner: if args["use_llama_pro"]: args["num_layer_trainable"] = get("train.num_layer_trainable") + # rlhf config if args["stage"] == "ppo": args["reward_model"] = ",".join( [ @@ -171,31 +175,41 @@ class Runner: ] ) args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full" + args["ppo_score_norm"] = get("train.ppo_score_norm") + args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards") + args["top_k"] = 0 + args["top_p"] = 0.9 elif args["stage"] == "dpo": - args["dpo_beta"] = get("train.dpo_beta") - args["dpo_ftx"] = get("train.dpo_ftx") + args["dpo_beta"] = get("train.pref_beta") + args["dpo_ftx"] = get("train.pref_ftx") + args["dpo_loss"] = get("train.pref_loss") + elif args["stage"] == "kto": + args["kto_beta"] = get("train.pref_beta") + args["kto_ftx"] = get("train.pref_ftx") elif args["stage"] == "orpo": - args["orpo_beta"] = get("train.orpo_beta") - - if get("train.val_size") > 1e-6 and args["stage"] != "ppo": - args["val_size"] = get("train.val_size") - args["evaluation_strategy"] = "steps" - args["eval_steps"] = args["save_steps"] - args["per_device_eval_batch_size"] = args["per_device_train_batch_size"] - args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"] + args["orpo_beta"] = get("train.pref_beta") + # galore config if args["use_galore"]: args["galore_rank"] = get("train.galore_rank") args["galore_update_interval"] = get("train.galore_update_interval") args["galore_scale"] = get("train.galore_scale") args["galore_target"] = get("train.galore_target") + # badam config if args["use_badam"]: args["badam_mode"] = get("train.badam_mode") args["badam_switch_mode"] = get("train.badam_switch_mode") args["badam_switch_interval"] = get("train.badam_switch_interval") args["badam_update_ratio"] = get("train.badam_update_ratio") + # eval config + if get("train.val_size") > 1e-6 and args["stage"] != "ppo": + args["val_size"] = get("train.val_size") + args["evaluation_strategy"] = "steps" + args["eval_steps"] = args["save_steps"] + args["per_device_eval_batch_size"] = args["per_device_train_batch_size"] + return args def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: