update webUI, fix #179
This commit is contained in:
parent
b9fe83fb75
commit
12d8a8633f
|
@ -108,7 +108,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
replace_model(unwrapped_model, target="reward")
|
||||
with torch.no_grad():
|
||||
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
||||
rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type
|
||||
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
# Run PPO step
|
||||
|
|
|
@ -17,8 +17,14 @@ class WebChatModel(ChatModel):
|
|||
self.generating_args = GeneratingArguments()
|
||||
|
||||
def load_model(
|
||||
self, lang: str, model_name: str, checkpoints: list,
|
||||
finetuning_type: str, template: str, quantization_bit: str
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str
|
||||
):
|
||||
if self.model is not None:
|
||||
yield ALERTS["err_exists"][lang]
|
||||
|
@ -43,10 +49,11 @@ class WebChatModel(ChatModel):
|
|||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
finetuning_type=finetuning_type,
|
||||
prompt_template=template,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
super().__init__(*get_infer_args(args))
|
||||
|
||||
|
|
|
@ -10,8 +10,8 @@ from llmtuner.webui.utils import can_preview, get_preview
|
|||
|
||||
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
|
@ -21,9 +21,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
|||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
|
||||
with gr.Row():
|
||||
max_samples = gr.Textbox(value="100000", interactive=True)
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True)
|
||||
quantization_bit = gr.Dropdown([8, 4])
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
with gr.Row():
|
||||
|
@ -35,9 +34,18 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
|||
start_btn.click(
|
||||
runner.run_eval,
|
||||
[
|
||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"], top_elems["template"],
|
||||
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_samples,
|
||||
batch_size,
|
||||
predict
|
||||
],
|
||||
[output_box]
|
||||
)
|
||||
|
@ -52,7 +60,6 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
|||
close_btn=close_btn,
|
||||
max_samples=max_samples,
|
||||
batch_size=batch_size,
|
||||
quantization_bit=quantization_bit,
|
||||
predict=predict,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
|
|
|
@ -11,7 +11,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
|||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
quantization_bit = gr.Dropdown([8, 4])
|
||||
|
||||
info_box = gr.Markdown()
|
||||
|
||||
|
@ -21,9 +20,13 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
|||
load_btn.click(
|
||||
chat_model.load_model,
|
||||
[
|
||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"], top_elems["template"],
|
||||
quantization_bit
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"]
|
||||
],
|
||||
[info_box]
|
||||
).then(
|
||||
|
@ -39,7 +42,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
|||
)
|
||||
|
||||
return dict(
|
||||
quantization_bit=quantization_bit,
|
||||
info_box=info_box,
|
||||
load_btn=load_btn,
|
||||
unload_btn=unload_btn,
|
||||
|
|
|
@ -12,8 +12,8 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
|||
|
||||
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
|
@ -23,22 +23,21 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
|||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
|
||||
with gr.Row():
|
||||
learning_rate = gr.Textbox(value="5e-5", interactive=True)
|
||||
num_train_epochs = gr.Textbox(value="3.0", interactive=True)
|
||||
max_samples = gr.Textbox(value="100000", interactive=True)
|
||||
quantization_bit = gr.Dropdown([8, 4])
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True)
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1)
|
||||
lr_scheduler_type = gr.Dropdown(
|
||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True
|
||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
||||
)
|
||||
fp16 = gr.Checkbox(value=True)
|
||||
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True)
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button()
|
||||
|
@ -55,11 +54,25 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
|||
start_btn.click(
|
||||
runner.run_train,
|
||||
[
|
||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"], top_elems["template"],
|
||||
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
|
||||
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
|
||||
lr_scheduler_type, logging_steps, save_steps, output_dir
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
learning_rate,
|
||||
num_train_epochs,
|
||||
max_samples,
|
||||
batch_size,
|
||||
gradient_accumulation_steps,
|
||||
lr_scheduler_type,
|
||||
fp16,
|
||||
logging_steps,
|
||||
save_steps,
|
||||
output_dir
|
||||
],
|
||||
[output_box]
|
||||
)
|
||||
|
@ -79,7 +92,6 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
|||
learning_rate=learning_rate,
|
||||
num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples,
|
||||
quantization_bit=quantization_bit,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
|
|
|
@ -6,29 +6,40 @@ from gradio.components import Component
|
|||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from llmtuner.extras.template import templates
|
||||
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
|
||||
from llmtuner.webui.utils import can_quantize
|
||||
|
||||
|
||||
def create_top() -> Dict[str, Component]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1)
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
|
||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||
model_path = gr.Textbox(scale=3)
|
||||
|
||||
with gr.Row():
|
||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1)
|
||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1)
|
||||
checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
|
||||
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown([8, 4], scale=1)
|
||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2)
|
||||
source_prefix = gr.Textbox(scale=4)
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
).then(
|
||||
get_model_path, [model_name], [model_path]
|
||||
) # do not save config since the below line will save
|
||||
model_path.change(save_config, [model_name, model_path])
|
||||
finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
||||
|
||||
finetuning_type.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
).then(
|
||||
can_quantize, [finetuning_type], [quantization_bit]
|
||||
)
|
||||
|
||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
||||
|
||||
return dict(
|
||||
|
@ -38,5 +49,7 @@ def create_top() -> Dict[str, Component]:
|
|||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
checkpoints=checkpoints,
|
||||
refresh_btn=refresh_btn
|
||||
refresh_btn=refresh_btn,
|
||||
quantization_bit=quantization_bit,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
|
|
|
@ -25,6 +25,14 @@ LOCALES = {
|
|||
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
|
||||
}
|
||||
},
|
||||
"finetuning_type": {
|
||||
"en": {
|
||||
"label": "Finetuning method"
|
||||
},
|
||||
"zh": {
|
||||
"label": "微调方法"
|
||||
}
|
||||
},
|
||||
"checkpoints": {
|
||||
"en": {
|
||||
"label": "Checkpoints"
|
||||
|
@ -33,14 +41,6 @@ LOCALES = {
|
|||
"label": "模型断点"
|
||||
}
|
||||
},
|
||||
"template": {
|
||||
"en": {
|
||||
"label": "Prompt template"
|
||||
},
|
||||
"zh": {
|
||||
"label": "提示模板"
|
||||
}
|
||||
},
|
||||
"refresh_btn": {
|
||||
"en": {
|
||||
"value": "Refresh checkpoints"
|
||||
|
@ -49,6 +49,36 @@ LOCALES = {
|
|||
"value": "刷新断点"
|
||||
}
|
||||
},
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit (optional)",
|
||||
"info": "Enable 4/8-bit model quantization."
|
||||
},
|
||||
"zh": {
|
||||
"label": "量化等级(非必填)",
|
||||
"info": "启用 4/8 比特模型量化。"
|
||||
}
|
||||
},
|
||||
"template": {
|
||||
"en": {
|
||||
"label": "Prompt template",
|
||||
"info": "The template used in constructing prompts."
|
||||
},
|
||||
"zh": {
|
||||
"label": "提示模板",
|
||||
"info": "构建提示词时使用的模板"
|
||||
}
|
||||
},
|
||||
"source_prefix": {
|
||||
"en": {
|
||||
"label": "Source prefix (optional)",
|
||||
"info": "A sequence used as the prefix of each samples."
|
||||
},
|
||||
"zh": {
|
||||
"label": "前缀序列(非必填)",
|
||||
"info": "作为每个输入样本前缀的序列"
|
||||
}
|
||||
},
|
||||
"dataset_dir": {
|
||||
"en": {
|
||||
"label": "Data dir",
|
||||
|
@ -99,68 +129,6 @@ LOCALES = {
|
|||
"value": "关闭"
|
||||
}
|
||||
},
|
||||
"max_samples": {
|
||||
"en": {
|
||||
"label": "Max samples",
|
||||
"info": "Maximum samples per dataset."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大样本数",
|
||||
"info": "每个数据集最多使用的样本数。"
|
||||
}
|
||||
},
|
||||
"batch_size": {
|
||||
"en": {
|
||||
"label": "Batch size",
|
||||
"info": "Number of samples to process per GPU."
|
||||
},
|
||||
"zh":{
|
||||
"label": "批处理大小",
|
||||
"info": "每块 GPU 上处理的样本数量。"
|
||||
}
|
||||
},
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit",
|
||||
"info": "Enable 4/8-bit model quantization."
|
||||
},
|
||||
"zh": {
|
||||
"label": "量化",
|
||||
"info": "启用 4/8 比特模型量化。"
|
||||
}
|
||||
},
|
||||
"start_btn": {
|
||||
"en": {
|
||||
"value": "Start"
|
||||
},
|
||||
"zh": {
|
||||
"value": "开始"
|
||||
}
|
||||
},
|
||||
"stop_btn": {
|
||||
"en": {
|
||||
"value": "Abort"
|
||||
},
|
||||
"zh": {
|
||||
"value": "中断"
|
||||
}
|
||||
},
|
||||
"output_box": {
|
||||
"en": {
|
||||
"value": "Ready."
|
||||
},
|
||||
"zh": {
|
||||
"value": "准备就绪。"
|
||||
}
|
||||
},
|
||||
"finetuning_type": {
|
||||
"en": {
|
||||
"label": "Finetuning method"
|
||||
},
|
||||
"zh": {
|
||||
"label": "微调方法"
|
||||
}
|
||||
},
|
||||
"learning_rate": {
|
||||
"en": {
|
||||
"label": "Learning rate",
|
||||
|
@ -181,6 +149,26 @@ LOCALES = {
|
|||
"info": "需要执行的训练总轮数。"
|
||||
}
|
||||
},
|
||||
"max_samples": {
|
||||
"en": {
|
||||
"label": "Max samples",
|
||||
"info": "Maximum samples per dataset."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大样本数",
|
||||
"info": "每个数据集最多使用的样本数。"
|
||||
}
|
||||
},
|
||||
"batch_size": {
|
||||
"en": {
|
||||
"label": "Batch size",
|
||||
"info": "Number of samples to process per GPU."
|
||||
},
|
||||
"zh":{
|
||||
"label": "批处理大小",
|
||||
"info": "每块 GPU 上处理的样本数量。"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": {
|
||||
"en": {
|
||||
"label": "Gradient accumulation",
|
||||
|
@ -231,6 +219,22 @@ LOCALES = {
|
|||
"info": "每两次断点保存间的更新步数。"
|
||||
}
|
||||
},
|
||||
"start_btn": {
|
||||
"en": {
|
||||
"value": "Start"
|
||||
},
|
||||
"zh": {
|
||||
"value": "开始"
|
||||
}
|
||||
},
|
||||
"stop_btn": {
|
||||
"en": {
|
||||
"value": "Abort"
|
||||
},
|
||||
"zh": {
|
||||
"value": "中断"
|
||||
}
|
||||
},
|
||||
"output_dir": {
|
||||
"en": {
|
||||
"label": "Checkpoint name",
|
||||
|
@ -241,6 +245,14 @@ LOCALES = {
|
|||
"info": "保存模型断点的文件夹名称。"
|
||||
}
|
||||
},
|
||||
"output_box": {
|
||||
"en": {
|
||||
"value": "Ready."
|
||||
},
|
||||
"zh": {
|
||||
"value": "准备就绪。"
|
||||
}
|
||||
},
|
||||
"loss_viewer": {
|
||||
"en": {
|
||||
"label": "Loss"
|
||||
|
@ -257,14 +269,6 @@ LOCALES = {
|
|||
"label": "保存预测结果"
|
||||
}
|
||||
},
|
||||
"info_box": {
|
||||
"en": {
|
||||
"value": "Model unloaded, please load a model first."
|
||||
},
|
||||
"zh": {
|
||||
"value": "模型未加载,请先加载模型。"
|
||||
}
|
||||
},
|
||||
"load_btn": {
|
||||
"en": {
|
||||
"value": "Load model"
|
||||
|
@ -281,6 +285,14 @@ LOCALES = {
|
|||
"value": "卸载模型"
|
||||
}
|
||||
},
|
||||
"info_box": {
|
||||
"en": {
|
||||
"value": "Model unloaded, please load a model first."
|
||||
},
|
||||
"zh": {
|
||||
"value": "模型未加载,请先加载模型。"
|
||||
}
|
||||
},
|
||||
"query": {
|
||||
"en": {
|
||||
"placeholder": "Input..."
|
||||
|
@ -305,12 +317,12 @@ LOCALES = {
|
|||
"value": "清空历史"
|
||||
}
|
||||
},
|
||||
"max_new_tokens": {
|
||||
"max_length": {
|
||||
"en": {
|
||||
"label": "Maximum new tokens"
|
||||
"label": "Maximum length"
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大生成长度"
|
||||
"label": "最大长度"
|
||||
}
|
||||
},
|
||||
"top_p": {
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import threading
|
||||
import time
|
||||
import transformers
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
|
||||
|
@ -59,10 +59,26 @@ class Runner:
|
|||
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
||||
|
||||
def run_train(
|
||||
self, lang, model_name, checkpoints, finetuning_type, template,
|
||||
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
|
||||
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
|
||||
lr_scheduler_type, logging_steps, save_steps, output_dir
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
learning_rate: str,
|
||||
num_train_epochs: str,
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
gradient_accumulation_steps: int,
|
||||
lr_scheduler_type: str,
|
||||
fp16: bool,
|
||||
logging_steps: int,
|
||||
save_steps: int,
|
||||
output_dir: str
|
||||
):
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
|
@ -79,24 +95,25 @@ class Runner:
|
|||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_train=True,
|
||||
finetuning_type=finetuning_type,
|
||||
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
|
||||
prompt_template=template,
|
||||
dataset=",".join(dataset),
|
||||
dataset_dir=dataset_dir,
|
||||
max_samples=int(max_samples),
|
||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
overwrite_cache=True,
|
||||
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
learning_rate=float(learning_rate),
|
||||
num_train_epochs=float(num_train_epochs),
|
||||
max_samples=int(max_samples),
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
fp16=fp16,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
learning_rate=float(learning_rate),
|
||||
num_train_epochs=float(num_train_epochs),
|
||||
fp16=fp16,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||
)
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
|
||||
|
@ -120,8 +137,19 @@ class Runner:
|
|||
yield self.finalize(lang)
|
||||
|
||||
def run_eval(
|
||||
self, lang, model_name, checkpoints, finetuning_type, template,
|
||||
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
predict: bool
|
||||
):
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
|
@ -140,17 +168,18 @@ class Runner:
|
|||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_eval=True,
|
||||
finetuning_type=finetuning_type,
|
||||
prompt_template=template,
|
||||
dataset=",".join(dataset),
|
||||
dataset_dir=dataset_dir,
|
||||
max_samples=int(max_samples),
|
||||
output_dir=output_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
overwrite_cache=True,
|
||||
predict_with_generate=True,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_samples=int(max_samples),
|
||||
per_device_eval_batch_size=batch_size,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
||||
output_dir=output_dir
|
||||
)
|
||||
|
||||
if predict:
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.extras.ploting import smooth
|
||||
|
@ -23,7 +23,7 @@ def get_time() -> str:
|
|||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> dict:
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
if (
|
||||
|
@ -36,7 +36,7 @@ def can_preview(dataset_dir: str, dataset: list) -> dict:
|
|||
return gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
|
||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
data_file = dataset_info[dataset[0]]["file_name"]
|
||||
|
@ -45,6 +45,13 @@ def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
|
|||
return len(data), data[:2], gr.update(visible=True)
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
if finetuning_type != "lora":
|
||||
return gr.update(value="", interactive=False)
|
||||
else:
|
||||
return gr.update(interactive=True)
|
||||
|
||||
|
||||
def get_eval_results(path: os.PathLike) -> str:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
result = json.dumps(json.load(f), indent=4)
|
||||
|
@ -66,6 +73,10 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
|||
if log_info.get("loss", None):
|
||||
steps.append(log_info["current_steps"])
|
||||
losses.append(log_info["loss"])
|
||||
|
||||
if len(losses) == 0:
|
||||
return None
|
||||
|
||||
ax.plot(steps, losses, alpha=0.4, label="original")
|
||||
ax.plot(steps, smooth(losses), label="smoothed")
|
||||
ax.legend()
|
||||
|
|
Loading…
Reference in New Issue