forked from p04798526/LLaMA-Factory-Mirror
fix #1452
This commit is contained in:
parent
b3572659f5
commit
0e86527d7f
|
@ -1,5 +1,4 @@
|
|||
import torch
|
||||
import deepspeed # type: ignore
|
||||
from copy import deepcopy
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
|
@ -76,6 +75,8 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
||||
config_kwargs["zero_optimization"]["stage"] = 0
|
||||
# lazy load
|
||||
import deepspeed # type: ignore
|
||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
||||
model.eval()
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue