feat: update num token
This commit is contained in:
parent
1f3703ae96
commit
cf1853978a
|
@ -119,10 +119,21 @@ def run_sft(
|
|||
print("********************************************************")
|
||||
print("********************************************************")
|
||||
print("********************************************************")
|
||||
|
||||
|
||||
print("predict_results.metrics")
|
||||
print(predict_results.metrics)
|
||||
print("********************************************************")
|
||||
|
||||
|
||||
num_train_tokens = trainer.num_tokens(trainer.get_test_dataloader(dataset_module["eval_dataset"]))
|
||||
print("num_train_tokens")
|
||||
print(num_train_tokens)
|
||||
print("********************************************************")
|
||||
predict_results.metrics["num_train_tokens"] = num_train_tokens
|
||||
print("new predict_results.metrics")
|
||||
print(predict_results.metrics)
|
||||
print("********************************************************")
|
||||
|
||||
|
||||
time.sleep(100)
|
||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||
predict_results.metrics.pop("predict_loss", None)
|
||||
|
|
Loading…
Reference in New Issue