diff --git a/src/utils/ppo.py b/src/utils/ppo.py index 10c80d22..5e754e48 100644 --- a/src/utils/ppo.py +++ b/src/utils/ppo.py @@ -214,51 +214,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): return response[:, inputs["input_ids"].size(1):] return response - @PPODecorators.empty_cuda_cache() - def batched_forward_pass( - self, - model: AutoModelForCausalLMWithValueHead, - queries: torch.Tensor, - responses: torch.Tensor, - model_inputs: dict, - ): - r""" - Calculates model outputs in multiple batches. - - Subclass and override to inject custom behavior. - """ - bs = len(model_inputs["input_ids"]) - fbs = self.config.mini_batch_size - all_logprobs = [] - all_logits = [] - all_masks = [] - all_values = [] - - for i in range(int(bs / fbs)): - input_kwargs = {k: v[i * fbs : (i + 1) * fbs] for k, v in model_inputs.items()} - input_ids: torch.Tensor = input_kwargs["input_ids"] # left-padded sequences - logits, _, values = model(**input_kwargs) - logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) - - masks = torch.zeros_like(input_ids) - for j in range(fbs): - start = (input_ids[j] == self.tokenizer.bos_token_id).nonzero()[0].item() - masks[j][start:] = 1 - if len(masks[j][start:]) < 2: - raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.") - - all_logits.append(logits) - all_values.append(values) - all_logprobs.append(logprobs) - all_masks.append(masks) - - return ( - torch.cat(all_logprobs), - torch.cat(all_logits)[:, :-1], - torch.cat(all_values)[:, :-1], - torch.cat(all_masks)[:, :-1], - ) - def save_model(self, output_dir: Optional[str] = None) -> None: r""" Saves model checkpoint.