ppo support rm server

This commit is contained in:
hiyouga 2023-12-03 21:38:51 +08:00
parent 7df4f3ab20
commit 747db40172
5 changed files with 47 additions and 15 deletions

View File

@ -167,7 +167,8 @@ class ChatModel:
scores = []
for i in range(input_ids.size(0)):
length = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
scores.append(values[i, length-1].nan_to_num().item())
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
scores.append(values[i, end_index].nan_to_num().item())
return scores

View File

@ -18,6 +18,7 @@ _flash_attn2_available = is_package_available("flash_attn") and get_package_vers
_jieba_available = is_package_available("jieba")
_matplotlib_available = is_package_available("matplotlib")
_nltk_available = is_package_available("nltk")
_requests_available = is_package_available("requests")
_rouge_available = is_package_available("rouge_chinese")
_starlette_available = is_package_available("sse_starlette")
_uvicorn_available = is_package_available("uvicorn")
@ -43,6 +44,10 @@ def is_nltk_available():
return _nltk_available
def is_requests_available():
return _requests_available
def is_rouge_available():
return _rouge_available

View File

@ -3,9 +3,9 @@ import sys
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_pt_utils import remove_dummy_checkpoint
@ -16,7 +16,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -200,7 +200,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
)
@torch.no_grad()
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
@ -208,7 +208,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
layernorm_params = dump_layernorm(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config,
logits_processor=get_logits_processor(),
**batch
@ -217,7 +217,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.finetuning_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
query = batch["input_ids"].detach().cpu()
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
queries, responses = [], []
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
@ -242,17 +243,26 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
) -> List[torch.Tensor]:
r"""
Computes scores using given reward model.
Both inputs and outputs are put on CPU.
"""
if self.reward_model is None:
if self.finetuning_args.reward_model_type == "api":
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return get_rewards_from_server(self.reward_model, messages)
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="reward")
reward_model = self.model
else:
reward_model = self.reward_model
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
reward_model = self.reward_model if self.reward_model is not None else self.model
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1)
rewards = []
@ -261,7 +271,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
if self.reward_model is None:
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
return rewards

View File

@ -1,10 +1,24 @@
import json
import torch
from typing import TYPE_CHECKING, Dict, Literal, Optional
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
from llmtuner.extras.packages import is_requests_available
if TYPE_CHECKING:
from transformers import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
if is_requests_available():
import requests
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
headers = {"Content-Type": "application/json"}
payload = {"model": "model", "messages": messages}
response = requests.post(server_url, json=payload, headers=headers)
rewards = json.loads(response.text)["scores"]
return torch.Tensor(rewards)
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily

View File

@ -76,7 +76,9 @@ def create_reward_model(
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "api":
raise NotImplementedError
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info("Use reward server {}".format(finetuning_args.reward_model))
return finetuning_args.reward_model
elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
@ -102,6 +104,6 @@ def create_reward_model(
reward_model, _ = load_model_and_tokenizer(
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
)
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model