diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 6c7769ef..d3f79f05 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -100,7 +100,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): # Get responses query_tensors = batch["input_ids"] - response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs) + response_tensors = self.generate( + batch, length_sampler, return_prompt=False, **gen_kwargs + ).detach().cpu() # move to cpu queries, responses = [], [] for i in range(len(query_tensors)): @@ -112,12 +114,12 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): # Compute rewards replace_model(unwrapped_model, target="reward") with torch.no_grad(): - _, _, values = self.model( + _, _, values: torch.Tensor = self.model( **self.prepare_model_inputs(queries, responses), output_hidden_states=True, return_dict=True ) - rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type + rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type replace_model(unwrapped_model, target="default") # Run PPO step diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 3a229c8c..eed0707f 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -1,5 +1,5 @@ # Inspired by: -# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py +# https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py import math from typing import TYPE_CHECKING