From c30db9f1f0db5a6a660cdc60016755241762aae7 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 14 Jul 2023 13:11:57 +0800 Subject: [PATCH] fix eval and pred loss --- src/utils/seq2seq.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/utils/seq2seq.py b/src/utils/seq2seq.py index f9a9dc6b..cfa637d7 100644 --- a/src/utils/seq2seq.py +++ b/src/utils/seq2seq.py @@ -79,11 +79,13 @@ class Seq2SeqPeftTrainer(PeftTrainer): Subclass and override to inject custom behavior. """ - input_ids = inputs["input_ids"] + prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) + inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1) loss, generated_tokens, labels = super().prediction_step( model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys ) - generated_tokens = generated_tokens[:, input_ids.size(-1):] if generated_tokens is not None else None + generated_tokens = generated_tokens[:, prompt_len:] if generated_tokens is not None else None + return (loss, generated_tokens, labels) def save_predictions(