From b988ce0a0c164213ad2e52efadd6aa5b71fd39c5 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 4 Feb 2024 00:47:37 +0800 Subject: [PATCH] fix #2189 --- src/llmtuner/train/ppo/utils.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/llmtuner/train/ppo/utils.py b/src/llmtuner/train/ppo/utils.py index f3cd8236..e6bdb89c 100644 --- a/src/llmtuner/train/ppo/utils.py +++ b/src/llmtuner/train/ppo/utils.py @@ -1,7 +1,9 @@ import json +from contextlib import nullcontext from typing import TYPE_CHECKING, Dict, List, Literal, Optional import torch +from transformers.integrations import is_deepspeed_zero3_enabled from ...extras.packages import is_requests_available @@ -23,18 +25,22 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch. def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: - if target == "reward": # save default head temporarily - valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict() - setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone()) - setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone()) + if is_deepspeed_zero3_enabled(): + import deepspeed # type: ignore - model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active - model.v_head.load_state_dict( - { - "summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(), - "summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(), - } - ) + params = [model.v_head.summary.weight, model.v_head.summary.bias] + context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) + else: + context_maybe_zero3 = nullcontext() + + with context_maybe_zero3: + if target == "reward": # save default head temporarily + setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone()) + setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone()) + + model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active + model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone() + model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone() def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: