feat: add save_state in prediction

This commit is contained in:
wql 2024-09-04 16:08:57 +08:00
parent 7af7b40955
commit e7a7a22f19
1 changed files with 1 additions and 0 deletions

View File

@ -117,6 +117,7 @@ def run_sft(
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_state()
trainer.save_predictions(dataset_module["eval_dataset"], predict_results)
# Create model card