fix memory leak of PPO trainer
This commit is contained in:
parent
c689857bbb
commit
286f7be346
|
@ -100,7 +100,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||||
|
|
||||||
# Get responses
|
# Get responses
|
||||||
query_tensors = batch["input_ids"]
|
query_tensors = batch["input_ids"]
|
||||||
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
|
response_tensors = self.generate(
|
||||||
|
batch, length_sampler, return_prompt=False, **gen_kwargs
|
||||||
|
).detach().cpu() # move to cpu
|
||||||
|
|
||||||
queries, responses = [], []
|
queries, responses = [], []
|
||||||
for i in range(len(query_tensors)):
|
for i in range(len(query_tensors)):
|
||||||
|
@ -112,12 +114,12 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||||
# Compute rewards
|
# Compute rewards
|
||||||
replace_model(unwrapped_model, target="reward")
|
replace_model(unwrapped_model, target="reward")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, _, values = self.model(
|
_, _, values: torch.Tensor = self.model(
|
||||||
**self.prepare_model_inputs(queries, responses),
|
**self.prepare_model_inputs(queries, responses),
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
return_dict=True
|
return_dict=True
|
||||||
)
|
)
|
||||||
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
|
||||||
# Run PPO step
|
# Run PPO step
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# Inspired by:
|
# Inspired by:
|
||||||
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
# https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
Loading…
Reference in New Issue