fix callback log multigpu #3559

This commit is contained in:
hiyouga 2024-05-03 21:24:27 +08:00
parent 5e6f808e3c
commit 9585838ebe
1 changed files with 9 additions and 8 deletions

View File

@ -70,11 +70,9 @@ class LogCallback(TrainerCallback):
r""" r"""
Event called at the beginning of training. Event called at the beginning of training.
""" """
if args.should_log: if args.should_save:
self.do_train = True self.do_train = True
self._reset(max_steps=state.max_steps) self._reset(max_steps=state.max_steps)
if args.should_save:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
self.thread_pool = ThreadPoolExecutor(max_workers=1) self.thread_pool = ThreadPoolExecutor(max_workers=1)
@ -98,7 +96,7 @@ class LogCallback(TrainerCallback):
r""" r"""
Event called at the end of a training step. Event called at the end of a training step.
""" """
if args.should_log: if args.should_save:
self._timing(cur_steps=state.global_step) self._timing(cur_steps=state.global_step)
if self.aborted: if self.aborted:
@ -119,7 +117,7 @@ class LogCallback(TrainerCallback):
Event called after a prediction step. Event called after a prediction step.
""" """
eval_dataloader = kwargs.pop("eval_dataloader", None) eval_dataloader = kwargs.pop("eval_dataloader", None)
if args.should_log and has_length(eval_dataloader) and not self.do_train: if args.should_save and has_length(eval_dataloader) and not self.do_train:
if self.max_steps == 0: if self.max_steps == 0:
self.max_steps = len(eval_dataloader) self.max_steps = len(eval_dataloader)
@ -131,8 +129,11 @@ class LogCallback(TrainerCallback):
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r""" r"""
Event called after logging the last logs, `args.should_log` has been applied. Event called after logging the last logs.
""" """
if not args.should_save:
return
logs = dict( logs = dict(
current_steps=self.cur_steps, current_steps=self.cur_steps,
total_steps=self.max_steps, total_steps=self.max_steps,
@ -148,12 +149,12 @@ class LogCallback(TrainerCallback):
remaining_time=self.remaining_time, remaining_time=self.remaining_time,
) )
logs = {k: v for k, v in logs.items() if v is not None} logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and "loss" in logs and "learning_rate" in logs and "epoch" in logs: if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info( logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"] logs["loss"], logs["learning_rate"], logs["epoch"]
) )
) )
if args.should_save and self.thread_pool is not None: if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs) self.thread_pool.submit(self._write_log, args.output_dir, logs)