This commit is contained in:
hiyouga 2023-11-13 22:42:23 +08:00
parent 442aefb925
commit 87390ae3b7
1 changed files with 10 additions and 5 deletions

View File

@ -5,7 +5,7 @@ import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from trl import PPOTrainer
@ -108,9 +108,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.eval()
# Get inputs
queries, responses = self.get_inputs(batch)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model)
queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses)
rewards.extend(mini_batch_rewards)
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
@ -165,7 +170,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
)
@torch.no_grad()
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
@ -219,7 +224,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards = []
for i in range(values.size(0)):
end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero()
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type