update webui

This commit is contained in:
hiyouga 2023-08-09 00:26:11 +08:00
parent eecc4b2131
commit 3a720aac66
5 changed files with 24 additions and 6 deletions

View File

@ -4,6 +4,7 @@ datasets>=2.12.0
accelerate>=0.21.0
peft>=0.4.0
trl>=0.4.7
scipy
sentencepiece
tiktoken
jieba

View File

@ -19,7 +19,6 @@ class ChatModel:
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.source_prefix = data_args.source_prefix
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
def process_args(

View File

@ -185,6 +185,7 @@ def get_template_and_fix_tokenizer(
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))
return template

View File

@ -513,6 +513,10 @@ ALERTS = {
"en": "Please provide export dir.",
"zh": "请填写导出目录"
},
"err_failed": {
"en": "Failed.",
"zh": "训练出错。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"

View File

@ -3,6 +3,7 @@ import os
import threading
import time
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
from typing import Generator, List, Optional, Tuple
from llmtuner.extras.callbacks import LogCallback
@ -53,14 +54,14 @@ class Runner:
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(
self, lang: str, finish_info: Optional[str] = None
self, lang: str, finish_info: str
) -> str:
self.running = False
torch_gc()
if self.aborted:
return ALERTS["info_aborted"][lang]
else:
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
return finish_info
def run_train(
self,
@ -104,6 +105,8 @@ class Runner:
else:
checkpoint_dir = None
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = dict(
stage="sft",
model_name_or_path=model_name_or_path,
@ -133,7 +136,7 @@ class Runner:
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
output_dir=output_dir
)
if dev_ratio > 1e-6:
@ -153,7 +156,12 @@ class Runner:
else:
yield format_info(logger_handler.log, trainer_callback)
yield self.finalize(lang)
if os.path.exists(os.path.join(output_dir), TRAINING_ARGS_NAME):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info)
def run_eval(
self,
@ -221,4 +229,9 @@ class Runner:
else:
yield format_info(logger_handler.log, trainer_callback)
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info)