update ppo trainer

This commit is contained in:
hiyouga 2023-08-02 18:46:41 +08:00
parent 286f7be346
commit b5ba87952a
2 changed files with 46 additions and 44 deletions

View File

@ -47,7 +47,6 @@ class PeftTrainer(Seq2SeqTrainer):
logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model)
if isinstance(model, PreTrainedModelWrapper):
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
model_state_dict = state_dict or model.state_dict()

View File

@ -2,10 +2,9 @@ import os
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from transformers import TrainerState, TrainerControl
from transformers.modeling_utils import PreTrainedModel
from trl import PPOTrainer
from trl.core import LengthSampler
@ -18,6 +17,7 @@ from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments
@ -43,7 +43,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.log_callback = callbacks[0]
self.state = TrainerState()
self.control = TrainerControl()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
self._remove_log()
def ppo_train(self, max_target_length: int) -> None:
@ -83,7 +82,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
"logits_processor": get_logits_processor()
}
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
@ -95,38 +94,22 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
batch = next(dataiter)
steps_trained += 1
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
unwrapped_model.eval()
# Get responses
query_tensors = batch["input_ids"]
response_tensors = self.generate(
batch, length_sampler, return_prompt=False, **gen_kwargs
).detach().cpu() # move to cpu
# Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
rewards = self.get_rewards(queries, responses, unwrapped_model)
queries, responses = [], []
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Compute rewards
replace_model(unwrapped_model, target="reward")
with torch.no_grad():
_, _, values: torch.Tensor = self.model(
**self.prepare_model_inputs(queries, responses),
output_hidden_states=True,
return_dict=True
)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default")
# Run PPO step
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
stats = self.step(queries, responses, rewards)
unwrapped_model.train()
# Run PPO step
stats = self.step(queries, responses, rewards)
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
@ -155,37 +138,57 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
steps_trained = 0
@torch.no_grad()
def generate(
def get_inputs(
self,
inputs: Dict[str, torch.Tensor],
length_sampler: Optional[Callable] = None,
return_prompt: Optional[bool] = True,
**generation_kwargs
) -> torch.Tensor:
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
Subclass and override to inject custom behavior.
"""
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
unwrapped_model = self.accelerator.unwrap_model(self.model)
response = unwrapped_model.generate(**inputs, **generation_kwargs)
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**inputs, **generation_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
queries, responses = [], []
query, response = inputs["input_ids"], response[:, inputs["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
if not return_prompt and not self.is_encoder_decoder:
return response[:, inputs["input_ids"].size(1):]
return response
return queries, responses
@torch.no_grad()
def get_rewards(
self,
queries: List[torch.Tensor],
responses: List[torch.Tensor],
unwrapped_model: "AutoModelForCausalLMWithValueHead"
) -> List[torch.Tensor]:
r"""
Computes scores using given reward model.
"""
replace_model(unwrapped_model, target="reward")
_, _, values = self.model(
**self.prepare_model_inputs(queries, responses),
output_hidden_states=True,
return_dict=True
)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default")
return rewards
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""