forked from p04798526/LLaMA-Factory-Mirror
Merge pull request #4746 from yzoaim/fix
fix src/llamafactory/train/callbacks.py
This commit is contained in:
commit
40c3b88b68
|
@ -79,7 +79,7 @@ def fix_valuehead_checkpoint(
|
|||
if name.startswith("v_head."):
|
||||
v_head_state_dict[name] = param
|
||||
else:
|
||||
decoder_state_dict[name.replace("pretrained_model.", "", count=1)] = param
|
||||
decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param
|
||||
|
||||
model.pretrained_model.save_pretrained(
|
||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||
|
|
Loading…
Reference in New Issue