From cf1853978aeb998be2e59ff0be2b060fb673e625 Mon Sep 17 00:00:00 2001 From: wql Date: Thu, 26 Sep 2024 16:21:33 +0800 Subject: [PATCH] feat: update num token --- src/llamafactory/train/sft/workflow.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 502ac2f7..47e2bb3b 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -119,10 +119,21 @@ def run_sft( print("********************************************************") print("********************************************************") print("********************************************************") - - + print("predict_results.metrics") print(predict_results.metrics) print("********************************************************") + + + num_train_tokens = trainer.num_tokens(trainer.get_test_dataloader(dataset_module["eval_dataset"])) + print("num_train_tokens") + print(num_train_tokens) + print("********************************************************") + predict_results.metrics["num_train_tokens"] = num_train_tokens + print("new predict_results.metrics") + print(predict_results.metrics) + print("********************************************************") + + time.sleep(100) if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled predict_results.metrics.pop("predict_loss", None)