diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py index 8a9f8dd6..647bcee2 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -1,5 +1,4 @@ import torch -import deepspeed # type: ignore from copy import deepcopy from collections import defaultdict from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union @@ -76,6 +75,8 @@ class CustomDPOTrainer(DPOTrainer): # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) if config_kwargs["zero_optimization"]["stage"] != 3: config_kwargs["zero_optimization"]["stage"] = 0 + # lazy load + import deepspeed # type: ignore model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model.eval() return model