diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 5da99557..ef3bc897 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -117,6 +117,7 @@ def run_sft( predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) + trainer.save_state() trainer.save_predictions(dataset_module["eval_dataset"], predict_results) # Create model card