This commit is contained in:
hiyouga 2024-02-04 00:47:37 +08:00
parent 38e63bfd28
commit b988ce0a0c
1 changed files with 17 additions and 11 deletions

View File

@ -1,7 +1,9 @@
import json import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Dict, List, Literal, Optional
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.packages import is_requests_available from ...extras.packages import is_requests_available
@ -23,18 +25,22 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily if is_deepspeed_zero3_enabled():
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict() import deepspeed # type: ignore
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active params = [model.v_head.summary.weight, model.v_head.summary.bias]
model.v_head.load_state_dict( context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
{ else:
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(), context_maybe_zero3 = nullcontext()
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
} with context_maybe_zero3:
) if target == "reward": # save default head temporarily
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: