fix memory leak of PPO trainer

This commit is contained in:
hiyouga 2023-08-02 17:41:34 +08:00
parent c689857bbb
commit 286f7be346
2 changed files with 6 additions and 4 deletions

View File

@ -100,7 +100,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Get responses
query_tensors = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
response_tensors = self.generate(
batch, length_sampler, return_prompt=False, **gen_kwargs
).detach().cpu() # move to cpu
queries, responses = [], []
for i in range(len(query_tensors)):
@ -112,12 +114,12 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Compute rewards
replace_model(unwrapped_model, target="reward")
with torch.no_grad():
_, _, values = self.model(
_, _, values: torch.Tensor = self.model(
**self.prepare_model_inputs(queries, responses),
output_hidden_states=True,
return_dict=True
)
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default")
# Run PPO step

View File

@ -1,5 +1,5 @@
# Inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
# https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
from typing import TYPE_CHECKING