diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index e1e3de99..e7ce09a2 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -79,7 +79,7 @@ def fix_valuehead_checkpoint( if name.startswith("v_head."): v_head_state_dict[name] = param else: - decoder_state_dict[name.replace("pretrained_model.", "",1)] = param + decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param model.pretrained_model.save_pretrained( output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization