From 87390ae3b70f654d520b9aadb335c9650130a42c Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 13 Nov 2023 22:42:23 +0800 Subject: [PATCH] fix #424 --- src/llmtuner/tuner/ppo/trainer.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index f9626b1d..3d591615 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -5,7 +5,7 @@ import torch from tqdm import tqdm from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl +from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from trl import PPOTrainer @@ -108,9 +108,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.model.eval() # Get inputs - queries, responses = self.get_inputs(batch) self.tokenizer.padding_side = "right" # change padding side - rewards = self.get_rewards(queries, responses, unwrapped_model) + queries, responses, rewards = [], [], [] + for idx in range(0, self.config.batch_size, self.config.mini_batch_size): + mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size]) + mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model) + queries.extend(mini_batch_queries) + responses.extend(mini_batch_responses) + rewards.extend(mini_batch_rewards) # Cast to training mode unwrapped_model.gradient_checkpointing_enable() @@ -165,7 +170,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ) @torch.no_grad() - def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: r""" Generates model's responses given queries. """ @@ -219,7 +224,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): rewards = [] for i in range(values.size(0)): - end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero() + end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero() end_index = end_indexes[-1].item() if len(end_indexes) else 0 rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type