fix memory leak of PPO trainer

This commit is contained in:
hiyouga 2023-08-02 17:41:34 +08:00
parent c689857bbb
commit 286f7be346
2 changed files with 6 additions and 4 deletions

View File

@ -100,7 +100,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Get responses # Get responses
query_tensors = batch["input_ids"] query_tensors = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs) response_tensors = self.generate(
batch, length_sampler, return_prompt=False, **gen_kwargs
).detach().cpu() # move to cpu
queries, responses = [], [] queries, responses = [], []
for i in range(len(query_tensors)): for i in range(len(query_tensors)):
@ -112,12 +114,12 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Compute rewards # Compute rewards
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
with torch.no_grad(): with torch.no_grad():
_, _, values = self.model( _, _, values: torch.Tensor = self.model(
**self.prepare_model_inputs(queries, responses), **self.prepare_model_inputs(queries, responses),
output_hidden_states=True, output_hidden_states=True,
return_dict=True return_dict=True
) )
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
# Run PPO step # Run PPO step

View File

@ -1,5 +1,5 @@
# Inspired by: # Inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py # https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math import math
from typing import TYPE_CHECKING from typing import TYPE_CHECKING