update ppo and demo in webui

This commit is contained in:
hiyouga 2023-11-16 14:55:26 +08:00
parent ff52b1779c
commit 7537dd434f
4 changed files with 32 additions and 16 deletions

View File

@ -120,10 +120,12 @@ register_model_group(
register_model_group(
models={
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b"
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
"ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
"ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
},
template="llama2_zh"
)

View File

@ -25,9 +25,13 @@ class WebChatModel(ChatModel):
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if not lazy_init:
if not lazy_init: # read arguments from command line
super().__init__()
if demo_mode: # load openchat 3.5 by default
super().__init__(dict(model_name_or_path="openchat/openchat_3.5", template="openchat"))
@property
def loaded(self) -> bool:
return self.model is not None
@ -75,6 +79,11 @@ class WebChatModel(ChatModel):
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
lang = data[self.manager.get_elem_by_name("top.lang")]
if self.demo_mode:
yield ALERTS["err_demo"][lang]
return
yield ALERTS["info_unloading"][lang]
self.model = None
self.tokenizer = None

View File

@ -38,7 +38,7 @@ def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine)
with gr.Tab("Evaluate"):
with gr.Tab("Evaluate & Predict"):
engine.manager.all_elems["eval"] = create_eval_tab(engine)
with gr.Tab("Chat"):

View File

@ -31,7 +31,6 @@ class Runner:
self.thread: "Thread" = None
self.do_train = True
self.running_data: Dict["Component", Any] = None
self.monitor_inputs: Dict[str, str] = None
""" State """
self.aborted = False
self.running = False
@ -75,6 +74,7 @@ class Runner:
def _finalize(self, lang: str, finish_info: str) -> str:
self.thread = None
self.running_data = None
self.running = False
torch_gc()
if self.aborted:
@ -87,9 +87,9 @@ class Runner:
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
checkpoint_dir = ",".join([get_save_dir(
get("top.model_name"), get("top.finetuning_type"), ckpt
) for ckpt in get("top.checkpoints")])
else:
checkpoint_dir = None
@ -139,7 +139,10 @@ class Runner:
args["upcast_layernorm"] = True
if args["stage"] == "ppo":
args["reward_model"] = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.reward_model"))
args["reward_model"] = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
)
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
if args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta")
@ -157,9 +160,9 @@ class Runner:
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
checkpoint_dir = ",".join([get_save_dir(
get("top.model_name"), get("top.finetuning_type"), ckpt
) for ckpt in get("top.checkpoints")])
output_dir = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
)
@ -216,7 +219,6 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
self.do_train, self.running_data = do_train, data
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
self.thread.start()
yield from self.monitor()
@ -235,7 +237,10 @@ class Runner:
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
self.running = True
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
lang = self.running_data[self.manager.get_elem_by_name("top.lang")]
output_dir = self.running_data[self.manager.get_elem_by_name(
"{}.output_dir".format("train" if self.do_train else "eval")
)]
while self.thread.is_alive():
time.sleep(2)
if self.aborted: