Release v0.1.6
This commit is contained in:
parent
156710a995
commit
a48cb0d474
|
@ -55,6 +55,7 @@
|
|||
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
|
||||
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||
|
||||
- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
|
||||
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
|
||||
|
@ -408,6 +409,8 @@ Please follow the model licenses to use the corresponding model weights:
|
|||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
|
||||
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
|
||||
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -6,4 +6,4 @@ from llmtuner.tuner import export_model, run_exp
|
|||
from llmtuner.webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.1.5"
|
||||
__version__ = "0.1.6"
|
||||
|
|
|
@ -93,11 +93,13 @@ def get_dataset(
|
|||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
||||
if dataset_attr.source_prefix: # add prefix
|
||||
features = None
|
||||
if data_args.streaming:
|
||||
features = dataset.features
|
||||
features["prefix"] = Value(dtype="string", id=None)
|
||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
|
||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
|
||||
else:
|
||||
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
||||
dataset = dataset.add_column("prefix", prefix_data)
|
||||
|
||||
all_datasets.append(dataset)
|
||||
|
||||
|
|
|
@ -19,7 +19,8 @@ def split_dataset(
|
|||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": train_set, "eval_dataset": val_set}
|
||||
else:
|
||||
dataset = dataset.train_test_split(test_size=data_args.val_size, seed=training_args.seed)
|
||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
|
||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
if data_args.streaming:
|
||||
|
|
|
@ -37,7 +37,9 @@ SUPPORTED_MODELS = {
|
|||
"InternLM-7B": "internlm/internlm-7b",
|
||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat"
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||
"ChatGLM2-6B": "THUDM/chatglm2-6b"
|
||||
}
|
||||
|
||||
DEFAULT_MODULE = {
|
||||
|
@ -48,5 +50,7 @@ DEFAULT_MODULE = {
|
|||
"Falcon": "query_key_value",
|
||||
"Baichuan": "W_pack",
|
||||
"InternLM": "q_proj,v_proj",
|
||||
"Qwen": "c_attn"
|
||||
"Qwen": "c_attn",
|
||||
"XVERSE": "q_proj,v_proj",
|
||||
"ChatGLM2": "query_key_value"
|
||||
}
|
||||
|
|
|
@ -178,7 +178,7 @@ def register_template(
|
|||
stop_words: List[str],
|
||||
use_history: bool
|
||||
) -> None:
|
||||
template_class = Llama2Template if name == "llama2" else Template
|
||||
template_class = Llama2Template if "llama2" in name else Template
|
||||
templates[name] = template_class(
|
||||
prefix=prefix,
|
||||
prompt=prompt,
|
||||
|
@ -272,6 +272,23 @@ register_template(
|
|||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
|
||||
"""
|
||||
register_template(
|
||||
name="llama2_zh",
|
||||
prefix=[
|
||||
"<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST] "
|
||||
],
|
||||
sep=[],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
||||
|
|
|
@ -57,6 +57,10 @@ class FinetuningArguments:
|
|||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
||||
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
dpo_beta: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."}
|
||||
|
|
|
@ -55,10 +55,6 @@ class ModelArguments:
|
|||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
|
|
|
@ -65,7 +65,7 @@ def init_adapter(
|
|||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
||||
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||
|
||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
|
||||
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
|
|
|
@ -18,7 +18,7 @@ logger = get_logger(__name__)
|
|||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks + [LogCallback()]
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||
|
||||
if general_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
|
|
|
@ -16,6 +16,6 @@ def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"
|
|||
|
||||
close_btn = gr.Button()
|
||||
|
||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
|
||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
||||
|
||||
return preview_box, preview_count, preview_samples, close_btn
|
||||
|
|
|
@ -20,7 +20,12 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
|
||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
preview_btn.click(
|
||||
get_preview,
|
||||
[dataset_dir, dataset],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
|
@ -33,6 +38,9 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
|
@ -54,7 +62,10 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
batch_size,
|
||||
predict
|
||||
],
|
||||
[output_box]
|
||||
[
|
||||
output_box,
|
||||
process_bar
|
||||
]
|
||||
)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
|
||||
|
|
|
@ -22,7 +22,12 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
|
||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
preview_btn.click(
|
||||
get_preview,
|
||||
[dataset_dir, dataset],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
|
@ -46,12 +51,14 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||
padding_side = gr.Radio(choices=["left", "right"], value="left")
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
resume_lora_training = gr.Checkbox(value=True, scale=1)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button()
|
||||
|
@ -59,7 +66,11 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
output_dir = gr.Textbox()
|
||||
with gr.Row():
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
with gr.Row():
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
@ -93,16 +104,21 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||
save_steps,
|
||||
warmup_steps,
|
||||
compute_type,
|
||||
padding_side,
|
||||
lora_rank,
|
||||
lora_dropout,
|
||||
lora_target,
|
||||
resume_lora_training,
|
||||
output_dir
|
||||
],
|
||||
[output_box]
|
||||
[
|
||||
output_box,
|
||||
process_bar
|
||||
]
|
||||
)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
|
||||
output_box.change(
|
||||
process_bar.change(
|
||||
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
|
||||
)
|
||||
|
||||
|
@ -128,10 +144,12 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
compute_type=compute_type,
|
||||
padding_side=padding_side,
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
resume_lora_training=resume_lora_training,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
|
|
|
@ -43,7 +43,7 @@ def create_top() -> Dict[str, "Component"]:
|
|||
can_quantize, [finetuning_type], [quantization_bit]
|
||||
)
|
||||
|
||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False)
|
||||
|
||||
return dict(
|
||||
lang=lang,
|
||||
|
|
|
@ -67,7 +67,7 @@ def create_web_demo() -> gr.Blocks:
|
|||
|
||||
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||
|
||||
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
|
|
|
@ -277,6 +277,16 @@ LOCALES = {
|
|||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||
}
|
||||
},
|
||||
"padding_side": {
|
||||
"en": {
|
||||
"label": "Padding side",
|
||||
"info": "The side on which the model should have padding applied."
|
||||
},
|
||||
"zh": {
|
||||
"label": "填充位置",
|
||||
"info": "使用左填充或右填充。"
|
||||
}
|
||||
},
|
||||
"lora_tab": {
|
||||
"en": {
|
||||
"label": "LoRA configurations"
|
||||
|
@ -315,6 +325,16 @@ LOCALES = {
|
|||
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
|
||||
}
|
||||
},
|
||||
"resume_lora_training": {
|
||||
"en": {
|
||||
"label": "Resume LoRA training",
|
||||
"info": "Whether to resume training from the last LoRA weights or create new lora weights."
|
||||
},
|
||||
"zh": {
|
||||
"label": "继续上次的训练",
|
||||
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
|
||||
}
|
||||
},
|
||||
"start_btn": {
|
||||
"en": {
|
||||
"value": "Start"
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import gradio as gr
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
@ -13,7 +14,7 @@ from llmtuner.extras.misc import torch_gc
|
|||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import format_info, get_eval_results
|
||||
from llmtuner.webui.utils import get_eval_results, update_process_bar
|
||||
|
||||
|
||||
class Runner:
|
||||
|
@ -88,14 +89,16 @@ class Runner:
|
|||
save_steps: int,
|
||||
warmup_steps: int,
|
||||
compute_type: str,
|
||||
padding_side: str,
|
||||
lora_rank: int,
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
resume_lora_training: bool,
|
||||
output_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
yield error, gr.update(visible=False)
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
|
@ -133,9 +136,11 @@ class Runner:
|
|||
warmup_steps=warmup_steps,
|
||||
fp16=(compute_type == "fp16"),
|
||||
bf16=(compute_type == "bf16"),
|
||||
padding_side=padding_side,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
resume_lora_training=resume_lora_training,
|
||||
output_dir=output_dir
|
||||
)
|
||||
|
||||
|
@ -150,18 +155,18 @@ class Runner:
|
|||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(1)
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback)
|
||||
yield logger_handler.log, update_process_bar(trainer_callback)
|
||||
|
||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self.finalize(lang, finish_info)
|
||||
yield self.finalize(lang, finish_info), gr.update(visible=False)
|
||||
|
||||
def run_eval(
|
||||
self,
|
||||
|
@ -182,7 +187,7 @@ class Runner:
|
|||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
yield error, gr.update(visible=False)
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
|
@ -223,15 +228,15 @@ class Runner:
|
|||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(1)
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback)
|
||||
yield logger_handler.log, update_process_bar(trainer_callback)
|
||||
|
||||
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
||||
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self.finalize(lang, finish_info)
|
||||
yield self.finalize(lang, finish_info), gr.update(visible=False)
|
||||
|
|
|
@ -15,13 +15,18 @@ if TYPE_CHECKING:
|
|||
from llmtuner.extras.callbacks import LogCallback
|
||||
|
||||
|
||||
def format_info(log: str, callback: "LogCallback") -> str:
|
||||
info = log
|
||||
if callback.max_steps:
|
||||
info += "Running **{:d}/{:d}**: {} < {}\n".format(
|
||||
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
|
||||
)
|
||||
return info
|
||||
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
||||
if not callback.max_steps:
|
||||
return gr.update(visible=False)
|
||||
|
||||
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
|
||||
label = "Running {:d}/{:d}: {} < {}".format(
|
||||
callback.cur_steps,
|
||||
callback.max_steps,
|
||||
callback.elapsed_time,
|
||||
callback.remaining_time
|
||||
)
|
||||
return gr.update(label=label, value=percentage, visible=True)
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
|
|
Loading…
Reference in New Issue