fix bug in web ui
This commit is contained in:
parent
7537dd434f
commit
6efa38be46
|
@ -3,6 +3,9 @@ import logging
|
|||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
r"""
|
||||
Logger handler used in Web UI.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -19,16 +22,10 @@ class LoggerHandler(logging.Handler):
|
|||
self.log += "\n\n"
|
||||
|
||||
|
||||
def reset_logging():
|
||||
r"""
|
||||
Removes basic config of root logger
|
||||
"""
|
||||
root = logging.getLogger()
|
||||
list(map(root.removeHandler, root.handlers))
|
||||
list(map(root.removeFilter, root.filters))
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
r"""
|
||||
Gets a standard logger with a stream hander to stdout.
|
||||
"""
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S"
|
||||
|
@ -41,3 +38,12 @@ def get_logger(name: str) -> logging.Logger:
|
|||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def reset_logging() -> None:
|
||||
r"""
|
||||
Removes basic config of root logger. (unused in script)
|
||||
"""
|
||||
root = logging.getLogger()
|
||||
list(map(root.removeHandler, root.handlers))
|
||||
list(map(root.removeFilter, root.filters))
|
||||
|
|
|
@ -202,7 +202,6 @@ def load_model_and_tokenizer(
|
|||
# Prepare model with valuehead for RLHF
|
||||
if stage in ["rm", "ppo"]:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
vhead_path = (
|
||||
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
||||
)
|
||||
|
|
|
@ -236,11 +236,12 @@ class Runner:
|
|||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
|
||||
self.running = True
|
||||
lang = self.running_data[self.manager.get_elem_by_name("top.lang")]
|
||||
output_dir = self.running_data[self.manager.get_elem_by_name(
|
||||
lang = get("top.lang")
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
||||
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||
)]
|
||||
))
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
|
|
Loading…
Reference in New Issue