fix bug in web ui
This commit is contained in:
parent
7537dd434f
commit
6efa38be46
|
@ -3,6 +3,9 @@ import logging
|
||||||
|
|
||||||
|
|
||||||
class LoggerHandler(logging.Handler):
|
class LoggerHandler(logging.Handler):
|
||||||
|
r"""
|
||||||
|
Logger handler used in Web UI.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -19,16 +22,10 @@ class LoggerHandler(logging.Handler):
|
||||||
self.log += "\n\n"
|
self.log += "\n\n"
|
||||||
|
|
||||||
|
|
||||||
def reset_logging():
|
|
||||||
r"""
|
|
||||||
Removes basic config of root logger
|
|
||||||
"""
|
|
||||||
root = logging.getLogger()
|
|
||||||
list(map(root.removeHandler, root.handlers))
|
|
||||||
list(map(root.removeFilter, root.filters))
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: str) -> logging.Logger:
|
def get_logger(name: str) -> logging.Logger:
|
||||||
|
r"""
|
||||||
|
Gets a standard logger with a stream hander to stdout.
|
||||||
|
"""
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
datefmt="%m/%d/%Y %H:%M:%S"
|
datefmt="%m/%d/%Y %H:%M:%S"
|
||||||
|
@ -41,3 +38,12 @@ def get_logger(name: str) -> logging.Logger:
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def reset_logging() -> None:
|
||||||
|
r"""
|
||||||
|
Removes basic config of root logger. (unused in script)
|
||||||
|
"""
|
||||||
|
root = logging.getLogger()
|
||||||
|
list(map(root.removeHandler, root.handlers))
|
||||||
|
list(map(root.removeFilter, root.filters))
|
||||||
|
|
|
@ -202,7 +202,6 @@ def load_model_and_tokenizer(
|
||||||
# Prepare model with valuehead for RLHF
|
# Prepare model with valuehead for RLHF
|
||||||
if stage in ["rm", "ppo"]:
|
if stage in ["rm", "ppo"]:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
reset_logging()
|
|
||||||
vhead_path = (
|
vhead_path = (
|
||||||
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
||||||
)
|
)
|
||||||
|
|
|
@ -236,11 +236,12 @@ class Runner:
|
||||||
yield from self._launch(data, do_train=False)
|
yield from self._launch(data, do_train=False)
|
||||||
|
|
||||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
|
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
|
||||||
self.running = True
|
self.running = True
|
||||||
lang = self.running_data[self.manager.get_elem_by_name("top.lang")]
|
lang = get("top.lang")
|
||||||
output_dir = self.running_data[self.manager.get_elem_by_name(
|
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
||||||
"{}.output_dir".format("train" if self.do_train else "eval")
|
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||||
)]
|
))
|
||||||
while self.thread.is_alive():
|
while self.thread.is_alive():
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
|
|
Loading…
Reference in New Issue