From 39cd89ce17220dc50c8331299ae5af230fe40cc9 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 10 Jul 2024 13:32:20 +0800 Subject: [PATCH] Update callbacks.py --- src/llamafactory/train/callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index e1e3de99..e7ce09a2 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -79,7 +79,7 @@ def fix_valuehead_checkpoint( if name.startswith("v_head."): v_head_state_dict[name] = param else: - decoder_state_dict[name.replace("pretrained_model.", "",1)] = param + decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param model.pretrained_model.save_pretrained( output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization