forked from p04798526/LLaMA-Factory-Mirror
10x generate in ppo w/ zero3
https://github.com/huggingface/trl/pull/1483
This commit is contained in:
parent
7c8e01bb74
commit
65cd8bdbdb
|
@ -13,6 +13,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
|||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from trl import PPOConfig, PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
from trl.models.utils import unwrap_model_for_generation
|
||||
|
||||
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
||||
from ...extras.logging import get_logger
|
||||
|
@ -322,10 +323,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
for k, v in batch.items():
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
|
||||
if self.model_args.upcast_layernorm:
|
||||
restore_layernorm(self.model, layernorm_params)
|
||||
|
|
Loading…
Reference in New Issue