ppo support rm server
This commit is contained in:
parent
7df4f3ab20
commit
747db40172
|
@ -167,7 +167,8 @@ class ChatModel:
|
|||
|
||||
scores = []
|
||||
for i in range(input_ids.size(0)):
|
||||
length = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
scores.append(values[i, length-1].nan_to_num().item())
|
||||
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
scores.append(values[i, end_index].nan_to_num().item())
|
||||
|
||||
return scores
|
||||
|
|
|
@ -18,6 +18,7 @@ _flash_attn2_available = is_package_available("flash_attn") and get_package_vers
|
|||
_jieba_available = is_package_available("jieba")
|
||||
_matplotlib_available = is_package_available("matplotlib")
|
||||
_nltk_available = is_package_available("nltk")
|
||||
_requests_available = is_package_available("requests")
|
||||
_rouge_available = is_package_available("rouge_chinese")
|
||||
_starlette_available = is_package_available("sse_starlette")
|
||||
_uvicorn_available = is_package_available("uvicorn")
|
||||
|
@ -43,6 +44,10 @@ def is_nltk_available():
|
|||
return _nltk_available
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
return _requests_available
|
||||
|
||||
|
||||
def is_rouge_available():
|
||||
return _rouge_available
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ import sys
|
|||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
|
@ -16,7 +16,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
|||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
||||
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
|
@ -200,7 +200,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
|
@ -208,7 +208,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
layernorm_params = dump_layernorm(self.model)
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config,
|
||||
logits_processor=get_logits_processor(),
|
||||
**batch
|
||||
|
@ -217,7 +217,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
if self.finetuning_args.upcast_layernorm:
|
||||
restore_layernorm(self.model, layernorm_params)
|
||||
|
||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
query = batch["input_ids"].detach().cpu()
|
||||
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
|
@ -242,17 +243,26 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
) -> List[torch.Tensor]:
|
||||
r"""
|
||||
Computes scores using given reward model.
|
||||
|
||||
Both inputs and outputs are put on CPU.
|
||||
"""
|
||||
if self.reward_model is None:
|
||||
if self.finetuning_args.reward_model_type == "api":
|
||||
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
|
||||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
reward_model = self.model
|
||||
else:
|
||||
reward_model = self.reward_model
|
||||
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
reward_model = self.reward_model if self.reward_model is not None else self.model
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
rewards = []
|
||||
|
@ -261,7 +271,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
if self.reward_model is None:
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
return rewards
|
||||
|
|
|
@ -1,10 +1,24 @@
|
|||
import json
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||
|
||||
from llmtuner.extras.packages import is_requests_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
if is_requests_available():
|
||||
import requests
|
||||
|
||||
|
||||
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = {"model": "model", "messages": messages}
|
||||
response = requests.post(server_url, json=payload, headers=headers)
|
||||
rewards = json.loads(response.text)["scores"]
|
||||
return torch.Tensor(rewards)
|
||||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
|
|
|
@ -76,7 +76,9 @@ def create_reward_model(
|
|||
Creates reward model for PPO training.
|
||||
"""
|
||||
if finetuning_args.reward_model_type == "api":
|
||||
raise NotImplementedError
|
||||
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||
logger.info("Use reward server {}".format(finetuning_args.reward_model))
|
||||
return finetuning_args.reward_model
|
||||
elif finetuning_args.reward_model_type == "lora":
|
||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
|
@ -102,6 +104,6 @@ def create_reward_model(
|
|||
reward_model, _ = load_model_and_tokenizer(
|
||||
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
||||
)
|
||||
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||
return reward_model
|
||||
|
|
Loading…
Reference in New Issue