web UI integrating RLHF
This commit is contained in:
parent
2f2fd55d81
commit
ec94274ca1
|
@ -68,7 +68,7 @@
|
||||||
| ---------------------- | -------------- | ----------------- | ---- | ----- |
|
| ---------------------- | -------------- | ----------------- | ---- | ----- |
|
||||||
| Pre-Training | ✅ | ✅ | ✅ | ✅ |
|
| Pre-Training | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Supervised Fine-Tuning | ✅ | ✅ | ✅ | ✅ |
|
| Supervised Fine-Tuning | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Reward Model Training | | | ✅ | ✅ |
|
| Reward Modeling | | | ✅ | ✅ |
|
||||||
| PPO Training | | | ✅ | ✅ |
|
| PPO Training | | | ✅ | ✅ |
|
||||||
| DPO Training | ✅ | | ✅ | ✅ |
|
| DPO Training | ✅ | | ✅ | ✅ |
|
||||||
|
|
||||||
|
@ -103,7 +103,7 @@
|
||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
- For reward modelling or DPO training:
|
- For reward modeling or DPO training:
|
||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
@ -206,7 +206,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
### Reward Model Training
|
### Reward Modeling
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
|
|
@ -37,7 +37,9 @@ def run_ppo(
|
||||||
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
|
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
|
||||||
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
||||||
ppo_epochs=1,
|
ppo_epochs=1,
|
||||||
max_grad_norm=training_args.max_grad_norm
|
max_grad_norm=training_args.max_grad_norm,
|
||||||
|
seed=training_args.seed,
|
||||||
|
optimize_cuda_cache=True
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
|
|
|
@ -29,14 +29,16 @@ def load_config() -> Dict[str, Any]:
|
||||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except:
|
except:
|
||||||
return {"last_model": "", "path_dict": {}}
|
return {"lang": "", "last_model": "", "path_dict": {}}
|
||||||
|
|
||||||
|
|
||||||
def save_config(model_name: str, model_path: str) -> None:
|
def save_config(lang: str, model_name: str, model_path: str) -> None:
|
||||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
user_config["last_model"] = model_name
|
user_config["lang"] = lang or user_config["lang"]
|
||||||
user_config["path_dict"][model_name] = model_path
|
if model_name:
|
||||||
|
user_config["last_model"] = model_name
|
||||||
|
user_config["path_dict"][model_name] = model_path
|
||||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||||
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from llmtuner.webui.components.top import create_top
|
from llmtuner.webui.components.top import create_top
|
||||||
from llmtuner.webui.components.sft import create_sft_tab
|
from llmtuner.webui.components.train import create_train_tab
|
||||||
from llmtuner.webui.components.eval import create_eval_tab
|
from llmtuner.webui.components.eval import create_eval_tab
|
||||||
from llmtuner.webui.components.infer import create_infer_tab
|
from llmtuner.webui.components.infer import create_infer_tab
|
||||||
from llmtuner.webui.components.export import create_export_tab
|
from llmtuner.webui.components.export import create_export_tab
|
||||||
|
|
|
@ -20,22 +20,25 @@ def create_top() -> Dict[str, "Component"]:
|
||||||
model_path = gr.Textbox(scale=3)
|
model_path = gr.Textbox(scale=3)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
|
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||||
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||||
refresh_btn = gr.Button(scale=1)
|
refresh_btn = gr.Button(scale=1)
|
||||||
|
|
||||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
quantization_bit = gr.Dropdown(["", "8", "4"], scale=1)
|
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
|
||||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
|
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
|
||||||
source_prefix = gr.Textbox(scale=2)
|
source_prefix = gr.Textbox(scale=2)
|
||||||
|
|
||||||
|
lang.change(save_config, [lang, model_name, model_path])
|
||||||
|
|
||||||
model_name.change(
|
model_name.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
).then(
|
).then(
|
||||||
get_model_path, [model_name], [model_path]
|
get_model_path, [model_name], [model_path]
|
||||||
) # do not save config since the below line will save
|
) # do not save config since the below line will save
|
||||||
model_path.change(save_config, [model_name, model_path])
|
|
||||||
|
model_path.change(save_config, [lang, model_name, model_path])
|
||||||
|
|
||||||
finetuning_type.change(
|
finetuning_type.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
|
@ -43,7 +46,9 @@ def create_top() -> Dict[str, "Component"]:
|
||||||
can_quantize, [finetuning_type], [quantization_bit]
|
can_quantize, [finetuning_type], [quantization_bit]
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False)
|
refresh_btn.click(
|
||||||
|
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||||
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
lang=lang,
|
lang=lang,
|
||||||
|
|
|
@ -3,7 +3,7 @@ from transformers.trainer_utils import SchedulerType
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
||||||
from llmtuner.webui.components.data import create_preview_box
|
from llmtuner.webui.components.data import create_preview_box
|
||||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
||||||
from llmtuner.webui.runner import Runner
|
from llmtuner.webui.runner import Runner
|
||||||
|
|
||||||
|
|
||||||
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
|
@ -40,7 +40,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
||||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||||
lr_scheduler_type = gr.Dropdown(
|
lr_scheduler_type = gr.Dropdown(
|
||||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
|
||||||
)
|
)
|
||||||
max_grad_norm = gr.Textbox(value="1.0")
|
max_grad_norm = gr.Textbox(value="1.0")
|
||||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||||
|
@ -60,6 +60,20 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
||||||
lora_target = gr.Textbox(scale=2)
|
lora_target = gr.Textbox(scale=2)
|
||||||
resume_lora_training = gr.Checkbox(value=True, scale=1)
|
resume_lora_training = gr.Checkbox(value=True, scale=1)
|
||||||
|
|
||||||
|
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||||
|
with gr.Row():
|
||||||
|
rlhf_method = gr.Dropdown(choices=["None", "Reward Modeling", "PPO", "DPO"], value="None", scale=1)
|
||||||
|
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
|
||||||
|
reward_model = gr.Dropdown(scale=2)
|
||||||
|
refresh_btn = gr.Button(scale=1)
|
||||||
|
|
||||||
|
refresh_btn.click(
|
||||||
|
list_checkpoint,
|
||||||
|
[top_elems["model_name"], top_elems["finetuning_type"]],
|
||||||
|
[reward_model],
|
||||||
|
queue=False
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cmd_preview_btn = gr.Button()
|
cmd_preview_btn = gr.Button()
|
||||||
start_btn = gr.Button()
|
start_btn = gr.Button()
|
||||||
|
@ -79,7 +93,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
loss_viewer = gr.Plot()
|
loss_viewer = gr.Plot()
|
||||||
|
|
||||||
input_list = [
|
input_components = [
|
||||||
top_elems["lang"],
|
top_elems["lang"],
|
||||||
top_elems["model_name"],
|
top_elems["model_name"],
|
||||||
top_elems["checkpoints"],
|
top_elems["checkpoints"],
|
||||||
|
@ -108,16 +122,19 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
||||||
lora_dropout,
|
lora_dropout,
|
||||||
lora_target,
|
lora_target,
|
||||||
resume_lora_training,
|
resume_lora_training,
|
||||||
|
rlhf_method,
|
||||||
|
dpo_beta,
|
||||||
|
reward_model,
|
||||||
output_dir
|
output_dir
|
||||||
]
|
]
|
||||||
|
|
||||||
output_list = [
|
output_components = [
|
||||||
output_box,
|
output_box,
|
||||||
process_bar
|
process_bar
|
||||||
]
|
]
|
||||||
|
|
||||||
cmd_preview_btn.click(runner.preview_train, input_list, output_list)
|
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
|
||||||
start_btn.click(runner.run_train, input_list, output_list)
|
start_btn.click(runner.run_train, input_components, output_components)
|
||||||
stop_btn.click(runner.set_abort, queue=False)
|
stop_btn.click(runner.set_abort, queue=False)
|
||||||
|
|
||||||
process_bar.change(
|
process_bar.change(
|
||||||
|
@ -152,6 +169,11 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
lora_target=lora_target,
|
lora_target=lora_target,
|
||||||
resume_lora_training=resume_lora_training,
|
resume_lora_training=resume_lora_training,
|
||||||
|
rlhf_tab=rlhf_tab,
|
||||||
|
rlhf_method=rlhf_method,
|
||||||
|
dpo_beta=dpo_beta,
|
||||||
|
reward_model=reward_model,
|
||||||
|
refresh_btn=refresh_btn,
|
||||||
cmd_preview_btn=cmd_preview_btn,
|
cmd_preview_btn=cmd_preview_btn,
|
||||||
start_btn=start_btn,
|
start_btn=start_btn,
|
||||||
stop_btn=stop_btn,
|
stop_btn=stop_btn,
|
|
@ -3,7 +3,7 @@ from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from llmtuner.webui.components import (
|
from llmtuner.webui.components import (
|
||||||
create_top,
|
create_top,
|
||||||
create_sft_tab,
|
create_train_tab,
|
||||||
create_eval_tab,
|
create_eval_tab,
|
||||||
create_infer_tab,
|
create_infer_tab,
|
||||||
create_export_tab,
|
create_export_tab,
|
||||||
|
@ -24,8 +24,8 @@ def create_ui() -> gr.Blocks:
|
||||||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||||
top_elems = create_top()
|
top_elems = create_top()
|
||||||
|
|
||||||
with gr.Tab("SFT"):
|
with gr.Tab("Train"):
|
||||||
sft_elems = create_sft_tab(top_elems, runner)
|
train_elems = create_train_tab(top_elems, runner)
|
||||||
|
|
||||||
with gr.Tab("Evaluate"):
|
with gr.Tab("Evaluate"):
|
||||||
eval_elems = create_eval_tab(top_elems, runner)
|
eval_elems = create_eval_tab(top_elems, runner)
|
||||||
|
@ -36,7 +36,7 @@ def create_ui() -> gr.Blocks:
|
||||||
with gr.Tab("Export"):
|
with gr.Tab("Export"):
|
||||||
export_elems = create_export_tab(top_elems)
|
export_elems = create_export_tab(top_elems)
|
||||||
|
|
||||||
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
|
elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
|
||||||
manager = Manager(elem_list)
|
manager = Manager(elem_list)
|
||||||
|
|
||||||
demo.load(
|
demo.load(
|
||||||
|
@ -59,7 +59,7 @@ def create_web_demo() -> gr.Blocks:
|
||||||
chat_model = WebChatModel(lazy_init=False)
|
chat_model = WebChatModel(lazy_init=False)
|
||||||
|
|
||||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||||
lang = gr.Dropdown(choices=["en", "zh"], value="en")
|
lang = gr.Dropdown(choices=["en", "zh"], value="")
|
||||||
|
|
||||||
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
|
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
|
||||||
|
|
||||||
|
|
|
@ -335,6 +335,44 @@ LOCALES = {
|
||||||
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
|
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"rlhf_tab": {
|
||||||
|
"en": {
|
||||||
|
"label": "RLHF configurations"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "RLHF 参数设置"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rlhf_method": {
|
||||||
|
"en": {
|
||||||
|
"label": "RLHF method",
|
||||||
|
"info": "The RLHF algorithm to adopt."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "RLHF 方法",
|
||||||
|
"info": "RLHF 阶段使用的算法。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"dpo_beta": {
|
||||||
|
"en": {
|
||||||
|
"label": "DPO beta",
|
||||||
|
"info": "Value of the beta parameter in the DPO loss."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "DPO beta 参数",
|
||||||
|
"info": "DPO 损失函数中 beta 超参数大小。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"reward_model": {
|
||||||
|
"en": {
|
||||||
|
"label": "Reward model",
|
||||||
|
"info": "Checkpoint of the reward model for PPO training."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "奖励模型",
|
||||||
|
"info": "PPO 训练中奖励模型的断点路径。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"cmd_preview_btn": {
|
"cmd_preview_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Preview command"
|
"value": "Preview command"
|
||||||
|
|
|
@ -12,12 +12,18 @@ class Manager:
|
||||||
def __init__(self, elem_list: List[Dict[str, Component]]):
|
def __init__(self, elem_list: List[Dict[str, Component]]):
|
||||||
self.elem_list = elem_list
|
self.elem_list = elem_list
|
||||||
|
|
||||||
def gen_refresh(self) -> Dict[str, Any]:
|
def gen_refresh(self, lang: str) -> Dict[str, Any]:
|
||||||
refresh_dict = {
|
refresh_dict = {
|
||||||
"dataset": {"choices": list_dataset()["choices"]},
|
"dataset": {"choices": list_dataset()["choices"]},
|
||||||
"output_dir": {"value": get_time()}
|
"output_dir": {"value": get_time()}
|
||||||
}
|
}
|
||||||
|
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
if lang:
|
||||||
|
refresh_dict["lang"] = {"value": lang}
|
||||||
|
else:
|
||||||
|
refresh_dict["lang"] = {"value": user_config["lang"] if user_config["lang"] else "en"}
|
||||||
|
|
||||||
if user_config["last_model"]:
|
if user_config["last_model"]:
|
||||||
refresh_dict["model_name"] = {"value": user_config["last_model"]}
|
refresh_dict["model_name"] = {"value": user_config["last_model"]}
|
||||||
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
|
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||||
|
@ -26,10 +32,12 @@ class Manager:
|
||||||
|
|
||||||
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
|
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
|
||||||
update_dict = {}
|
update_dict = {}
|
||||||
refresh_dict = self.gen_refresh()
|
refresh_dict = self.gen_refresh(lang)
|
||||||
|
|
||||||
for elems in self.elem_list:
|
for elems in self.elem_list:
|
||||||
for name, component in elems.items():
|
for name, component in elems.items():
|
||||||
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
|
update_dict[component] = gr.update(
|
||||||
|
**LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
|
||||||
|
)
|
||||||
|
|
||||||
return update_dict
|
return update_dict
|
||||||
|
|
|
@ -91,6 +91,9 @@ class Runner:
|
||||||
lora_dropout: float,
|
lora_dropout: float,
|
||||||
lora_target: str,
|
lora_target: str,
|
||||||
resume_lora_training: bool,
|
resume_lora_training: bool,
|
||||||
|
rlhf_method: str,
|
||||||
|
dpo_beta: float,
|
||||||
|
reward_model: str,
|
||||||
output_dir: str
|
output_dir: str
|
||||||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
|
@ -109,7 +112,7 @@ class Runner:
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||||
template=template,
|
template=template,
|
||||||
source_prefix=source_prefix,
|
source_prefix=source_prefix,
|
||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
|
@ -134,6 +137,21 @@ class Runner:
|
||||||
output_dir=output_dir
|
output_dir=output_dir
|
||||||
)
|
)
|
||||||
args[compute_type] = True
|
args[compute_type] = True
|
||||||
|
|
||||||
|
if rlhf_method == "Reward Modeling":
|
||||||
|
args["stage"] = "rm"
|
||||||
|
args["resume_lora_training"] = False
|
||||||
|
elif rlhf_method == "PPO":
|
||||||
|
args["stage"] = "ppo"
|
||||||
|
args["resume_lora_training"] = False
|
||||||
|
args["reward_model"] = reward_model
|
||||||
|
args["padding_side"] = "left"
|
||||||
|
val_size = 0
|
||||||
|
elif rlhf_method == "DPO":
|
||||||
|
args["stage"] = "dpo"
|
||||||
|
args["resume_lora_training"] = False
|
||||||
|
args["dpo_beta"] = dpo_beta
|
||||||
|
|
||||||
if val_size > 1e-6:
|
if val_size > 1e-6:
|
||||||
args["val_size"] = val_size
|
args["val_size"] = val_size
|
||||||
args["evaluation_strategy"] = "steps"
|
args["evaluation_strategy"] = "steps"
|
||||||
|
@ -176,7 +194,7 @@ class Runner:
|
||||||
predict_with_generate=True,
|
predict_with_generate=True,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||||
template=template,
|
template=template,
|
||||||
source_prefix=source_prefix,
|
source_prefix=source_prefix,
|
||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
|
|
|
@ -63,7 +63,8 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
|
||||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||||
args["plot_loss"] = True
|
if args.get("do_train", None):
|
||||||
|
args["plot_loss"] = True
|
||||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
|
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
|
||||||
for k, v in args.items():
|
for k, v in args.items():
|
||||||
if v is not None and v != "":
|
if v is not None and v != "":
|
||||||
|
|
Loading…
Reference in New Issue