This commit is contained in:
hiyouga 2023-08-18 13:07:35 +08:00
parent 53e33418d0
commit d75e377b0f
2 changed files with 9 additions and 4 deletions

View File

@ -182,9 +182,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
"""
replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses)
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0): # adapt chatglm2
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default")
return rewards
@ -220,7 +224,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs)
if values.size(0) != input_ids.size(0): # adapt chatglm2
if values.size(0) != input_ids.size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
@ -240,6 +244,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
all_logits.append(logits)
else:
del logits
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)

View File

@ -42,7 +42,7 @@ class PairwisePeftTrainer(PeftTrainer):
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if values.size(0) != inputs["input_ids"].size(0): # adapt chatglm2
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()