From 747db4017291b0eb91946f57011bb31659056037 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 3 Dec 2023 21:38:51 +0800 Subject: [PATCH] ppo support rm server --- src/llmtuner/chat/chat_model.py | 5 +++-- src/llmtuner/extras/packages.py | 5 +++++ src/llmtuner/train/ppo/trainer.py | 30 ++++++++++++++++++++---------- src/llmtuner/train/ppo/utils.py | 16 +++++++++++++++- src/llmtuner/train/utils.py | 6 ++++-- 5 files changed, 47 insertions(+), 15 deletions(-) diff --git a/src/llmtuner/chat/chat_model.py b/src/llmtuner/chat/chat_model.py index 500009fe..6e4c28e7 100644 --- a/src/llmtuner/chat/chat_model.py +++ b/src/llmtuner/chat/chat_model.py @@ -167,7 +167,8 @@ class ChatModel: scores = [] for i in range(input_ids.size(0)): - length = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 - scores.append(values[i, length-1].nan_to_num().item()) + end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero() + end_index = end_indexes[-1].item() if len(end_indexes) else 0 + scores.append(values[i, end_index].nan_to_num().item()) return scores diff --git a/src/llmtuner/extras/packages.py b/src/llmtuner/extras/packages.py index 22cab732..22d725c2 100644 --- a/src/llmtuner/extras/packages.py +++ b/src/llmtuner/extras/packages.py @@ -18,6 +18,7 @@ _flash_attn2_available = is_package_available("flash_attn") and get_package_vers _jieba_available = is_package_available("jieba") _matplotlib_available = is_package_available("matplotlib") _nltk_available = is_package_available("nltk") +_requests_available = is_package_available("requests") _rouge_available = is_package_available("rouge_chinese") _starlette_available = is_package_available("sse_starlette") _uvicorn_available = is_package_available("uvicorn") @@ -43,6 +44,10 @@ def is_nltk_available(): return _nltk_available +def is_requests_available(): + return _requests_available + + def is_rouge_available(): return _rouge_available diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index b81aa7ff..ade5a41c 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -3,9 +3,9 @@ import sys import math import torch from tqdm import tqdm -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl +from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_pt_utils import remove_dummy_checkpoint @@ -16,7 +16,7 @@ from trl.core import PPODecorators, logprobs_from_logits from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor -from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model +from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -200,7 +200,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ) @torch.no_grad() - def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: r""" Generates model's responses given queries. """ @@ -208,7 +208,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): layernorm_params = dump_layernorm(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) - response: torch.Tensor = unwrapped_model.generate( + generate_output: torch.Tensor = unwrapped_model.generate( generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch @@ -217,7 +217,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer): if self.finetuning_args.upcast_layernorm: restore_layernorm(self.model, layernorm_params) - query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu() + query = batch["input_ids"].detach().cpu() + response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu() queries, responses = [], [] for i in range(len(query)): query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item() @@ -242,17 +243,26 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ) -> List[torch.Tensor]: r""" Computes scores using given reward model. + + Both inputs and outputs are put on CPU. """ - if self.reward_model is None: + if self.finetuning_args.reward_model_type == "api": + token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)] + messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) + return get_rewards_from_server(self.reward_model, messages) + + if self.finetuning_args.reward_model_type == "lora": replace_model(unwrapped_model, target="reward") + reward_model = self.model + else: + reward_model = self.reward_model batch = self.prepare_model_inputs(queries, responses) with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 - reward_model = self.reward_model if self.reward_model is not None else self.model _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True) - if getattr(unwrapped_model.config, "model_type", None) == "chatglm": + if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture values = torch.transpose(values, 0, 1) rewards = [] @@ -261,7 +271,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): end_index = end_indexes[-1].item() if len(end_indexes) else 0 rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type - if self.reward_model is None: + if self.finetuning_args.reward_model_type == "lora": replace_model(unwrapped_model, target="default") return rewards diff --git a/src/llmtuner/train/ppo/utils.py b/src/llmtuner/train/ppo/utils.py index 74453a39..12e9bfcb 100644 --- a/src/llmtuner/train/ppo/utils.py +++ b/src/llmtuner/train/ppo/utils.py @@ -1,10 +1,24 @@ +import json import torch -from typing import TYPE_CHECKING, Dict, Literal, Optional +from typing import TYPE_CHECKING, Dict, List, Literal, Optional + +from llmtuner.extras.packages import is_requests_available if TYPE_CHECKING: from transformers import PreTrainedModel from trl import AutoModelForCausalLMWithValueHead +if is_requests_available(): + import requests + + +def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]: + headers = {"Content-Type": "application/json"} + payload = {"model": "model", "messages": messages} + response = requests.post(server_url, json=payload, headers=headers) + rewards = json.loads(response.text)["scores"] + return torch.Tensor(rewards) + def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: if target == "reward": # save default head temporarily diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 61700b53..e7fc279b 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -76,7 +76,9 @@ def create_reward_model( Creates reward model for PPO training. """ if finetuning_args.reward_model_type == "api": - raise NotImplementedError + assert finetuning_args.reward_model.startswith("http"), "Please provide full url." + logger.info("Use reward server {}".format(finetuning_args.reward_model)) + return finetuning_args.reward_model elif finetuning_args.reward_model_type == "lora": model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090 @@ -102,6 +104,6 @@ def create_reward_model( reward_model, _ = load_model_and_tokenizer( reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True ) - logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model)) + logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model)) logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") return reward_model