forked from p04798526/LLaMA-Factory-Mirror
fix ChatGLM RLHF
This commit is contained in:
parent
a7dd9611db
commit
af6c011fcb
|
@ -182,6 +182,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
replace_model(unwrapped_model, target="reward")
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
||||
if values.size(0) != batch["input_ids"].size(0):
|
||||
values = torch.transpose(values, 0, 1)
|
||||
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
|
||||
replace_model(unwrapped_model, target="default")
|
||||
return rewards
|
||||
|
|
|
@ -42,6 +42,8 @@ class PairwisePeftTrainer(PeftTrainer):
|
|||
"""
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
if values.size(0) != inputs["input_ids"].size(0):
|
||||
values = torch.transpose(values, 0, 1)
|
||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
||||
|
|
Loading…
Reference in New Issue