fix #424
This commit is contained in:
parent
442aefb925
commit
87390ae3b7
|
@ -5,7 +5,7 @@ import torch
|
|||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from trl import PPOTrainer
|
||||
|
@ -108,9 +108,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
self.model.eval()
|
||||
|
||||
# Get inputs
|
||||
queries, responses = self.get_inputs(batch)
|
||||
self.tokenizer.padding_side = "right" # change padding side
|
||||
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
||||
queries, responses, rewards = [], [], []
|
||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
|
||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
||||
queries.extend(mini_batch_queries)
|
||||
responses.extend(mini_batch_responses)
|
||||
rewards.extend(mini_batch_rewards)
|
||||
|
||||
# Cast to training mode
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
|
@ -165,7 +170,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
|
@ -219,7 +224,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero()
|
||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
|
|
Loading…
Reference in New Issue