From e7a7a22f195d08095cce665cfff5d6b794e871ab Mon Sep 17 00:00:00 2001 From: wql Date: Wed, 4 Sep 2024 16:08:57 +0800 Subject: [PATCH] feat: add save_state in prediction --- src/llamafactory/train/sft/workflow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 5da99557..ef3bc897 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -117,6 +117,7 @@ def run_sft( predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) + trainer.save_state() trainer.save_predictions(dataset_module["eval_dataset"], predict_results) # Create model card