update webui
This commit is contained in:
parent
eecc4b2131
commit
3a720aac66
|
@ -4,6 +4,7 @@ datasets>=2.12.0
|
|||
accelerate>=0.21.0
|
||||
peft>=0.4.0
|
||||
trl>=0.4.7
|
||||
scipy
|
||||
sentencepiece
|
||||
tiktoken
|
||||
jieba
|
||||
|
|
|
@ -19,7 +19,6 @@ class ChatModel:
|
|||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
||||
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
|
||||
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
|
||||
|
||||
def process_args(
|
||||
|
|
|
@ -185,6 +185,7 @@ def get_template_and_fix_tokenizer(
|
|||
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))
|
||||
return template
|
||||
|
||||
|
||||
|
|
|
@ -513,6 +513,10 @@ ALERTS = {
|
|||
"en": "Please provide export dir.",
|
||||
"zh": "请填写导出目录"
|
||||
},
|
||||
"err_failed": {
|
||||
"en": "Failed.",
|
||||
"zh": "训练出错。"
|
||||
},
|
||||
"info_aborting": {
|
||||
"en": "Aborted, wait for terminating...",
|
||||
"zh": "训练中断,正在等待线程结束……"
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
import threading
|
||||
import time
|
||||
import transformers
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
|
@ -53,14 +54,14 @@ class Runner:
|
|||
return model_name_or_path, "", logger_handler, trainer_callback
|
||||
|
||||
def finalize(
|
||||
self, lang: str, finish_info: Optional[str] = None
|
||||
self, lang: str, finish_info: str
|
||||
) -> str:
|
||||
self.running = False
|
||||
torch_gc()
|
||||
if self.aborted:
|
||||
return ALERTS["info_aborted"][lang]
|
||||
else:
|
||||
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
||||
return finish_info
|
||||
|
||||
def run_train(
|
||||
self,
|
||||
|
@ -104,6 +105,8 @@ class Runner:
|
|||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=model_name_or_path,
|
||||
|
@ -133,7 +136,7 @@ class Runner:
|
|||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||
output_dir=output_dir
|
||||
)
|
||||
|
||||
if dev_ratio > 1e-6:
|
||||
|
@ -153,7 +156,12 @@ class Runner:
|
|||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback)
|
||||
|
||||
yield self.finalize(lang)
|
||||
if os.path.exists(os.path.join(output_dir), TRAINING_ARGS_NAME):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self.finalize(lang, finish_info)
|
||||
|
||||
def run_eval(
|
||||
self,
|
||||
|
@ -221,4 +229,9 @@ class Runner:
|
|||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback)
|
||||
|
||||
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
|
||||
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
||||
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self.finalize(lang, finish_info)
|
||||
|
|
Loading…
Reference in New Issue