fix callback log multigpu #3559
This commit is contained in:
parent
5e6f808e3c
commit
9585838ebe
|
@ -70,11 +70,9 @@ class LogCallback(TrainerCallback):
|
||||||
r"""
|
r"""
|
||||||
Event called at the beginning of training.
|
Event called at the beginning of training.
|
||||||
"""
|
"""
|
||||||
if args.should_log:
|
if args.should_save:
|
||||||
self.do_train = True
|
self.do_train = True
|
||||||
self._reset(max_steps=state.max_steps)
|
self._reset(max_steps=state.max_steps)
|
||||||
|
|
||||||
if args.should_save:
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
|
@ -98,7 +96,7 @@ class LogCallback(TrainerCallback):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of a training step.
|
Event called at the end of a training step.
|
||||||
"""
|
"""
|
||||||
if args.should_log:
|
if args.should_save:
|
||||||
self._timing(cur_steps=state.global_step)
|
self._timing(cur_steps=state.global_step)
|
||||||
|
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
|
@ -119,7 +117,7 @@ class LogCallback(TrainerCallback):
|
||||||
Event called after a prediction step.
|
Event called after a prediction step.
|
||||||
"""
|
"""
|
||||||
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||||
if args.should_log and has_length(eval_dataloader) and not self.do_train:
|
if args.should_save and has_length(eval_dataloader) and not self.do_train:
|
||||||
if self.max_steps == 0:
|
if self.max_steps == 0:
|
||||||
self.max_steps = len(eval_dataloader)
|
self.max_steps = len(eval_dataloader)
|
||||||
|
|
||||||
|
@ -131,8 +129,11 @@ class LogCallback(TrainerCallback):
|
||||||
|
|
||||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
||||||
r"""
|
r"""
|
||||||
Event called after logging the last logs, `args.should_log` has been applied.
|
Event called after logging the last logs.
|
||||||
"""
|
"""
|
||||||
|
if not args.should_save:
|
||||||
|
return
|
||||||
|
|
||||||
logs = dict(
|
logs = dict(
|
||||||
current_steps=self.cur_steps,
|
current_steps=self.cur_steps,
|
||||||
total_steps=self.max_steps,
|
total_steps=self.max_steps,
|
||||||
|
@ -148,12 +149,12 @@ class LogCallback(TrainerCallback):
|
||||||
remaining_time=self.remaining_time,
|
remaining_time=self.remaining_time,
|
||||||
)
|
)
|
||||||
logs = {k: v for k, v in logs.items() if v is not None}
|
logs = {k: v for k, v in logs.items() if v is not None}
|
||||||
if self.webui_mode and "loss" in logs and "learning_rate" in logs and "epoch" in logs:
|
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
||||||
logger.info(
|
logger.info(
|
||||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||||
logs["loss"], logs["learning_rate"], logs["epoch"]
|
logs["loss"], logs["learning_rate"], logs["epoch"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.should_save and self.thread_pool is not None:
|
if self.thread_pool is not None:
|
||||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||||
|
|
Loading…
Reference in New Issue