remove unused code

This commit is contained in:
hiyouga 2023-06-03 00:10:54 +08:00
parent 72a85ccc39
commit ed6161fa6a
1 changed files with 0 additions and 45 deletions

View File

@ -214,51 +214,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
return response[:, inputs["input_ids"].size(1):]
return response
@PPODecorators.empty_cuda_cache()
def batched_forward_pass(
self,
model: AutoModelForCausalLMWithValueHead,
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
):
r"""
Calculates model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
bs = len(model_inputs["input_ids"])
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
for i in range(int(bs / fbs)):
input_kwargs = {k: v[i * fbs : (i + 1) * fbs] for k, v in model_inputs.items()}
input_ids: torch.Tensor = input_kwargs["input_ids"] # left-padded sequences
logits, _, values = model(**input_kwargs)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(input_ids)
for j in range(fbs):
start = (input_ids[j] == self.tokenizer.bos_token_id).nonzero()[0].item()
masks[j][start:] = 1
if len(masks[j][start:]) < 2:
raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.")
all_logits.append(logits)
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1],
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.