This commit is contained in:
hiyouga 2023-09-08 20:22:18 +08:00
parent 8ea32e4046
commit b34797a845
3 changed files with 5 additions and 3 deletions

View File

@ -14,7 +14,6 @@ class ChatModel:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model)
self.model = self.model.eval() # enable evaluation mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.system_prompt = data_args.system_prompt

View File

@ -175,6 +175,7 @@ def load_model_and_tokenizer(
# Initialize adapters
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":

View File

@ -99,6 +99,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
unwrapped_model, layer_norm_params = cast_layernorm_dtype(unwrapped_model, self.compute_dtype)
self.model.eval()
# Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
@ -108,6 +110,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
unwrapped_model, _ = cast_layernorm_dtype(unwrapped_model, self.compute_dtype, layer_norm_params)
self.model.train()
# Run PPO step
stats = self.step(queries, responses, rewards)
@ -157,10 +161,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273