update webUI, fix #179

This commit is contained in:
hiyouga 2023-07-18 15:35:17 +08:00
parent b9fe83fb75
commit 12d8a8633f
9 changed files with 247 additions and 154 deletions

View File

@ -108,7 +108,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
with torch.no_grad(): with torch.no_grad():
_, _, values = self.model(**self.prepare_model_inputs(queries, responses)) _, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
# Run PPO step # Run PPO step

View File

@ -17,8 +17,14 @@ class WebChatModel(ChatModel):
self.generating_args = GeneratingArguments() self.generating_args = GeneratingArguments()
def load_model( def load_model(
self, lang: str, model_name: str, checkpoints: list, self,
finetuning_type: str, template: str, quantization_bit: str lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str
): ):
if self.model is not None: if self.model is not None:
yield ALERTS["err_exists"][lang] yield ALERTS["err_exists"][lang]
@ -43,10 +49,11 @@ class WebChatModel(ChatModel):
yield ALERTS["info_loading"][lang] yield ALERTS["info_loading"][lang]
args = dict( args = dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type,
prompt_template=template,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
quantization_bit=int(quantization_bit) if quantization_bit else None finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix
) )
super().__init__(*get_infer_args(args)) super().__init__(*get_infer_args(args))

View File

@ -10,8 +10,8 @@ from llmtuner.webui.utils import can_preview, get_preview
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1) preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
@ -21,9 +21,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row(): with gr.Row():
max_samples = gr.Textbox(value="100000", interactive=True) max_samples = gr.Textbox(value="100000")
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True) batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1)
quantization_bit = gr.Dropdown([8, 4])
predict = gr.Checkbox(value=True) predict = gr.Checkbox(value=True)
with gr.Row(): with gr.Row():
@ -35,9 +34,18 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
start_btn.click( start_btn.click(
runner.run_eval, runner.run_eval,
[ [
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], top_elems["lang"],
top_elems["finetuning_type"], top_elems["template"], top_elems["model_name"],
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
dataset_dir,
dataset,
max_samples,
batch_size,
predict
], ],
[output_box] [output_box]
) )
@ -52,7 +60,6 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
close_btn=close_btn, close_btn=close_btn,
max_samples=max_samples, max_samples=max_samples,
batch_size=batch_size, batch_size=batch_size,
quantization_bit=quantization_bit,
predict=predict, predict=predict,
start_btn=start_btn, start_btn=start_btn,
stop_btn=stop_btn, stop_btn=stop_btn,

View File

@ -11,7 +11,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
with gr.Row(): with gr.Row():
load_btn = gr.Button() load_btn = gr.Button()
unload_btn = gr.Button() unload_btn = gr.Button()
quantization_bit = gr.Dropdown([8, 4])
info_box = gr.Markdown() info_box = gr.Markdown()
@ -21,9 +20,13 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
load_btn.click( load_btn.click(
chat_model.load_model, chat_model.load_model,
[ [
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], top_elems["lang"],
top_elems["finetuning_type"], top_elems["template"], top_elems["model_name"],
quantization_bit top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"]
], ],
[info_box] [info_box]
).then( ).then(
@ -39,7 +42,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
) )
return dict( return dict(
quantization_bit=quantization_bit,
info_box=info_box, info_box=info_box,
load_btn=load_btn, load_btn=load_btn,
unload_btn=unload_btn, unload_btn=unload_btn,

View File

@ -12,8 +12,8 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1) preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
@ -23,22 +23,21 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row(): with gr.Row():
learning_rate = gr.Textbox(value="5e-5", interactive=True) learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0", interactive=True) num_train_epochs = gr.Textbox(value="3.0")
max_samples = gr.Textbox(value="100000", interactive=True) max_samples = gr.Textbox(value="100000")
quantization_bit = gr.Dropdown([8, 4])
with gr.Row(): with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True) batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True) gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1)
lr_scheduler_type = gr.Dropdown( lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
) )
fp16 = gr.Checkbox(value=True) fp16 = gr.Checkbox(value=True)
with gr.Row(): with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True) logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True) save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10)
with gr.Row(): with gr.Row():
start_btn = gr.Button() start_btn = gr.Button()
@ -55,11 +54,25 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
start_btn.click( start_btn.click(
runner.run_train, runner.run_train,
[ [
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], top_elems["lang"],
top_elems["finetuning_type"], top_elems["template"], top_elems["model_name"],
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples, top_elems["checkpoints"],
fp16, quantization_bit, batch_size, gradient_accumulation_steps, top_elems["finetuning_type"],
lr_scheduler_type, logging_steps, save_steps, output_dir top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
dataset_dir,
dataset,
learning_rate,
num_train_epochs,
max_samples,
batch_size,
gradient_accumulation_steps,
lr_scheduler_type,
fp16,
logging_steps,
save_steps,
output_dir
], ],
[output_box] [output_box]
) )
@ -79,7 +92,6 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
learning_rate=learning_rate, learning_rate=learning_rate,
num_train_epochs=num_train_epochs, num_train_epochs=num_train_epochs,
max_samples=max_samples, max_samples=max_samples,
quantization_bit=quantization_bit,
batch_size=batch_size, batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, lr_scheduler_type=lr_scheduler_type,

