fix ppo+zero3 #3108
This commit is contained in:
parent
451b6693c0
commit
76c61905b2
|
@ -2,9 +2,10 @@ import math
|
|||
import os
|
||||
import sys
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
|
||||
from transformers.optimization import get_scheduler
|
||||
|
@ -79,6 +80,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
project_kwargs={"logging_dir": training_args.logging_dir},
|
||||
)
|
||||
|
||||
# Add deepspeed config
|
||||
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
|
||||
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
|
||||
]
|
||||
if training_args.deepspeed_plugin is not None:
|
||||
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
||||
|
||||
# Create optimizer and scheduler
|
||||
if training_args.max_steps > 0:
|
||||
num_training_steps = training_args.max_steps
|
||||
|
@ -124,6 +132,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
if self.args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
|
||||
|
||||
device_type = unwrapped_model.pretrained_model.device.type
|
||||
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
|
||||
|
||||
if finetuning_args.reward_model_type == "full":
|
||||
if self.is_deepspeed_enabled:
|
||||
if not (
|
||||
|
@ -184,7 +198,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
logger.info(" Total training steps = {}".format(max_steps))
|
||||
logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
dataiter = iter(self.dataloader)
|
||||
loss_meter = AverageMeter()
|
||||
reward_meter = AverageMeter()
|
||||
|
@ -197,29 +210,21 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
dataiter = iter(self.dataloader)
|
||||
batch = next(dataiter)
|
||||
|
||||
# Cast to inference mode
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
unwrapped_model.config.use_cache = True
|
||||
self.model.eval()
|
||||
|
||||
# Get inputs
|
||||
self.model.eval()
|
||||
self.tokenizer.padding_side = "right" # change padding side
|
||||
queries, responses, rewards = [], [], []
|
||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||
mini_batch_queries, mini_batch_responses = self.get_inputs(
|
||||
batch[idx : idx + self.config.mini_batch_size]
|
||||
)
|
||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
|
||||
queries.extend(mini_batch_queries)
|
||||
responses.extend(mini_batch_responses)
|
||||
rewards.extend(mini_batch_rewards)
|
||||
|
||||
# Cast to training mode
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
unwrapped_model.config.use_cache = False
|
||||
self.model.train()
|
||||
|
||||
# Run PPO step
|
||||
self.model.train()
|
||||
stats = self.step(queries, responses, rewards)
|
||||
self.tokenizer.padding_side = "left" # restore padding side
|
||||
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
||||
|
@ -311,25 +316,24 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(self.model)
|
||||
|
||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
for k, v in batch.items():
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(unwrapped_model)
|
||||
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
|
||||
if self.model_args.upcast_layernorm:
|
||||
restore_layernorm(self.model, layernorm_params)
|
||||
if self.model_args.upcast_layernorm:
|
||||
restore_layernorm(unwrapped_model, layernorm_params)
|
||||
|
||||
query = batch["input_ids"].detach().cpu()
|
||||
response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
|
||||
|
@ -351,10 +355,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
@torch.no_grad()
|
||||
def get_rewards(
|
||||
self,
|
||||
queries: List[torch.Tensor],
|
||||
responses: List[torch.Tensor],
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead",
|
||||
) -> List[torch.Tensor]:
|
||||
queries: List["torch.Tensor"],
|
||||
responses: List["torch.Tensor"],
|
||||
) -> List["torch.Tensor"]:
|
||||
r"""
|
||||
Computes scores using given reward model.
|
||||
|
||||
|
@ -365,18 +368,22 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
reward_model = self.model
|
||||
else:
|
||||
reward_model = self.reward_model
|
||||
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
reward_model = self.model
|
||||
else:
|
||||
reward_model = self.reward_model
|
||||
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
||||
with self.amp_context: # support bf16
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
if self.is_chatglm_model: # assume same architecture
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
rewards = []
|
||||
|
@ -385,21 +392,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
return rewards
|
||||
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
queries: torch.Tensor,
|
||||
responses: torch.Tensor,
|
||||
model_inputs: dict,
|
||||
queries: "torch.Tensor",
|
||||
responses: "torch.Tensor",
|
||||
model_inputs: Dict[str, Any],
|
||||
return_logits: bool = False,
|
||||
response_masks: Optional[torch.Tensor] = None,
|
||||
):
|
||||
response_masks: Optional["torch.Tensor"] = None,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Calculates model outputs in multiple batches.
|
||||
|
||||
|
@ -421,11 +425,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
input_ids = input_kwargs["input_ids"]
|
||||
attention_mask = input_kwargs["attention_mask"]
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
with self.amp_context: # support bf16
|
||||
logits, _, values = model(**input_kwargs)
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
if self.is_chatglm_model:
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import json
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...extras.packages import is_requests_available
|
||||
|
||||
|
@ -18,6 +16,9 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
|
||||
r"""
|
||||
Gets reward scores from the API server.
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = {"model": "model", "messages": messages}
|
||||
response = requests.post(server_url, json=payload, headers=headers)
|
||||
|
@ -26,25 +27,23 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
|
|||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
r"""
|
||||
Replaces the default/reward modules in the model. The model is already unwrapped (and gathered).
|
||||
"""
|
||||
if target == "reward": # save default head temporarily
|
||||
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
|
||||
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
|
||||
|
||||
params = [model.v_head.summary.weight, model.v_head.summary.bias]
|
||||
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||
else:
|
||||
context_maybe_zero3 = nullcontext()
|
||||
|
||||
with context_maybe_zero3:
|
||||
if target == "reward": # save default head temporarily
|
||||
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
|
||||
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
|
||||
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
|
||||
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
device = model.v_head.summary.weight.device
|
||||
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
|
||||
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
|
||||
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
||||
"""
|
||||
layer_norm_params = {}
|
||||
for name, param in model.named_parameters():
|
||||
if param.data.dtype == torch.float32:
|
||||
|
@ -55,6 +54,9 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
|||
|
||||
|
||||
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
r"""
|
||||
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
||||
"""
|
||||
for name, param in model.named_parameters():
|
||||
if name in layernorm_params:
|
||||
param.data = layernorm_params[name]
|
||||
|
|
Loading…
Reference in New Issue