update webUI, fix #179
This commit is contained in:
parent
b9fe83fb75
commit
12d8a8633f
|
@ -108,7 +108,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||||
replace_model(unwrapped_model, target="reward")
|
replace_model(unwrapped_model, target="reward")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
||||||
rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type
|
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
|
||||||
# Run PPO step
|
# Run PPO step
|
||||||
|
|
|
@ -17,8 +17,14 @@ class WebChatModel(ChatModel):
|
||||||
self.generating_args = GeneratingArguments()
|
self.generating_args = GeneratingArguments()
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
self, lang: str, model_name: str, checkpoints: list,
|
self,
|
||||||
finetuning_type: str, template: str, quantization_bit: str
|
lang: str,
|
||||||
|
model_name: str,
|
||||||
|
checkpoints: List[str],
|
||||||
|
finetuning_type: str,
|
||||||
|
quantization_bit: str,
|
||||||
|
template: str,
|
||||||
|
source_prefix: str
|
||||||
):
|
):
|
||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
yield ALERTS["err_exists"][lang]
|
yield ALERTS["err_exists"][lang]
|
||||||
|
@ -43,10 +49,11 @@ class WebChatModel(ChatModel):
|
||||||
yield ALERTS["info_loading"][lang]
|
yield ALERTS["info_loading"][lang]
|
||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
finetuning_type=finetuning_type,
|
|
||||||
prompt_template=template,
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
finetuning_type=finetuning_type,
|
||||||
|
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||||
|
prompt_template=template,
|
||||||
|
source_prefix=source_prefix
|
||||||
)
|
)
|
||||||
super().__init__(*get_infer_args(args))
|
super().__init__(*get_infer_args(args))
|
||||||
|
|
||||||
|
|
|
@ -10,8 +10,8 @@ from llmtuner.webui.utils import can_preview, get_preview
|
||||||
|
|
||||||
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||||
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
preview_btn = gr.Button(interactive=False, scale=1)
|
preview_btn = gr.Button(interactive=False, scale=1)
|
||||||
|
|
||||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||||
|
@ -21,9 +21,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
||||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
max_samples = gr.Textbox(value="100000", interactive=True)
|
max_samples = gr.Textbox(value="100000")
|
||||||
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True)
|
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1)
|
||||||
quantization_bit = gr.Dropdown([8, 4])
|
|
||||||
predict = gr.Checkbox(value=True)
|
predict = gr.Checkbox(value=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -35,9 +34,18 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
||||||
start_btn.click(
|
start_btn.click(
|
||||||
runner.run_eval,
|
runner.run_eval,
|
||||||
[
|
[
|
||||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
top_elems["lang"],
|
||||||
top_elems["finetuning_type"], top_elems["template"],
|
top_elems["model_name"],
|
||||||
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
|
top_elems["checkpoints"],
|
||||||
|
top_elems["finetuning_type"],
|
||||||
|
top_elems["quantization_bit"],
|
||||||
|
top_elems["template"],
|
||||||
|
top_elems["source_prefix"],
|
||||||
|
dataset_dir,
|
||||||
|
dataset,
|
||||||
|
max_samples,
|
||||||
|
batch_size,
|
||||||
|
predict
|
||||||
],
|
],
|
||||||
[output_box]
|
[output_box]
|
||||||
)
|
)
|
||||||
|
@ -52,7 +60,6 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
||||||
close_btn=close_btn,
|
close_btn=close_btn,
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
quantization_bit=quantization_bit,
|
|
||||||
predict=predict,
|
predict=predict,
|
||||||
start_btn=start_btn,
|
start_btn=start_btn,
|
||||||
stop_btn=stop_btn,
|
stop_btn=stop_btn,
|
||||||
|
|
|
@ -11,7 +11,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
load_btn = gr.Button()
|
load_btn = gr.Button()
|
||||||
unload_btn = gr.Button()
|
unload_btn = gr.Button()
|
||||||
quantization_bit = gr.Dropdown([8, 4])
|
|
||||||
|
|
||||||
info_box = gr.Markdown()
|
info_box = gr.Markdown()
|
||||||
|
|
||||||
|
@ -21,9 +20,13 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
||||||
load_btn.click(
|
load_btn.click(
|
||||||
chat_model.load_model,
|
chat_model.load_model,
|
||||||
[
|
[
|
||||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
top_elems["lang"],
|
||||||
top_elems["finetuning_type"], top_elems["template"],
|
top_elems["model_name"],
|
||||||
quantization_bit
|
top_elems["checkpoints"],
|
||||||
|
top_elems["finetuning_type"],
|
||||||
|
top_elems["quantization_bit"],
|
||||||
|
top_elems["template"],
|
||||||
|
top_elems["source_prefix"]
|
||||||
],
|
],
|
||||||
[info_box]
|
[info_box]
|
||||||
).then(
|
).then(
|
||||||
|
@ -39,7 +42,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
||||||
)
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
quantization_bit=quantization_bit,
|
|
||||||
info_box=info_box,
|
info_box=info_box,
|
||||||
load_btn=load_btn,
|
load_btn=load_btn,
|
||||||
unload_btn=unload_btn,
|
unload_btn=unload_btn,
|
||||||
|
|
|
@ -12,8 +12,8 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||||
|
|
||||||
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||||
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
preview_btn = gr.Button(interactive=False, scale=1)
|
preview_btn = gr.Button(interactive=False, scale=1)
|
||||||
|
|
||||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||||
|
@ -23,22 +23,21 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
||||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
learning_rate = gr.Textbox(value="5e-5", interactive=True)
|
learning_rate = gr.Textbox(value="5e-5")
|
||||||
num_train_epochs = gr.Textbox(value="3.0", interactive=True)
|
num_train_epochs = gr.Textbox(value="3.0")
|
||||||
max_samples = gr.Textbox(value="100000", interactive=True)
|
max_samples = gr.Textbox(value="100000")
|
||||||
quantization_bit = gr.Dropdown([8, 4])
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True)
|
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1)
|
||||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True)
|
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1)
|
||||||
lr_scheduler_type = gr.Dropdown(
|
lr_scheduler_type = gr.Dropdown(
|
||||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True
|
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
||||||
)
|
)
|
||||||
fp16 = gr.Checkbox(value=True)
|
fp16 = gr.Checkbox(value=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True)
|
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||||
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True)
|
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_btn = gr.Button()
|
start_btn = gr.Button()
|
||||||
|
@ -55,11 +54,25 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
||||||
start_btn.click(
|
start_btn.click(
|
||||||
runner.run_train,
|
runner.run_train,
|
||||||
[
|
[
|
||||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
top_elems["lang"],
|
||||||
top_elems["finetuning_type"], top_elems["template"],
|
top_elems["model_name"],
|
||||||
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
|
top_elems["checkpoints"],
|
||||||
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
|
top_elems["finetuning_type"],
|
||||||
lr_scheduler_type, logging_steps, save_steps, output_dir
|
top_elems["quantization_bit"],
|
||||||
|
top_elems["template"],
|
||||||
|
top_elems["source_prefix"],
|
||||||
|
dataset_dir,
|
||||||
|
dataset,
|
||||||
|
learning_rate,
|
||||||
|
num_train_epochs,
|
||||||
|
max_samples,
|
||||||
|
batch_size,
|
||||||
|
gradient_accumulation_steps,
|
||||||
|
lr_scheduler_type,
|
||||||
|
fp16,
|
||||||
|
logging_steps,
|
||||||
|
save_steps,
|
||||||
|
output_dir
|
||||||
],
|
],
|
||||||
[output_box]
|
[output_box]
|
||||||
)
|
)
|
||||||
|
@ -79,7 +92,6 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
num_train_epochs=num_train_epochs,
|
num_train_epochs=num_train_epochs,
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
quantization_bit=quantization_bit,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
lr_scheduler_type=lr_scheduler_type,
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
|
|
|
@ -6,29 +6,40 @@ from gradio.components import Component
|
||||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||||
from llmtuner.extras.template import templates
|
from llmtuner.extras.template import templates
|
||||||
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
|
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
|
||||||
|
from llmtuner.webui.utils import can_quantize
|
||||||
|
|
||||||
|
|
||||||
def create_top() -> Dict[str, Component]:
|
def create_top() -> Dict[str, Component]:
|
||||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1)
|
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
|
||||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||||
model_path = gr.Textbox(scale=3)
|
model_path = gr.Textbox(scale=3)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1)
|
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
|
||||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1)
|
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||||
checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
|
||||||
refresh_btn = gr.Button(scale=1)
|
refresh_btn = gr.Button(scale=1)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
quantization_bit = gr.Dropdown([8, 4], scale=1)
|
||||||
|
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2)
|
||||||
|
source_prefix = gr.Textbox(scale=4)
|
||||||
|
|
||||||
model_name.change(
|
model_name.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
).then(
|
).then(
|
||||||
get_model_path, [model_name], [model_path]
|
get_model_path, [model_name], [model_path]
|
||||||
) # do not save config since the below line will save
|
) # do not save config since the below line will save
|
||||||
model_path.change(save_config, [model_name, model_path])
|
model_path.change(save_config, [model_name, model_path])
|
||||||
finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
|
||||||
|
finetuning_type.change(
|
||||||
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
|
).then(
|
||||||
|
can_quantize, [finetuning_type], [quantization_bit]
|
||||||
|
)
|
||||||
|
|
||||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
|
@ -38,5 +49,7 @@ def create_top() -> Dict[str, Component]:
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
template=template,
|
template=template,
|
||||||
checkpoints=checkpoints,
|
checkpoints=checkpoints,
|
||||||
refresh_btn=refresh_btn
|
refresh_btn=refresh_btn,
|
||||||
|
quantization_bit=quantization_bit,
|
||||||
|
source_prefix=source_prefix
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,6 +25,14 @@ LOCALES = {
|
||||||
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
|
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"finetuning_type": {
|
||||||
|
"en": {
|
||||||
|
"label": "Finetuning method"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "微调方法"
|
||||||
|
}
|
||||||
|
},
|
||||||
"checkpoints": {
|
"checkpoints": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Checkpoints"
|
"label": "Checkpoints"
|
||||||
|
@ -33,14 +41,6 @@ LOCALES = {
|
||||||
"label": "模型断点"
|
"label": "模型断点"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"template": {
|
|
||||||
"en": {
|
|
||||||
"label": "Prompt template"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "提示模板"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"refresh_btn": {
|
"refresh_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Refresh checkpoints"
|
"value": "Refresh checkpoints"
|
||||||
|
@ -49,6 +49,36 @@ LOCALES = {
|
||||||
"value": "刷新断点"
|
"value": "刷新断点"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"quantization_bit": {
|
||||||
|
"en": {
|
||||||
|
"label": "Quantization bit (optional)",
|
||||||
|
"info": "Enable 4/8-bit model quantization."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "量化等级(非必填)",
|
||||||
|
"info": "启用 4/8 比特模型量化。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"template": {
|
||||||
|
"en": {
|
||||||
|
"label": "Prompt template",
|
||||||
|
"info": "The template used in constructing prompts."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "提示模板",
|
||||||
|
"info": "构建提示词时使用的模板"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source_prefix": {
|
||||||
|
"en": {
|
||||||
|
"label": "Source prefix (optional)",
|
||||||
|
"info": "A sequence used as the prefix of each samples."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "前缀序列(非必填)",
|
||||||
|
"info": "作为每个输入样本前缀的序列"
|
||||||
|
}
|
||||||
|
},
|
||||||
"dataset_dir": {
|
"dataset_dir": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Data dir",
|
"label": "Data dir",
|
||||||
|
@ -99,68 +129,6 @@ LOCALES = {
|
||||||
"value": "关闭"
|
"value": "关闭"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"max_samples": {
|
|
||||||
"en": {
|
|
||||||
"label": "Max samples",
|
|
||||||
"info": "Maximum samples per dataset."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "最大样本数",
|
|
||||||
"info": "每个数据集最多使用的样本数。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"batch_size": {
|
|
||||||
"en": {
|
|
||||||
"label": "Batch size",
|
|
||||||
"info": "Number of samples to process per GPU."
|
|
||||||
},
|
|
||||||
"zh":{
|
|
||||||
"label": "批处理大小",
|
|
||||||
"info": "每块 GPU 上处理的样本数量。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"quantization_bit": {
|
|
||||||
"en": {
|
|
||||||
"label": "Quantization bit",
|
|
||||||
"info": "Enable 4/8-bit model quantization."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "量化",
|
|
||||||
"info": "启用 4/8 比特模型量化。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"start_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Start"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "开始"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"stop_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Abort"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "中断"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"output_box": {
|
|
||||||
"en": {
|
|
||||||
"value": "Ready."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "准备就绪。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"finetuning_type": {
|
|
||||||
"en": {
|
|
||||||
"label": "Finetuning method"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "微调方法"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"learning_rate": {
|
"learning_rate": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Learning rate",
|
"label": "Learning rate",
|
||||||
|
@ -181,6 +149,26 @@ LOCALES = {
|
||||||
"info": "需要执行的训练总轮数。"
|
"info": "需要执行的训练总轮数。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"max_samples": {
|
||||||
|
"en": {
|
||||||
|
"label": "Max samples",
|
||||||
|
"info": "Maximum samples per dataset."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "最大样本数",
|
||||||
|
"info": "每个数据集最多使用的样本数。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"batch_size": {
|
||||||
|
"en": {
|
||||||
|
"label": "Batch size",
|
||||||
|
"info": "Number of samples to process per GPU."
|
||||||
|
},
|
||||||
|
"zh":{
|
||||||
|
"label": "批处理大小",
|
||||||
|
"info": "每块 GPU 上处理的样本数量。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": {
|
"gradient_accumulation_steps": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Gradient accumulation",
|
"label": "Gradient accumulation",
|
||||||
|
@ -231,6 +219,22 @@ LOCALES = {
|
||||||
"info": "每两次断点保存间的更新步数。"
|
"info": "每两次断点保存间的更新步数。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"start_btn": {
|
||||||
|
"en": {
|
||||||
|
"value": "Start"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"value": "开始"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stop_btn": {
|
||||||
|
"en": {
|
||||||
|
"value": "Abort"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"value": "中断"
|
||||||
|
}
|
||||||
|
},
|
||||||
"output_dir": {
|
"output_dir": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Checkpoint name",
|
"label": "Checkpoint name",
|
||||||
|
@ -241,6 +245,14 @@ LOCALES = {
|
||||||
"info": "保存模型断点的文件夹名称。"
|
"info": "保存模型断点的文件夹名称。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"output_box": {
|
||||||
|
"en": {
|
||||||
|
"value": "Ready."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"value": "准备就绪。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"loss_viewer": {
|
"loss_viewer": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Loss"
|
"label": "Loss"
|
||||||
|
@ -257,14 +269,6 @@ LOCALES = {
|
||||||
"label": "保存预测结果"
|
"label": "保存预测结果"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"info_box": {
|
|
||||||
"en": {
|
|
||||||
"value": "Model unloaded, please load a model first."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "模型未加载,请先加载模型。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"load_btn": {
|
"load_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Load model"
|
"value": "Load model"
|
||||||
|
@ -281,6 +285,14 @@ LOCALES = {
|
||||||
"value": "卸载模型"
|
"value": "卸载模型"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"info_box": {
|
||||||
|
"en": {
|
||||||
|
"value": "Model unloaded, please load a model first."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"value": "模型未加载,请先加载模型。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"query": {
|
"query": {
|
||||||
"en": {
|
"en": {
|
||||||
"placeholder": "Input..."
|
"placeholder": "Input..."
|
||||||
|
@ -305,12 +317,12 @@ LOCALES = {
|
||||||
"value": "清空历史"
|
"value": "清空历史"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"max_new_tokens": {
|
"max_length": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Maximum new tokens"
|
"label": "Maximum length"
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "最大生成长度"
|
"label": "最大长度"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"top_p": {
|
"top_p": {
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import transformers
|
import transformers
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
|
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
|
||||||
|
@ -59,10 +59,26 @@ class Runner:
|
||||||
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
||||||
|
|
||||||
def run_train(
|
def run_train(
|
||||||
self, lang, model_name, checkpoints, finetuning_type, template,
|
self,
|
||||||
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
|
lang: str,
|
||||||
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
|
model_name: str,
|
||||||
lr_scheduler_type, logging_steps, save_steps, output_dir
|
checkpoints: List[str],
|
||||||
|
finetuning_type: str,
|
||||||
|
quantization_bit: str,
|
||||||
|
template: str,
|
||||||
|
source_prefix: str,
|
||||||
|
dataset_dir: str,
|
||||||
|
dataset: List[str],
|
||||||
|
learning_rate: str,
|
||||||
|
num_train_epochs: str,
|
||||||
|
max_samples: str,
|
||||||
|
batch_size: int,
|
||||||
|
gradient_accumulation_steps: int,
|
||||||
|
lr_scheduler_type: str,
|
||||||
|
fp16: bool,
|
||||||
|
logging_steps: int,
|
||||||
|
save_steps: int,
|
||||||
|
output_dir: str
|
||||||
):
|
):
|
||||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||||
if error:
|
if error:
|
||||||
|
@ -79,24 +95,25 @@ class Runner:
|
||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
finetuning_type=finetuning_type,
|
|
||||||
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
|
|
||||||
prompt_template=template,
|
|
||||||
dataset=",".join(dataset),
|
|
||||||
dataset_dir=dataset_dir,
|
|
||||||
max_samples=int(max_samples),
|
|
||||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir),
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
|
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
finetuning_type=finetuning_type,
|
||||||
|
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||||
|
prompt_template=template,
|
||||||
|
source_prefix=source_prefix,
|
||||||
|
dataset_dir=dataset_dir,
|
||||||
|
dataset=",".join(dataset),
|
||||||
|
learning_rate=float(learning_rate),
|
||||||
|
num_train_epochs=float(num_train_epochs),
|
||||||
|
max_samples=int(max_samples),
|
||||||
per_device_train_batch_size=batch_size,
|
per_device_train_batch_size=batch_size,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
lr_scheduler_type=lr_scheduler_type,
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
|
fp16=fp16,
|
||||||
logging_steps=logging_steps,
|
logging_steps=logging_steps,
|
||||||
save_steps=save_steps,
|
save_steps=save_steps,
|
||||||
learning_rate=float(learning_rate),
|
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||||
num_train_epochs=float(num_train_epochs),
|
|
||||||
fp16=fp16,
|
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
|
||||||
)
|
)
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||||
|
|
||||||
|
@ -120,8 +137,19 @@ class Runner:
|
||||||
yield self.finalize(lang)
|
yield self.finalize(lang)
|
||||||
|
|
||||||
def run_eval(
|
def run_eval(
|
||||||
self, lang, model_name, checkpoints, finetuning_type, template,
|
self,
|
||||||
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
|
lang: str,
|
||||||
|
model_name: str,
|
||||||
|
checkpoints: List[str],
|
||||||
|
finetuning_type: str,
|
||||||
|
quantization_bit: str,
|
||||||
|
template: str,
|
||||||
|
source_prefix: str,
|
||||||
|
dataset_dir: str,
|
||||||
|
dataset: List[str],
|
||||||
|
max_samples: str,
|
||||||
|
batch_size: int,
|
||||||
|
predict: bool
|
||||||
):
|
):
|
||||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||||
if error:
|
if error:
|
||||||
|
@ -140,17 +168,18 @@ class Runner:
|
||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
do_eval=True,
|
do_eval=True,
|
||||||
finetuning_type=finetuning_type,
|
|
||||||
prompt_template=template,
|
|
||||||
dataset=",".join(dataset),
|
|
||||||
dataset_dir=dataset_dir,
|
|
||||||
max_samples=int(max_samples),
|
|
||||||
output_dir=output_dir,
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
predict_with_generate=True,
|
predict_with_generate=True,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
finetuning_type=finetuning_type,
|
||||||
|
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||||
|
prompt_template=template,
|
||||||
|
source_prefix=source_prefix,
|
||||||
|
dataset_dir=dataset_dir,
|
||||||
|
dataset=",".join(dataset),
|
||||||
|
max_samples=int(max_samples),
|
||||||
per_device_eval_batch_size=batch_size,
|
per_device_eval_batch_size=batch_size,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
output_dir=output_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
if predict:
|
if predict:
|
||||||
|
|
|
@ -3,7 +3,7 @@ import json
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import matplotlib.figure
|
import matplotlib.figure
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from typing import Tuple
|
from typing import Any, Dict, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from llmtuner.extras.ploting import smooth
|
from llmtuner.extras.ploting import smooth
|
||||||
|
@ -23,7 +23,7 @@ def get_time() -> str:
|
||||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||||
|
|
||||||
|
|
||||||
def can_preview(dataset_dir: str, dataset: list) -> dict:
|
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
if (
|
if (
|
||||||
|
@ -36,7 +36,7 @@ def can_preview(dataset_dir: str, dataset: list) -> dict:
|
||||||
return gr.update(interactive=False)
|
return gr.update(interactive=False)
|
||||||
|
|
||||||
|
|
||||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
|
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
data_file = dataset_info[dataset[0]]["file_name"]
|
data_file = dataset_info[dataset[0]]["file_name"]
|
||||||
|
@ -45,6 +45,13 @@ def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
|
||||||
return len(data), data[:2], gr.update(visible=True)
|
return len(data), data[:2], gr.update(visible=True)
|
||||||
|
|
||||||
|
|
||||||
|
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||||
|
if finetuning_type != "lora":
|
||||||
|
return gr.update(value="", interactive=False)
|
||||||
|
else:
|
||||||
|
return gr.update(interactive=True)
|
||||||
|
|
||||||
|
|
||||||
def get_eval_results(path: os.PathLike) -> str:
|
def get_eval_results(path: os.PathLike) -> str:
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
result = json.dumps(json.load(f), indent=4)
|
result = json.dumps(json.load(f), indent=4)
|
||||||
|
@ -66,6 +73,10 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
||||||
if log_info.get("loss", None):
|
if log_info.get("loss", None):
|
||||||
steps.append(log_info["current_steps"])
|
steps.append(log_info["current_steps"])
|
||||||
losses.append(log_info["loss"])
|
losses.append(log_info["loss"])
|
||||||
|
|
||||||
|
if len(losses) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
ax.plot(steps, losses, alpha=0.4, label="original")
|
ax.plot(steps, losses, alpha=0.4, label="original")
|
||||||
ax.plot(steps, smooth(losses), label="smoothed")
|
ax.plot(steps, smooth(losses), label="smoothed")
|
||||||
ax.legend()
|
ax.legend()
|
||||||
|
|
Loading…
Reference in New Issue