diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 502ac2f7..47e2bb3b 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -119,10 +119,21 @@ def run_sft( print("********************************************************") print("********************************************************") print("********************************************************") - - + print("predict_results.metrics") print(predict_results.metrics) print("********************************************************") + + + num_train_tokens = trainer.num_tokens(trainer.get_test_dataloader(dataset_module["eval_dataset"])) + print("num_train_tokens") + print(num_train_tokens) + print("********************************************************") + predict_results.metrics["num_train_tokens"] = num_train_tokens + print("new predict_results.metrics") + print(predict_results.metrics) + print("********************************************************") + + time.sleep(100) if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled predict_results.metrics.pop("predict_loss", None)