10x generate in ppo w/ zero3

https://github.com/huggingface/trl/pull/1483
This commit is contained in:
hiyouga 2024-05-29 00:23:23 +08:00
parent 7c8e01bb74
commit 65cd8bdbdb
1 changed files with 5 additions and 4 deletions

View File

@ -13,6 +13,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
@ -322,10 +323,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
for k, v in batch.items():
batch[k] = v[:, start_index:]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)