update ppo trainer
This commit is contained in:
parent
286f7be346
commit
b5ba87952a
|
@ -47,7 +47,6 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
model = unwrap_model(self.model)
|
||||
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
|
||||
model_state_dict = state_dict or model.state_dict()
|
||||
|
|
|
@ -2,10 +2,9 @@ import os
|
|||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import TrainerState, TrainerControl
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
|
@ -18,6 +17,7 @@ from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
@ -43,7 +43,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
self.log_callback = callbacks[0]
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
|
||||
self._remove_log()
|
||||
|
||||
def ppo_train(self, max_target_length: int) -> None:
|
||||
|
@ -83,7 +82,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
|
@ -95,38 +94,22 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
batch = next(dataiter)
|
||||
steps_trained += 1
|
||||
|
||||
# Cast to inference mode
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
unwrapped_model.config.use_cache = True
|
||||
unwrapped_model.eval()
|
||||
|
||||
# Get responses
|
||||
query_tensors = batch["input_ids"]
|
||||
response_tensors = self.generate(
|
||||
batch, length_sampler, return_prompt=False, **gen_kwargs
|
||||
).detach().cpu() # move to cpu
|
||||
# Get inputs
|
||||
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
|
||||
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
||||
|
||||
queries, responses = [], []
|
||||
for i in range(len(query_tensors)):
|
||||
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
queries.append(query_tensors[i, query_length:]) # remove padding from left
|
||||
responses.append(response_tensors[i, :response_length]) # remove padding from right
|
||||
|
||||
# Compute rewards
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
with torch.no_grad():
|
||||
_, _, values: torch.Tensor = self.model(
|
||||
**self.prepare_model_inputs(queries, responses),
|
||||
output_hidden_states=True,
|
||||
return_dict=True
|
||||
)
|
||||
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
# Run PPO step
|
||||
# Cast to training mode
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
unwrapped_model.config.use_cache = False
|
||||
stats = self.step(queries, responses, rewards)
|
||||
unwrapped_model.train()
|
||||
|
||||
# Run PPO step
|
||||
stats = self.step(queries, responses, rewards)
|
||||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||
|
||||
|
@ -155,37 +138,57 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
steps_trained = 0
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
def get_inputs(
|
||||
self,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
length_sampler: Optional[Callable] = None,
|
||||
return_prompt: Optional[bool] = True,
|
||||
**generation_kwargs
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
||||
|
||||
if length_sampler is not None:
|
||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
response = unwrapped_model.generate(**inputs, **generation_kwargs)
|
||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(**inputs, **generation_kwargs)
|
||||
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
||||
|
||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
||||
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
||||
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
||||
|
||||
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
||||
queries, responses = [], []
|
||||
query, response = inputs["input_ids"], response[:, inputs["input_ids"].size(-1):].detach().cpu()
|
||||
for i in range(len(query)):
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
queries.append(query[i, query_length:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
|
||||
if not return_prompt and not self.is_encoder_decoder:
|
||||
return response[:, inputs["input_ids"].size(1):]
|
||||
return response
|
||||
return queries, responses
|
||||
|
||||
@torch.no_grad()
|
||||
def get_rewards(
|
||||
self,
|
||||
queries: List[torch.Tensor],
|
||||
responses: List[torch.Tensor],
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead"
|
||||
) -> List[torch.Tensor]:
|
||||
r"""
|
||||
Computes scores using given reward model.
|
||||
"""
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
_, _, values = self.model(
|
||||
**self.prepare_model_inputs(queries, responses),
|
||||
output_hidden_states=True,
|
||||
return_dict=True
|
||||
)
|
||||
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
|
||||
replace_model(unwrapped_model, target="default")
|
||||
return rewards
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue