From 76c61905b20f69fac5c7a6c4ea9450bf33d3b1f2 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 6 Jun 2024 23:30:07 +0800 Subject: [PATCH] fix ppo+zero3 #3108 --- src/llamafactory/train/ppo/trainer.py | 91 ++++++++++++++------------- src/llamafactory/train/ppo/utils.py | 36 ++++++----- 2 files changed, 66 insertions(+), 61 deletions(-) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 27353c72..b0c7e25d 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -2,9 +2,10 @@ import math import os import sys from types import MethodType -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch +from accelerate.utils import DistributedDataParallelKwargs from tqdm import tqdm from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState from transformers.optimization import get_scheduler @@ -79,6 +80,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer): project_kwargs={"logging_dir": training_args.logging_dir}, ) + # Add deepspeed config + ppo_config.accelerator_kwargs["kwargs_handlers"] = [ + DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters) + ] + if training_args.deepspeed_plugin is not None: + ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin + # Create optimizer and scheduler if training_args.max_steps > 0: num_training_steps = training_args.max_steps @@ -124,6 +132,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer): if self.args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm" + + device_type = unwrapped_model.pretrained_model.device.type + self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype) + if finetuning_args.reward_model_type == "full": if self.is_deepspeed_enabled: if not ( @@ -184,7 +198,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer): logger.info(" Total training steps = {}".format(max_steps)) logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0])) - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) dataiter = iter(self.dataloader) loss_meter = AverageMeter() reward_meter = AverageMeter() @@ -197,29 +210,21 @@ class CustomPPOTrainer(PPOTrainer, Trainer): dataiter = iter(self.dataloader) batch = next(dataiter) - # Cast to inference mode - unwrapped_model.gradient_checkpointing_disable() - unwrapped_model.config.use_cache = True - self.model.eval() - # Get inputs + self.model.eval() self.tokenizer.padding_side = "right" # change padding side queries, responses, rewards = [], [], [] for idx in range(0, self.config.batch_size, self.config.mini_batch_size): mini_batch_queries, mini_batch_responses = self.get_inputs( batch[idx : idx + self.config.mini_batch_size] ) - mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model) + mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses) queries.extend(mini_batch_queries) responses.extend(mini_batch_responses) rewards.extend(mini_batch_rewards) - # Cast to training mode - unwrapped_model.gradient_checkpointing_enable() - unwrapped_model.config.use_cache = False - self.model.train() - # Run PPO step + self.model.train() stats = self.step(queries, responses, rewards) self.tokenizer.padding_side = "left" # restore padding side loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards)) @@ -311,25 +316,24 @@ class CustomPPOTrainer(PPOTrainer, Trainer): getattr(self.processor, "image_processor").save_pretrained(output_dir) @torch.no_grad() - def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]: r""" Generates model's responses given queries. """ - if self.model_args.upcast_layernorm: - layernorm_params = dump_layernorm(self.model) - if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1 start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item() for k, v in batch.items(): batch[k] = v[:, start_index:] with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + if self.model_args.upcast_layernorm: + layernorm_params = dump_layernorm(unwrapped_model) + generate_output: torch.Tensor = unwrapped_model.generate( generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch ) - - if self.model_args.upcast_layernorm: - restore_layernorm(self.model, layernorm_params) + if self.model_args.upcast_layernorm: + restore_layernorm(unwrapped_model, layernorm_params) query = batch["input_ids"].detach().cpu() response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu() @@ -351,10 +355,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): @torch.no_grad() def get_rewards( self, - queries: List[torch.Tensor], - responses: List[torch.Tensor], - unwrapped_model: "AutoModelForCausalLMWithValueHead", - ) -> List[torch.Tensor]: + queries: List["torch.Tensor"], + responses: List["torch.Tensor"], + ) -> List["torch.Tensor"]: r""" Computes scores using given reward model. @@ -365,18 +368,22 @@ class CustomPPOTrainer(PPOTrainer, Trainer): messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) return get_rewards_from_server(self.reward_model, messages) - if self.finetuning_args.reward_model_type == "lora": - replace_model(unwrapped_model, target="reward") - reward_model = self.model - else: - reward_model = self.reward_model - batch = self.prepare_model_inputs(queries, responses) - with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 - _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False) + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + if self.finetuning_args.reward_model_type == "lora": + replace_model(unwrapped_model, target="reward") + reward_model = self.model + else: + reward_model = self.reward_model - if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture + with self.amp_context: # support bf16 + _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False) + + if self.finetuning_args.reward_model_type == "lora": + replace_model(unwrapped_model, target="default") + + if self.is_chatglm_model: # assume same architecture values = torch.transpose(values, 0, 1) rewards = [] @@ -385,21 +392,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer): end_index = end_indexes[-1].item() if len(end_indexes) else 0 rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type - if self.finetuning_args.reward_model_type == "lora": - replace_model(unwrapped_model, target="default") - return rewards @PPODecorators.empty_device_cache() def batched_forward_pass( self, model: "AutoModelForCausalLMWithValueHead", - queries: torch.Tensor, - responses: torch.Tensor, - model_inputs: dict, + queries: "torch.Tensor", + responses: "torch.Tensor", + model_inputs: Dict[str, Any], return_logits: bool = False, - response_masks: Optional[torch.Tensor] = None, - ): + response_masks: Optional["torch.Tensor"] = None, + ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]: r""" Calculates model outputs in multiple batches. @@ -421,11 +425,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer): input_ids = input_kwargs["input_ids"] attention_mask = input_kwargs["attention_mask"] - with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 + with self.amp_context: # support bf16 logits, _, values = model(**input_kwargs) - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) - if getattr(unwrapped_model.config, "model_type", None) == "chatglm": + if self.is_chatglm_model: values = torch.transpose(values, 0, 1) logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) diff --git a/src/llamafactory/train/ppo/utils.py b/src/llamafactory/train/ppo/utils.py index e5025581..570409f2 100644 --- a/src/llamafactory/train/ppo/utils.py +++ b/src/llamafactory/train/ppo/utils.py @@ -1,9 +1,7 @@ import json -from contextlib import nullcontext from typing import TYPE_CHECKING, Dict, List, Literal, Optional import torch -from transformers.integrations import is_deepspeed_zero3_enabled from ...extras.packages import is_requests_available @@ -18,6 +16,9 @@ if TYPE_CHECKING: def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]: + r""" + Gets reward scores from the API server. + """ headers = {"Content-Type": "application/json"} payload = {"model": "model", "messages": messages} response = requests.post(server_url, json=payload, headers=headers) @@ -26,25 +27,23 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch. def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: - if is_deepspeed_zero3_enabled(): - import deepspeed # type: ignore + r""" + Replaces the default/reward modules in the model. The model is already unwrapped (and gathered). + """ + if target == "reward": # save default head temporarily + setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone()) + setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone()) - params = [model.v_head.summary.weight, model.v_head.summary.bias] - context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) - else: - context_maybe_zero3 = nullcontext() - - with context_maybe_zero3: - if target == "reward": # save default head temporarily - setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone()) - setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone()) - - model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active - model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone() - model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone() + model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active + device = model.v_head.summary.weight.device + model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device) + model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device) def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: + r""" + Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered). + """ layer_norm_params = {} for name, param in model.named_parameters(): if param.data.dtype == torch.float32: @@ -55,6 +54,9 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None: + r""" + Restores the layernorm parameters in the model. The model is already unwrapped (and gathered). + """ for name, param in model.named_parameters(): if name in layernorm_params: param.data = layernorm_params[name]