From 815b92e698562bfae6eb9a6fa1b612a05d43ed67 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 10 Sep 2023 14:22:03 +0800 Subject: [PATCH] fix #850 --- src/llmtuner/tuner/sft/trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/tuner/sft/trainer.py b/src/llmtuner/tuner/sft/trainer.py index 66fe04a7..69507600 100644 --- a/src/llmtuner/tuner/sft/trainer.py +++ b/src/llmtuner/tuner/sft/trainer.py @@ -50,9 +50,11 @@ class Seq2SeqPeftTrainer(PeftTrainer): loss, generated_tokens, labels = super().prediction_step( model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys ) - generated_tokens = ( - generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None - ) + if generated_tokens is not None: + generated_tokens[:, :max(prompt_len, label_len)] = ( + self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)]) + ) + generated_tokens = generated_tokens.contiguous() return loss, generated_tokens, labels