fix ppo+zero3 #3108

This commit is contained in:
hiyouga 2024-06-06 23:30:07 +08:00
parent 451b6693c0
commit 76c61905b2
2 changed files with 66 additions and 61 deletions

View File

@ -2,9 +2,10 @@ import math
import os
import sys
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch
from accelerate.utils import DistributedDataParallelKwargs
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler
@ -79,6 +80,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
project_kwargs={"logging_dir": training_args.logging_dir},
)
# Add deepspeed config
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
]
if training_args.deepspeed_plugin is not None:
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
# Create optimizer and scheduler
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
@ -124,6 +132,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
device_type = unwrapped_model.pretrained_model.device.type
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
if finetuning_args.reward_model_type == "full":
if self.is_deepspeed_enabled:
if not (
@ -184,7 +198,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.info(" Total training steps = {}".format(max_steps))
logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
loss_meter = AverageMeter()
reward_meter = AverageMeter()
@ -197,29 +210,21 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
dataiter = iter(self.dataloader)
batch = next(dataiter)
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
self.model.eval()
# Get inputs
self.model.eval()
self.tokenizer.padding_side = "right" # change padding side
queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(
batch[idx : idx + self.config.mini_batch_size]
)
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses)
rewards.extend(mini_batch_rewards)
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
self.model.train()
# Run PPO step
self.model.train()
stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
@ -311,25 +316,24 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
getattr(self.processor, "image_processor").save_pretrained(output_dir)
@torch.no_grad()
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
r"""
Generates model's responses given queries.
"""
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items():
batch[k] = v[:, start_index:]
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(unwrapped_model)
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)
restore_layernorm(unwrapped_model, layernorm_params)
query = batch["input_ids"].detach().cpu()
response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
@ -351,10 +355,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@torch.no_grad()
def get_rewards(
self,
queries: List[torch.Tensor],
responses: List[torch.Tensor],
unwrapped_model: "AutoModelForCausalLMWithValueHead",
) -> List[torch.Tensor]:
queries: List["torch.Tensor"],
responses: List["torch.Tensor"],
) -> List["torch.Tensor"]:
r"""
Computes scores using given reward model.
@ -365,18 +368,22 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return get_rewards_from_server(self.reward_model, messages)
batch = self.prepare_model_inputs(queries, responses)
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="reward")
reward_model = self.model
else:
reward_model = self.reward_model
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
with self.amp_context: # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
if self.is_chatglm_model: # assume same architecture
values = torch.transpose(values, 0, 1)
rewards = []
@ -385,21 +392,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
return rewards
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
queries: "torch.Tensor",
responses: "torch.Tensor",
model_inputs: Dict[str, Any],
return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None,
):
response_masks: Optional["torch.Tensor"] = None,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
r"""
Calculates model outputs in multiple batches.
@ -421,11 +425,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
with self.amp_context: # support bf16
logits, _, values = model(**input_kwargs)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
if self.is_chatglm_model:
values = torch.transpose(values, 0, 1)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])

View File

@ -1,9 +1,7 @@
import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.packages import is_requests_available
@ -18,6 +16,9 @@ if TYPE_CHECKING:
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
r"""
Gets reward scores from the API server.
"""
headers = {"Content-Type": "application/json"}
payload = {"model": "model", "messages": messages}
response = requests.post(server_url, json=payload, headers=headers)
@ -26,25 +27,23 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.v_head.summary.weight, model.v_head.summary.bias]
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
r"""
Replaces the default/reward modules in the model. The model is already unwrapped (and gathered).
"""
if target == "reward": # save default head temporarily
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
device = model.v_head.summary.weight.device
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
r"""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
layer_norm_params = {}
for name, param in model.named_parameters():
if param.data.dtype == torch.float32:
@ -55,6 +54,9 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
for name, param in model.named_parameters():
if name in layernorm_params:
param.data = layernorm_params[name]