support rope scaling, fix #475 #476 #478

This commit is contained in:
hiyouga 2023-08-12 20:46:27 +08:00
parent 2eb0eca65f
commit fa940c17b8
12 changed files with 267 additions and 277 deletions

View File

@ -1,8 +1,7 @@
import torch import torch
from types import MethodType
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import PreTrainedModel, TextIteratorStreamer from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.extras.template import get_template_and_fix_tokenizer
@ -15,10 +14,9 @@ class ChatModel:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.model = self.model.eval() # change to eval mode self.model = self.model.eval() # enable evaluation mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.source_prefix = data_args.source_prefix self.source_prefix = data_args.source_prefix
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
def process_args( def process_args(
self, self,

View File

@ -1,7 +1,5 @@
IGNORE_INDEX = -100 IGNORE_INDEX = -100
SFT_SCRIPT_PREFIX = "CUDA_VISIBLE_DEVICES=0 python "
LOG_FILE_NAME = "trainer_log.jsonl" LOG_FILE_NAME = "trainer_log.jsonl"
VALUE_HEAD_FILE_NAME = "value_head.bin" VALUE_HEAD_FILE_NAME = "value_head.bin"

View File

@ -8,6 +8,9 @@ class LoggerHandler(logging.Handler):
super().__init__() super().__init__()
self.log = "" self.log = ""
def reset(self):
self.log = ""
def emit(self, record): def emit(self, record):
if record.name == "httpx": if record.name == "httpx":
return return

View File

@ -43,9 +43,9 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether to use double quantization in int4 training or not."} metadata={"help": "Whether to use double quantization in int4 training or not."}
) )
compute_dtype: Optional[torch.dtype] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."} metadata={"help": "Adopt scaled rotary positional embeddings."}
) )
checkpoint_dir: Optional[str] = field( checkpoint_dir: Optional[str] = field(
default=None, default=None,
@ -63,8 +63,19 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."} metadata={"help": "Auth token to log in with Hugging Face Hub."}
) )
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
model_max_length: Optional[int] = field(
default=None,
metadata={"help": "Used in rope scaling. Do not specify this argument manually."}
)
def __post_init__(self): def __post_init__(self):
if self.compute_dtype is not None or self.model_max_length is not None:
raise ValueError("These arguments cannot be specified.")
if self.checkpoint_dir is not None: # support merging multiple lora weights if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]

View File

