update ppo and demo in webui
This commit is contained in:
parent
ff52b1779c
commit
7537dd434f
|
@ -120,10 +120,12 @@ register_model_group(
|
|||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
|
||||
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
|
||||
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
|
||||
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b"
|
||||
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
|
||||
"ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
|
||||
"ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
|
||||
"ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
|
||||
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
|
||||
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
|
||||
},
|
||||
template="llama2_zh"
|
||||
)
|
||||
|
|
|
@ -25,9 +25,13 @@ class WebChatModel(ChatModel):
|
|||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
if not lazy_init:
|
||||
|
||||
if not lazy_init: # read arguments from command line
|
||||
super().__init__()
|
||||
|
||||
if demo_mode: # load openchat 3.5 by default
|
||||
super().__init__(dict(model_name_or_path="openchat/openchat_3.5", template="openchat"))
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
@ -75,6 +79,11 @@ class WebChatModel(ChatModel):
|
|||
|
||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
lang = data[self.manager.get_elem_by_name("top.lang")]
|
||||
|
||||
if self.demo_mode:
|
||||
yield ALERTS["err_demo"][lang]
|
||||
return
|
||||
|
||||
yield ALERTS["info_unloading"][lang]
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
|
|
|
@ -38,7 +38,7 @@ def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
|
|||
with gr.Tab("Train"):
|
||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||
|
||||
with gr.Tab("Evaluate"):
|
||||
with gr.Tab("Evaluate & Predict"):
|
||||
engine.manager.all_elems["eval"] = create_eval_tab(engine)
|
||||
|
||||
with gr.Tab("Chat"):
|
||||
|
|
|
@ -31,7 +31,6 @@ class Runner:
|
|||
self.thread: "Thread" = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
self.monitor_inputs: Dict[str, str] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
|
@ -75,6 +74,7 @@ class Runner:
|
|||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
self.thread = None
|
||||
self.running_data = None
|
||||
self.running = False
|
||||
torch_gc()
|
||||
if self.aborted:
|
||||
|
@ -87,9 +87,9 @@ class Runner:
|
|||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
checkpoint_dir = ",".join([get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||
) for ckpt in get("top.checkpoints")])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
|
@ -139,7 +139,10 @@ class Runner:
|
|||
args["upcast_layernorm"] = True
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.reward_model"))
|
||||
args["reward_model"] = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
|
||||
)
|
||||
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
|
||||
|
||||
if args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
|
@ -157,9 +160,9 @@ class Runner:
|
|||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
checkpoint_dir = ",".join([get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||
) for ckpt in get("top.checkpoints")])
|
||||
output_dir = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
||||
)
|
||||
|
@ -216,7 +219,6 @@ class Runner:
|
|||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.do_train, self.running_data = do_train, data
|
||||
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
yield from self.monitor()
|
||||
|
@ -235,7 +237,10 @@ class Runner:
|
|||
|
||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
self.running = True
|
||||
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
||||
lang = self.running_data[self.manager.get_elem_by_name("top.lang")]
|
||||
output_dir = self.running_data[self.manager.get_elem_by_name(
|
||||
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||
)]
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
|
|
Loading…
Reference in New Issue