View File

@ -6,29 +6,40 @@ from gradio.components import Component
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
from llmtuner.webui.utils import can_quantize
def create_top() -> Dict[str, Component]: def create_top() -> Dict[str, Component]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row(): with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1) lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3) model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3) model_path = gr.Textbox(scale=3)
with gr.Row(): with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1) finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1) checkpoints = gr.Dropdown(multiselect=True, scale=5)
checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
with gr.Row():
quantization_bit = gr.Dropdown([8, 4], scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2)
source_prefix = gr.Textbox(scale=4)
model_name.change( model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints] list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then( ).then(
get_model_path, [model_name], [model_path] get_model_path, [model_name], [model_path]
) # do not save config since the below line will save ) # do not save config since the below line will save
model_path.change(save_config, [model_name, model_path]) model_path.change(save_config, [model_name, model_path])
finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints])
finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
can_quantize, [finetuning_type], [quantization_bit]
)
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints]) refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
return dict( return dict(
@ -38,5 +49,7 @@ def create_top() -> Dict[str, Component]:
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
template=template, template=template,
checkpoints=checkpoints, checkpoints=checkpoints,
refresh_btn=refresh_btn refresh_btn=refresh_btn,
quantization_bit=quantization_bit,
source_prefix=source_prefix
) )

View File

