fix eval and pred loss

This commit is contained in:
hiyouga 2023-07-14 13:11:57 +08:00
parent a04115ec27
commit c30db9f1f0
1 changed files with 4 additions and 2 deletions

View File

@ -79,11 +79,13 @@ class Seq2SeqPeftTrainer(PeftTrainer):
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
input_ids = inputs["input_ids"] prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
loss, generated_tokens, labels = super().prediction_step( loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
generated_tokens = generated_tokens[:, input_ids.size(-1):] if generated_tokens is not None else None generated_tokens = generated_tokens[:, prompt_len:] if generated_tokens is not None else None
return (loss, generated_tokens, labels) return (loss, generated_tokens, labels)
def save_predictions( def save_predictions(