This commit is contained in:
hiyouga 2024-03-28 18:31:17 +08:00
parent 8c77b10912
commit b19c14870d
1 changed files with 17 additions and 5 deletions

View File

@ -58,9 +58,17 @@ class LogCallback(TrainerCallback):
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
@ -112,8 +120,12 @@ class LogCallback(TrainerCallback):
r"""
Event called after logging the last logs.
"""
if not state.is_local_process_zero:
return
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
logs = dict(
current_steps=self.cur_steps,