From cff89a2e8907f3fe89406006105cb6494e2ee993 Mon Sep 17 00:00:00 2001 From: "-.-" Date: Wed, 10 Jul 2024 12:05:51 +0800 Subject: [PATCH] fix src/llamafactory/train/callbacks.py --- src/llamafactory/train/callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 97eb6d1c..e1e3de99 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -79,7 +79,7 @@ def fix_valuehead_checkpoint( if name.startswith("v_head."): v_head_state_dict[name] = param else: - decoder_state_dict[name.replace("pretrained_model.", "", count=1)] = param + decoder_state_dict[name.replace("pretrained_model.", "",1)] = param model.pretrained_model.save_pretrained( output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization