From b19c14870d30c57fbea81e9cfa737d762922c54b Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 28 Mar 2024 18:31:17 +0800 Subject: [PATCH] fix #3010 --- src/llmtuner/extras/callbacks.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 086dea6d..985b0292 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -58,9 +58,17 @@ class LogCallback(TrainerCallback): self.in_training = True self.start_time = time.time() self.max_steps = state.max_steps - if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: - logger.warning("Previous log file in this folder will be deleted.") - os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) + + if args.save_on_each_node: + if not state.is_local_process_zero: + return + else: + if not state.is_world_process_zero: + return + + if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: + logger.warning("Previous log file in this folder will be deleted.") + os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" @@ -112,8 +120,12 @@ class LogCallback(TrainerCallback): r""" Event called after logging the last logs. """ - if not state.is_local_process_zero: - return + if args.save_on_each_node: + if not state.is_local_process_zero: + return + else: + if not state.is_world_process_zero: + return logs = dict( current_steps=self.cur_steps,