From 821bb6660e57c29ebf6ac482e78dd2efb8d72437 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 2 Jul 2024 23:03:17 +0800 Subject: [PATCH] remove rlhf support for chatglm2&3 --- src/llamafactory/train/ppo/trainer.py | 12 +----------- src/llamafactory/train/rm/trainer.py | 8 +------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 1c401938..7e0c0111 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -150,14 +150,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.callback_handler = CallbackHandler( callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler ) - if self.args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) - self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm" - - self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype) + self.amp_context = torch.autocast(self.current_device.type) warnings.simplefilter("ignore") # remove gc warnings on ref model if finetuning_args.reward_model_type == "full": @@ -403,9 +399,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer): if self.finetuning_args.reward_model_type == "lora": replace_model(unwrapped_model, target="default") - if self.is_chatglm_model: # assume same architecture - values = torch.transpose(values, 0, 1) - rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1)) return rewards.float().detach() # use fp32 type @@ -443,9 +436,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer): with self.amp_context: # support bf16 logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False) - if self.is_chatglm_model: - values = torch.transpose(values, 0, 1) - logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) masks = torch.zeros_like(attention_mask) masks[:, :-1] = attention_mask[:, 1:] diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 267e88e2..63f925bb 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -31,7 +31,6 @@ from ..trainer_utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: from transformers import PreTrainedModel, ProcessorMixin from transformers.trainer import PredictionOutput - from trl import AutoModelForCausalLMWithValueHead from ...hparams import FinetuningArguments @@ -86,19 +85,14 @@ class PairwiseTrainer(Trainer): Note that the first element will be removed from the output tuple. See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842 """ - # Compute rewards _, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False) - - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) - if getattr(unwrapped_model.config, "model_type", None) == "chatglm": - values = torch.transpose(values, 0, 1) - batch_size = inputs["input_ids"].size(0) // 2 chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0) chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0) chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1)) rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1)) chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze() + loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean() if return_outputs: return loss, (loss, chosen_scores, rejected_scores)