@ -1,5 +1,7 @@
import os import os
import math
import torch import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import ( from transformers import (
@ -66,15 +68,56 @@ def load_model_and_tokenizer(
**config_kwargs **config_kwargs
) )
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
model_to_load = model_args.checkpoint_dir[0] model_to_load = model_args.checkpoint_dir[0]
else: else:
model_to_load = model_args.model_name_or_path model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs) config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
is_mergeable = True
if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
if model_args.compute_dtype == torch.bfloat16:
setattr(config, "bf16", True)
else:
setattr(config, "fp16", True)
# Set RoPE scaling
if model_args.rope_scaling is not None:
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
if is_trainable:
logger.warning("Qwen model does not support rope scaling in training.")
else:
setattr(config, "use_dynamic_ntk", True)
setattr(config, "use_logn_attn", True)
logger.info("Using dynamic NTK scaling.")
elif hasattr(config, "rope_scaling"): # for LLaMA models
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length <= current_max_length:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
model_args.rope_scaling, scaling_factor
))
else:
logger.warning("Current model does not support RoPE scaling.")
# Quantization configurations (using bitsandbytes library). # Quantization configurations (using bitsandbytes library).
is_mergeable = True
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if model_args.quantization_bit == 8: if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
@ -95,7 +138,7 @@ def load_model_and_tokenizer(
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto" config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pretrained models (without valuehead). # Load and prepare pre-trained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_to_load, model_to_load,
config=config, config=config,
@ -104,6 +147,10 @@ def load_model_and_tokenizer(
**config_kwargs **config_kwargs
) )
# Disable custom generate method (for Qwen)
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Register auto class to save the custom code files. # Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}): if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class() config.__class__.register_for_auto_class()
@ -116,10 +163,10 @@ def load_model_and_tokenizer(
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
if stage == "rm" or stage == "ppo": # add value head # Prepare model with valuehead for RLHF
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) if stage == "rm" or stage == "ppo":
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging() reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]): if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
@ -133,9 +180,11 @@ def load_model_and_tokenizer(
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
# Prepare model for inference
if not is_trainable: if not is_trainable:
model.requires_grad_(False) # fix all model params model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(

View File

@ -151,14 +151,17 @@ def get_train_args(
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None:
if training_args.fp16: if training_args.fp16:
model_args.compute_dtype = torch.float16 model_args.compute_dtype = torch.float16
elif training_args.bf16: elif training_args.bf16:
if not torch.cuda.is_bf16_supported():
raise ValueError("Current device does not support bf16 training.")
model_args.compute_dtype = torch.bfloat16 model_args.compute_dtype = torch.bfloat16
else: else:
model_args.compute_dtype = torch.float32 model_args.compute_dtype = torch.float32
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary: # Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format( logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu, training_args.local_rank, training_args.device, training_args.n_gpu,

View File

@ -35,7 +35,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"): def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
model_args, _, training_args, finetuning_args, _ = get_train_args(args) model_args, _, training_args, finetuning_args, _, _ = get_train_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size) model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
try: try:

View File

@ -14,13 +14,13 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1) data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset]) dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
preview_btn.click( data_preview_btn.click(
get_preview, get_preview,
[dataset_dir, dataset], [dataset_dir, dataset],
[preview_count, preview_samples, preview_box], [preview_count, preview_samples, preview_box],
@ -35,6 +35,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
predict = gr.Checkbox(value=True) predict = gr.Checkbox(value=True)
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button() start_btn = gr.Button()
stop_btn = gr.Button() stop_btn = gr.Button()
@ -44,9 +45,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
with gr.Box(): with gr.Box():
output_box = gr.Markdown() output_box = gr.Markdown()
start_btn.click( input_list = [
runner.run_eval,
[
top_elems["lang"], top_elems["lang"],
top_elems["model_name"], top_elems["model_name"],
top_elems["checkpoints"], top_elems["checkpoints"],
@ -61,18 +60,21 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
max_samples, max_samples,
batch_size, batch_size,
predict predict
], ]
[
output_list = [
output_box, output_box,
process_bar process_bar
] ]
)
cmd_preview_btn.click(runner.preview_eval, input_list, output_list)
start_btn.click(runner.run_eval, input_list, output_list)
stop_btn.click(runner.set_abort, queue=False) stop_btn.click(runner.set_abort, queue=False)
return dict( return dict(
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=dataset, dataset=dataset,
preview_btn=preview_btn, data_preview_btn=data_preview_btn,
preview_count=preview_count, preview_count=preview_count,
preview_samples=preview_samples, preview_samples=preview_samples,
close_btn=close_btn, close_btn=close_btn,
@ -81,6 +83,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
max_samples=max_samples, max_samples=max_samples,
batch_size=batch_size, batch_size=batch_size,
predict=predict, predict=predict,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn, start_btn=start_btn,
stop_btn=stop_btn, stop_btn=stop_btn,
output_box=output_box output_box=output_box

View File

@ -16,13 +16,13 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1) data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset]) dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
preview_btn.click( data_preview_btn.click(
get_preview, get_preview,
[dataset_dir, dataset], [dataset_dir, dataset],
[preview_count, preview_samples, preview_box], [preview_count, preview_samples, preview_box],
@ -61,15 +61,12 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
resume_lora_training = gr.Checkbox(value=True, scale=1) resume_lora_training = gr.Checkbox(value=True, scale=1)
with gr.Row(): with gr.Row():
preview_script_btn = gr.Button() cmd_preview_btn = gr.Button()
start_btn = gr.Button() start_btn = gr.Button()
stop_btn = gr.Button() stop_btn = gr.Button()
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
with gr.Box():
preview_script_box = gr.Textbox()
with gr.Row(): with gr.Row():
output_dir = gr.Textbox() output_dir = gr.Textbox()
@ -82,9 +79,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
with gr.Column(scale=1): with gr.Column(scale=1):
loss_viewer = gr.Plot() loss_viewer = gr.Plot()
preview_script_btn.click( input_list = [
runner.preview_sft_script,
[
top_elems["lang"], top_elems["lang"],
top_elems["model_name"], top_elems["model_name"],
top_elems["checkpoints"], top_elems["checkpoints"],
@ -114,50 +109,15 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
lora_target, lora_target,
resume_lora_training, resume_lora_training,
output_dir output_dir
],
[
preview_script_box
] ]
)
start_btn.click( output_list = [
runner.run_train,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
dataset_dir,
dataset,
max_source_length,
max_target_length,
learning_rate,
num_train_epochs,
max_samples,
batch_size,
gradient_accumulation_steps,
lr_scheduler_type,
max_grad_norm,
val_size,
logging_steps,
save_steps,
warmup_steps,
compute_type,
padding_side,
lora_rank,
lora_dropout,
lora_target,
resume_lora_training,
output_dir
],
[
output_box, output_box,
process_bar process_bar
] ]
)
cmd_preview_btn.click(runner.preview_train, input_list, output_list)
start_btn.click(runner.run_train, input_list, output_list)
stop_btn.click(runner.set_abort, queue=False) stop_btn.click(runner.set_abort, queue=False)
process_bar.change( process_bar.change(
@ -167,7 +127,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
return dict( return dict(
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=dataset, dataset=dataset,
preview_btn=preview_btn, data_preview_btn=data_preview_btn,
preview_count=preview_count, preview_count=preview_count,
preview_samples=preview_samples, preview_samples=preview_samples,
close_btn=close_btn, close_btn=close_btn,
@ -192,11 +152,10 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
lora_target=lora_target, lora_target=lora_target,
resume_lora_training=resume_lora_training, resume_lora_training=resume_lora_training,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn, start_btn=start_btn,
stop_btn=stop_btn, stop_btn=stop_btn,
output_dir=output_dir, output_dir=output_dir,
output_box=output_box, output_box=output_box,
loss_viewer=loss_viewer, loss_viewer=loss_viewer
preview_script_btn=preview_script_btn,
preview_script_box=preview_script_box
) )

View File

@ -105,12 +105,12 @@ LOCALES = {
"label": "数据集" "label": "数据集"
} }
}, },
"preview_btn": { "data_preview_btn": {
"en": { "en": {
"value": "Preview" "value": "Preview dataset"
}, },
"zh": { "zh": {
"value": "预览" "value": "预览数据集"
} }
}, },
"preview_count": { "preview_count": {
@ -335,6 +335,14 @@ LOCALES = {
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。" "info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
} }
}, },
"cmd_preview_btn": {
"en": {
"value": "Preview command"
},
"zh": {
"value": "预览命令"
}
},
"start_btn": { "start_btn": {
"en": { "en": {
"value": "Start" "value": "Start"
@ -500,22 +508,6 @@ LOCALES = {
"zh": { "zh": {
"value": "开始导出" "value": "开始导出"
} }
},
"preview_script_btn": {
"en": {
"value": "preview train script"
},
"zh": {
"value": "预览训练脚本命令"
}
},
"preview_script_box": {
"en": {
"label": "SFT Script Preview",
},
"zh": {
"label": "训练命令预览",
}
} }
} }

View File

@ -5,16 +5,16 @@ import threading
import time import time
import transformers import transformers
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from typing import Generator, List, Tuple from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE, SFT_SCRIPT_PREFIX from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import get_eval_results, update_process_bar from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
class Runner: class Runner:
@ -22,39 +22,36 @@ class Runner:
def __init__(self): def __init__(self):
self.aborted = False self.aborted = False
self.running = False self.running = False
self.logger_handler = LoggerHandler()
self.logger_handler.setLevel(logging.INFO)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def set_abort(self): def set_abort(self):
self.aborted = True self.aborted = True
self.running = False self.running = False
def initialize( def _initialize(
self, lang: str, model_name: str, dataset: List[str] self, lang: str, model_name: str, dataset: List[str]
) -> Tuple[str, str, LoggerHandler, LogCallback]: ) -> str:
if self.running: if self.running:
return None, ALERTS["err_conflict"][lang], None, None return ALERTS["err_conflict"][lang]
if not model_name: if not model_name:
return None, ALERTS["err_no_model"][lang], None, None return ALERTS["err_no_model"][lang]
model_name_or_path = get_model_path(model_name) if not get_model_path(model_name):
if not model_name_or_path: return ALERTS["err_no_path"][lang]
return None, ALERTS["err_no_path"][lang], None, None
if len(dataset) == 0: if len(dataset) == 0:
return None, ALERTS["err_no_dataset"][lang], None, None return ALERTS["err_no_dataset"][lang]
self.aborted = False self.aborted = False
self.running = True self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
return ""
logger_handler = LoggerHandler() def _finalize(
logger_handler.setLevel(logging.INFO)
logging.root.addHandler(logger_handler)
transformers.logging.add_handler(logger_handler)
trainer_callback = LogCallback(self)
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(
self, lang: str, finish_info: str self, lang: str, finish_info: str
) -> str: ) -> str:
self.running = False self.running = False
@ -64,7 +61,7 @@ class Runner:
else: else:
return finish_info return finish_info
def run_train( def _parse_train_args(
self, self,
lang: str, lang: str,
model_name: str, model_name: str,
@ -95,52 +92,19 @@ class Runner:
lora_target: str, lora_target: str,
resume_lora_training: bool, resume_lora_training: bool,
output_dir: str output_dir: str
) -> Generator[str, None, None]: ) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = self._build_args(batch_size, checkpoints, compute_type, dataset, dataset_dir, finetuning_type,
gradient_accumulation_steps, learning_rate, logging_steps, lora_dropout, lora_rank,
lora_target, lr_scheduler_type, max_grad_norm, max_samples, max_source_length,
max_target_length, model_name, model_name_or_path, num_train_epochs, output_dir,
padding_side, quantization_bit, resume_lora_training, save_steps, source_prefix,
template, val_size, warmup_steps)
run_kwargs = dict(args=args, callbacks=[trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield logger_handler.log, update_process_bar(trainer_callback)
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info), gr.update(visible=False)
def _build_args(self, batch_size, checkpoints, compute_type, dataset, dataset_dir, finetuning_type,
gradient_accumulation_steps, learning_rate, logging_steps, lora_dropout, lora_rank, lora_target,
lr_scheduler_type, max_grad_norm, max_samples, max_source_length, max_target_length, model_name,
model_name_or_path, num_train_epochs, output_dir, padding_side, quantization_bit,
resume_lora_training, save_steps, source_prefix, template, val_size, warmup_steps):
if checkpoints: if checkpoints:
checkpoint_dir = ",".join( checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] [os.path.join(get_save_dir(model_name), finetuning_type, ckpt) for ckpt in checkpoints]
) )
else: else:
checkpoint_dir = None checkpoint_dir = None
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = dict( args = dict(
stage="sft", stage="sft",
model_name_or_path=model_name_or_path, model_name_or_path=get_model_path(model_name),
do_train=True, do_train=True,
overwrite_cache=True, overwrite_cache=True,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
@ -171,14 +135,16 @@ class Runner:
resume_lora_training=resume_lora_training, resume_lora_training=resume_lora_training,
output_dir=output_dir output_dir=output_dir
) )
if val_size > 1e-6: if val_size > 1e-6:
args["val_size"] = val_size args["val_size"] = val_size
args["evaluation_strategy"] = "steps" args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True args["load_best_model_at_end"] = True
return args
def run_eval( return lang, model_name, dataset, output_dir, args
def _parse_eval_args(
self, self,
lang: str, lang: str,
model_name: str, model_name: str,
@ -194,12 +160,7 @@ class Runner:
max_samples: str, max_samples: str,
batch_size: int, batch_size: int,
predict: bool predict: bool
) -> Generator[str, None, None]: ) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
if checkpoints: if checkpoints:
checkpoint_dir = ",".join( checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
@ -211,7 +172,7 @@ class Runner:
args = dict( args = dict(
stage="sft", stage="sft",
model_name_or_path=model_name_or_path, model_name_or_path=get_model_path(model_name),
do_eval=True, do_eval=True,
overwrite_cache=True, overwrite_cache=True,
predict_with_generate=True, predict_with_generate=True,
@ -233,7 +194,33 @@ class Runner:
args.pop("do_eval", None) args.pop("do_eval", None)
args["do_predict"] = True args["do_predict"] = True
run_kwargs = dict(args=args, callbacks=[trainer_callback]) return lang, model_name, dataset, output_dir, args
def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_train_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, output_dir, args = self._parse_train_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs) thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start() thread.start()
@ -242,60 +229,37 @@ class Runner:
if self.aborted: if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False) yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else: else:
yield logger_handler.log, update_process_bar(trainer_callback) yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.update(visible=False)
def run_eval(self, *args) -> Generator[str, None, None]:
lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, "all_results.json")): if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json")) finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info), gr.update(visible=False) yield self._finalize(lang, finish_info), gr.update(visible=False)
def preview_sft_script(
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
learning_rate: str,
num_train_epochs: str,
max_samples: str,
batch_size: int,
gradient_accumulation_steps: int,
lr_scheduler_type: str,
max_grad_norm: str,
val_size: float,
logging_steps: int,
save_steps: int,
warmup_steps: int,
compute_type: str,
padding_side: str,
lora_rank: int,
lora_dropout: float,
lora_target: str,
resume_lora_training: bool,
output_dir: str
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = self._build_args(batch_size, checkpoints, compute_type, dataset, dataset_dir, finetuning_type,
gradient_accumulation_steps, learning_rate, logging_steps, lora_dropout, lora_rank,
lora_target, lr_scheduler_type, max_grad_norm, max_samples, max_source_length,
max_target_length, model_name, model_name_or_path, num_train_epochs, output_dir,
padding_side, quantization_bit, resume_lora_training, save_steps, source_prefix,
template, val_size, warmup_steps)
script_lines = [SFT_SCRIPT_PREFIX]
for param_key, param_value in args.items():
# filter None
if param_value:
script_lines.append(" --" + param_key + " " + str(param_value) + " ")
script_str = "\\\n".join(script_lines)
return gr.update(value=script_str)

View File

@ -62,6 +62,16 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
return gr.update(interactive=True) return gr.update(interactive=True)
def gen_cmd(args: Dict[str, Any]) -> str:
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
for k, v in args.items():
if v is not None and v is not False and v != "":
cmd_lines.append(" --{} {} ".format(k, str(v)))
cmd_text = "\\\n".join(cmd_lines)
cmd_text = "```bash\n{}\n```".format(cmd_text)
return cmd_text
def get_eval_results(path: os.PathLike) -> str: def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4) result = json.dumps(json.load(f), indent=4)