From 9585838ebe1f7ce508ec490f91d30920f134be3f Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 3 May 2024 21:24:27 +0800 Subject: [PATCH] fix callback log multigpu #3559 --- src/llmtuner/extras/callbacks.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index fbe6f373..76f07a42 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -70,11 +70,9 @@ class LogCallback(TrainerCallback): r""" Event called at the beginning of training. """ - if args.should_log: + if args.should_save: self.do_train = True self._reset(max_steps=state.max_steps) - - if args.should_save: os.makedirs(args.output_dir, exist_ok=True) self.thread_pool = ThreadPoolExecutor(max_workers=1) @@ -98,7 +96,7 @@ class LogCallback(TrainerCallback): r""" Event called at the end of a training step. """ - if args.should_log: + if args.should_save: self._timing(cur_steps=state.global_step) if self.aborted: @@ -119,7 +117,7 @@ class LogCallback(TrainerCallback): Event called after a prediction step. """ eval_dataloader = kwargs.pop("eval_dataloader", None) - if args.should_log and has_length(eval_dataloader) and not self.do_train: + if args.should_save and has_length(eval_dataloader) and not self.do_train: if self.max_steps == 0: self.max_steps = len(eval_dataloader) @@ -131,8 +129,11 @@ class LogCallback(TrainerCallback): def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: r""" - Event called after logging the last logs, `args.should_log` has been applied. + Event called after logging the last logs. """ + if not args.should_save: + return + logs = dict( current_steps=self.cur_steps, total_steps=self.max_steps, @@ -148,12 +149,12 @@ class LogCallback(TrainerCallback): remaining_time=self.remaining_time, ) logs = {k: v for k, v in logs.items() if v is not None} - if self.webui_mode and "loss" in logs and "learning_rate" in logs and "epoch" in logs: + if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): logger.info( "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( logs["loss"], logs["learning_rate"], logs["epoch"] ) ) - if args.should_save and self.thread_pool is not None: + if self.thread_pool is not None: self.thread_pool.submit(self._write_log, args.output_dir, logs)