This commit is contained in:
hiyouga 2023-11-09 16:41:32 +08:00
parent b3572659f5
commit 0e86527d7f
1 changed files with 2 additions and 1 deletions

View File

@ -1,5 +1,4 @@
import torch
import deepspeed # type: ignore
from copy import deepcopy
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
@ -76,6 +75,8 @@ class CustomDPOTrainer(DPOTrainer):
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
# lazy load
import deepspeed # type: ignore
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model