@ -25,6 +25,14 @@ LOCALES = {
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。" "info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
} }
}, },
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
}
},
"checkpoints": { "checkpoints": {
"en": { "en": {
"label": "Checkpoints" "label": "Checkpoints"
@ -33,14 +41,6 @@ LOCALES = {
"label": "模型断点" "label": "模型断点"
} }
}, },
"template": {
"en": {
"label": "Prompt template"
},
"zh": {
"label": "提示模板"
}
},
"refresh_btn": { "refresh_btn": {
"en": { "en": {
"value": "Refresh checkpoints" "value": "Refresh checkpoints"
@ -49,6 +49,36 @@ LOCALES = {
"value": "刷新断点" "value": "刷新断点"
} }
}, },
"quantization_bit": {
"en": {
"label": "Quantization bit (optional)",
"info": "Enable 4/8-bit model quantization."
},
"zh": {
"label": "量化等级(非必填)",
"info": "启用 4/8 比特模型量化。"
}
},
"template": {
"en": {
"label": "Prompt template",
"info": "The template used in constructing prompts."
},
"zh": {
"label": "提示模板",
"info": "构建提示词时使用的模板"
}
},
"source_prefix": {
"en": {
"label": "Source prefix (optional)",
"info": "A sequence used as the prefix of each samples."
},
"zh": {
"label": "前缀序列(非必填)",
"info": "作为每个输入样本前缀的序列"
}
},
"dataset_dir": { "dataset_dir": {
"en": { "en": {
"label": "Data dir", "label": "Data dir",
@ -99,68 +129,6 @@ LOCALES = {
"value": "关闭" "value": "关闭"
} }
}, },
"max_samples": {
"en": {
"label": "Max samples",
"info": "Maximum samples per dataset."
},
"zh": {
"label": "最大样本数",
"info": "每个数据集最多使用的样本数。"
}
},
"batch_size": {
"en": {
"label": "Batch size",
"info": "Number of samples to process per GPU."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
},
"quantization_bit": {
"en": {
"label": "Quantization bit",
"info": "Enable 4/8-bit model quantization."
},
"zh": {
"label": "量化",
"info": "启用 4/8 比特模型量化。"
}
},
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
},
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
}
},
"learning_rate": { "learning_rate": {
"en": { "en": {
"label": "Learning rate", "label": "Learning rate",
@ -181,6 +149,26 @@ LOCALES = {
"info": "需要执行的训练总轮数。" "info": "需要执行的训练总轮数。"
} }
}, },
"max_samples": {
"en": {
"label": "Max samples",
"info": "Maximum samples per dataset."
},
"zh": {
"label": "最大样本数",
"info": "每个数据集最多使用的样本数。"
}
},
"batch_size": {
"en": {
"label": "Batch size",
"info": "Number of samples to process per GPU."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
},
"gradient_accumulation_steps": { "gradient_accumulation_steps": {
"en": { "en": {
"label": "Gradient accumulation", "label": "Gradient accumulation",
@ -231,6 +219,22 @@ LOCALES = {
"info": "每两次断点保存间的更新步数。" "info": "每两次断点保存间的更新步数。"
} }
}, },
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
},
"output_dir": { "output_dir": {
"en": { "en": {
"label": "Checkpoint name", "label": "Checkpoint name",
@ -241,6 +245,14 @@ LOCALES = {
"info": "保存模型断点的文件夹名称。" "info": "保存模型断点的文件夹名称。"
} }
}, },
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"loss_viewer": { "loss_viewer": {
"en": { "en": {
"label": "Loss" "label": "Loss"
@ -257,14 +269,6 @@ LOCALES = {
"label": "保存预测结果" "label": "保存预测结果"
} }
}, },
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"load_btn": { "load_btn": {
"en": { "en": {
"value": "Load model" "value": "Load model"
@ -281,6 +285,14 @@ LOCALES = {
"value": "卸载模型" "value": "卸载模型"
} }
}, },
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"query": { "query": {
"en": { "en": {
"placeholder": "Input..." "placeholder": "Input..."
@ -305,12 +317,12 @@ LOCALES = {
"value": "清空历史" "value": "清空历史"
} }
}, },
"max_new_tokens": { "max_length": {
"en": { "en": {
"label": "Maximum new tokens" "label": "Maximum length"
}, },
"zh": { "zh": {
"label": "最大生成长度" "label": "最大长度"
} }
}, },
"top_p": { "top_p": {

View File

@ -3,7 +3,7 @@ import os
import threading import threading
import time import time
import transformers import transformers
from typing import Optional, Tuple from typing import List, Optional, Tuple
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
@ -59,10 +59,26 @@ class Runner:
return finish_info if finish_info is not None else ALERTS["info_finished"][lang] return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
def run_train( def run_train(
self, lang, model_name, checkpoints, finetuning_type, template, self,
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples, lang: str,
fp16, quantization_bit, batch_size, gradient_accumulation_steps, model_name: str,
lr_scheduler_type, logging_steps, save_steps, output_dir checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
learning_rate: str,
num_train_epochs: str,
max_samples: str,
batch_size: int,
gradient_accumulation_steps: int,
lr_scheduler_type: str,
fp16: bool,
logging_steps: int,
save_steps: int,
output_dir: str
): ):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error: if error:
@ -79,24 +95,25 @@ class Runner:
args = dict( args = dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
do_train=True, do_train=True,
finetuning_type=finetuning_type,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir),
checkpoint_dir=checkpoint_dir,
overwrite_cache=True, overwrite_cache=True,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
max_samples=int(max_samples),
per_device_train_batch_size=batch_size, per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, lr_scheduler_type=lr_scheduler_type,
fp16=fp16,
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,
learning_rate=float(learning_rate), output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
num_train_epochs=float(num_train_epochs),
fp16=fp16,
quantization_bit=int(quantization_bit) if quantization_bit else None
) )
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
@ -120,8 +137,19 @@ class Runner:
yield self.finalize(lang) yield self.finalize(lang)
def run_eval( def run_eval(
self, lang, model_name, checkpoints, finetuning_type, template, self,
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_samples: str,
batch_size: int,
predict: bool
): ):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error: if error:
@ -140,17 +168,18 @@ class Runner:
args = dict( args = dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
do_eval=True, do_eval=True,
finetuning_type=finetuning_type,
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=output_dir,
checkpoint_dir=checkpoint_dir,
overwrite_cache=True, overwrite_cache=True,
predict_with_generate=True, predict_with_generate=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_samples=int(max_samples),
per_device_eval_batch_size=batch_size, per_device_eval_batch_size=batch_size,
quantization_bit=int(quantization_bit) if quantization_bit else None output_dir=output_dir
) )
if predict: if predict:

View File

@ -3,7 +3,7 @@ import json
import gradio as gr import gradio as gr
import matplotlib.figure import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import Tuple from typing import Any, Dict, Tuple
from datetime import datetime from datetime import datetime
from llmtuner.extras.ploting import smooth from llmtuner.extras.ploting import smooth
@ -23,7 +23,7 @@ def get_time() -> str:
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
def can_preview(dataset_dir: str, dataset: list) -> dict: def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
if ( if (
@ -36,7 +36,7 @@ def can_preview(dataset_dir: str, dataset: list) -> dict:
return gr.update(interactive=False) return gr.update(interactive=False)
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]: def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
data_file = dataset_info[dataset[0]]["file_name"] data_file = dataset_info[dataset[0]]["file_name"]
@ -45,6 +45,13 @@ def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
return len(data), data[:2], gr.update(visible=True) return len(data), data[:2], gr.update(visible=True)
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
if finetuning_type != "lora":
return gr.update(value="", interactive=False)
else:
return gr.update(interactive=True)
def get_eval_results(path: os.PathLike) -> str: def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4) result = json.dumps(json.load(f), indent=4)
@ -66,6 +73,10 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
if log_info.get("loss", None): if log_info.get("loss", None):
steps.append(log_info["current_steps"]) steps.append(log_info["current_steps"])
losses.append(log_info["loss"]) losses.append(log_info["loss"])
if len(losses) == 0:
return None
ax.plot(steps, losses, alpha=0.4, label="original") ax.plot(steps, losses, alpha=0.4, label="original")
ax.plot(steps, smooth(losses), label="smoothed") ax.plot(steps, smooth(losses), label="smoothed")
ax.legend() ax.legend()