diff --git a/src/llmtuner/extras/logging.py b/src/llmtuner/extras/logging.py index d6f185e6..d01c14a4 100644 --- a/src/llmtuner/extras/logging.py +++ b/src/llmtuner/extras/logging.py @@ -3,6 +3,9 @@ import logging class LoggerHandler(logging.Handler): + r""" + Logger handler used in Web UI. + """ def __init__(self): super().__init__() @@ -19,16 +22,10 @@ class LoggerHandler(logging.Handler): self.log += "\n\n" -def reset_logging(): - r""" - Removes basic config of root logger - """ - root = logging.getLogger() - list(map(root.removeHandler, root.handlers)) - list(map(root.removeFilter, root.filters)) - - def get_logger(name: str) -> logging.Logger: + r""" + Gets a standard logger with a stream hander to stdout. + """ formatter = logging.Formatter( fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" @@ -41,3 +38,12 @@ def get_logger(name: str) -> logging.Logger: logger.addHandler(handler) return logger + + +def reset_logging() -> None: + r""" + Removes basic config of root logger. (unused in script) + """ + root = logging.getLogger() + list(map(root.removeHandler, root.handlers)) + list(map(root.removeFilter, root.filters)) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index df43599c..20b9b5d4 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -202,7 +202,6 @@ def load_model_and_tokenizer( # Prepare model with valuehead for RLHF if stage in ["rm", "ppo"]: model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) - reset_logging() vhead_path = ( model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path ) diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 687cdb12..7789fc4d 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -236,11 +236,12 @@ class Runner: yield from self._launch(data, do_train=False) def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + get = lambda name: self.running_data[self.manager.get_elem_by_name(name)] self.running = True - lang = self.running_data[self.manager.get_elem_by_name("top.lang")] - output_dir = self.running_data[self.manager.get_elem_by_name( + lang = get("top.lang") + output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get( "{}.output_dir".format("train" if self.do_train else "eval") - )] + )) while self.thread.is_alive(): time.sleep(2) if self.aborted: