feat: update num token

This commit is contained in:
wql 2024-09-26 16:21:33 +08:00
parent 1f3703ae96
commit cf1853978a
1 changed files with 13 additions and 2 deletions

View File

@ -119,10 +119,21 @@ def run_sft(
print("********************************************************") print("********************************************************")
print("********************************************************") print("********************************************************")
print("********************************************************") print("********************************************************")
print("predict_results.metrics")
print(predict_results.metrics) print(predict_results.metrics)
print("********************************************************") print("********************************************************")
num_train_tokens = trainer.num_tokens(trainer.get_test_dataloader(dataset_module["eval_dataset"]))
print("num_train_tokens")
print(num_train_tokens)
print("********************************************************")
predict_results.metrics["num_train_tokens"] = num_train_tokens
print("new predict_results.metrics")
print(predict_results.metrics)
print("********************************************************")
time.sleep(100) time.sleep(100)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None) predict_results.metrics.pop("predict_loss", None)