forked from p04798526/LLaMA-Factory-Mirror
fix callback log multigpu #3559
This commit is contained in:
parent
5e6f808e3c
commit
9585838ebe
|
@ -70,11 +70,9 @@ class LogCallback(TrainerCallback):
|
|||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
if args.should_log:
|
||||
if args.should_save:
|
||||
self.do_train = True
|
||||
self._reset(max_steps=state.max_steps)
|
||||
|
||||
if args.should_save:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
|
@ -98,7 +96,7 @@ class LogCallback(TrainerCallback):
|
|||
r"""
|
||||
Event called at the end of a training step.
|
||||
"""
|
||||
if args.should_log:
|
||||
if args.should_save:
|
||||
self._timing(cur_steps=state.global_step)
|
||||
|
||||
if self.aborted:
|
||||
|
@ -119,7 +117,7 @@ class LogCallback(TrainerCallback):
|
|||
Event called after a prediction step.
|
||||
"""
|
||||
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||
if args.should_log and has_length(eval_dataloader) and not self.do_train:
|
||||
if args.should_save and has_length(eval_dataloader) and not self.do_train:
|
||||
if self.max_steps == 0:
|
||||
self.max_steps = len(eval_dataloader)
|
||||
|
||||
|
@ -131,8 +129,11 @@ class LogCallback(TrainerCallback):
|
|||
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs, `args.should_log` has been applied.
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if not args.should_save:
|
||||
return
|
||||
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
|
@ -148,12 +149,12 @@ class LogCallback(TrainerCallback):
|
|||
remaining_time=self.remaining_time,
|
||||
)
|
||||
logs = {k: v for k, v in logs.items() if v is not None}
|
||||
if self.webui_mode and "loss" in logs and "learning_rate" in logs and "epoch" in logs:
|
||||
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
||||
logger.info(
|
||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||
logs["loss"], logs["learning_rate"], logs["epoch"]
|
||||
)
|
||||
)
|
||||
|
||||
if args.should_save and self.thread_pool is not None:
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
|
Loading…
Reference in New Issue