fix #761
This commit is contained in:
parent
8ea32e4046
commit
b34797a845
|
@ -14,7 +14,6 @@ class ChatModel:
|
|||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.model = dispatch_model(self.model)
|
||||
self.model = self.model.eval() # enable evaluation mode
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.system_prompt = data_args.system_prompt
|
||||
|
||||
|
|
|
@ -175,6 +175,7 @@ def load_model_and_tokenizer(
|
|||
# Initialize adapters
|
||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||
model = model.train() if is_trainable else model.eval()
|
||||
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage == "rm" or stage == "ppo":
|
||||
|
|
|
@ -99,6 +99,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
# Cast to inference mode
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
unwrapped_model.config.use_cache = True
|
||||
unwrapped_model, layer_norm_params = cast_layernorm_dtype(unwrapped_model, self.compute_dtype)
|
||||
self.model.eval()
|
||||
|
||||
# Get inputs
|
||||
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
|
||||
|
@ -108,6 +110,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
# Cast to training mode
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
unwrapped_model.config.use_cache = False
|
||||
unwrapped_model, _ = cast_layernorm_dtype(unwrapped_model, self.compute_dtype, layer_norm_params)
|
||||
self.model.train()
|
||||
|
||||
# Run PPO step
|
||||
stats = self.step(queries, responses, rewards)
|
||||
|
@ -157,10 +161,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
if length_sampler is not None:
|
||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||
|
||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
|
||||
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
|
||||
|
||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
||||
|
|
Loading…
Reference in New Issue