fix ppo+zero3 #3108

This commit is contained in:
hiyouga 2024-06-06 23:30:07 +08:00
parent 451b6693c0
commit 76c61905b2
2 changed files with 66 additions and 61 deletions

View File

@ -2,9 +2,10 @@ import math
import os import os
import sys import sys
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch import torch
from accelerate.utils import DistributedDataParallelKwargs
from tqdm import tqdm from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
@ -79,6 +80,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
project_kwargs={"logging_dir": training_args.logging_dir}, project_kwargs={"logging_dir": training_args.logging_dir},
) )
# Add deepspeed config
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
]
if training_args.deepspeed_plugin is not None:
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
# Create optimizer and scheduler # Create optimizer and scheduler
if training_args.max_steps > 0: if training_args.max_steps > 0:
num_training_steps = training_args.max_steps num_training_steps = training_args.max_steps
@ -124,6 +132,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.args.max_steps > 0: if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs") logger.info("max_steps is given, it will override any value given in num_train_epochs")
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
device_type = unwrapped_model.pretrained_model.device.type
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
if finetuning_args.reward_model_type == "full": if finetuning_args.reward_model_type == "full":
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
if not ( if not (
@ -184,7 +198,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.info(" Total training steps = {}".format(max_steps)) logger.info(" Total training steps = {}".format(max_steps))
logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0])) logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
loss_meter = AverageMeter() loss_meter = AverageMeter()
reward_meter = AverageMeter() reward_meter = AverageMeter()
@ -197,29 +210,21 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
batch = next(dataiter) batch = next(dataiter)
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
self.model.eval()
# Get inputs # Get inputs
self.model.eval()
self.tokenizer.padding_side = "right" # change padding side self.tokenizer.padding_side = "right" # change padding side
queries, responses, rewards = [], [], [] queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size): for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs( mini_batch_queries, mini_batch_responses = self.get_inputs(
batch[idx : idx + self.config.mini_batch_size] batch[idx : idx + self.config.mini_batch_size]
) )
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model) mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
queries.extend(mini_batch_queries) queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses) responses.extend(mini_batch_responses)
rewards.extend(mini_batch_rewards) rewards.extend(mini_batch_rewards)
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
self.model.train()
# Run PPO step # Run PPO step
self.model.train()
stats = self.step(queries, responses, rewards) stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards)) loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
@ -311,25 +316,24 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
getattr(self.processor, "image_processor").save_pretrained(output_dir) getattr(self.processor, "image_processor").save_pretrained(output_dir)
@torch.no_grad() @torch.no_grad()
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
r""" r"""
Generates model's responses given queries. Generates model's responses given queries.
""" """
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1 if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item() start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items(): for k, v in batch.items():
batch[k] = v[:, start_index:] batch[k] = v[:, start_index:]
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(unwrapped_model)
generate_output: torch.Tensor = unwrapped_model.generate( generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
) )
if self.model_args.upcast_layernorm:
if self.model_args.upcast_layernorm: restore_layernorm(unwrapped_model, layernorm_params)
restore_layernorm(self.model, layernorm_params)
query = batch["input_ids"].detach().cpu() query = batch["input_ids"].detach().cpu()
response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu() response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
@ -351,10 +355,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@torch.no_grad() @torch.no_grad()
def get_rewards( def get_rewards(
self, self,
queries: List[torch.Tensor], queries: List["torch.Tensor"],
responses: List[torch.Tensor], responses: List["torch.Tensor"],
unwrapped_model: "AutoModelForCausalLMWithValueHead", ) -> List["torch.Tensor"]:
) -> List[torch.Tensor]:
r""" r"""
Computes scores using given reward model. Computes scores using given reward model.
@ -365,18 +368,22 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return get_rewards_from_server(self.reward_model, messages) return get_rewards_from_server(self.reward_model, messages)
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="reward")
reward_model = self.model
else:
reward_model = self.reward_model
batch = self.prepare_model_inputs(queries, responses) batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False) if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="reward")
reward_model = self.model
else:
reward_model = self.reward_model
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture with self.amp_context: # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
if self.is_chatglm_model: # assume same architecture
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
rewards = [] rewards = []
@ -385,21 +392,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
end_index = end_indexes[-1].item() if len(end_indexes) else 0 end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
return rewards return rewards
@PPODecorators.empty_device_cache() @PPODecorators.empty_device_cache()
def batched_forward_pass( def batched_forward_pass(
self, self,
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead",
queries: torch.Tensor, queries: "torch.Tensor",
responses: torch.Tensor, responses: "torch.Tensor",
model_inputs: dict, model_inputs: Dict[str, Any],
return_logits: bool = False, return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None, response_masks: Optional["torch.Tensor"] = None,
): ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
r""" r"""
Calculates model outputs in multiple batches. Calculates model outputs in multiple batches.
@ -421,11 +425,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
input_ids = input_kwargs["input_ids"] input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"] attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 with self.amp_context: # support bf16
logits, _, values = model(**input_kwargs) logits, _, values = model(**input_kwargs)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) if self.is_chatglm_model:
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])

View File

@ -1,9 +1,7 @@
import json import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Dict, List, Literal, Optional
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.packages import is_requests_available from ...extras.packages import is_requests_available
@ -18,6 +16,9 @@ if TYPE_CHECKING:
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]: def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
r"""
Gets reward scores from the API server.
"""
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
payload = {"model": "model", "messages": messages} payload = {"model": "model", "messages": messages}
response = requests.post(server_url, json=payload, headers=headers) response = requests.post(server_url, json=payload, headers=headers)
@ -26,25 +27,23 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if is_deepspeed_zero3_enabled(): r"""
import deepspeed # type: ignore Replaces the default/reward modules in the model. The model is already unwrapped (and gathered).
"""
if target == "reward": # save default head temporarily
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
params = [model.v_head.summary.weight, model.v_head.summary.bias] model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) device = model.v_head.summary.weight.device
else: model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
context_maybe_zero3 = nullcontext() model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
with context_maybe_zero3:
if target == "reward": # save default head temporarily
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
r"""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
layer_norm_params = {} layer_norm_params = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.data.dtype == torch.float32: if param.data.dtype == torch.float32:
@ -55,6 +54,9 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None: def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if name in layernorm_params: if name in layernorm_params:
param.data = layernorm_params[name] param.data = layernorm_params[name]