fix src/llamafactory/train/callbacks.py

This commit is contained in:
-.- 2024-07-10 12:05:51 +08:00
parent 51942acee8
commit cff89a2e89
1 changed files with 1 additions and 1 deletions

View File

@ -79,7 +79,7 @@ def fix_valuehead_checkpoint(
if name.startswith("v_head."): if name.startswith("v_head."):
v_head_state_dict[name] = param v_head_state_dict[name] = param
else: else:
decoder_state_dict[name.replace("pretrained_model.", "", count=1)] = param decoder_state_dict[name.replace("pretrained_model.", "",1)] = param
model.pretrained_model.save_pretrained( model.pretrained_model.save_pretrained(
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization