From fd8cc490084aba9b5155eaaaf26129efd2871fa3 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 15 Jul 2024 22:32:07 +0800 Subject: [PATCH] fix #4820 --- src/llamafactory/train/sft/metric.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index 86f8bb15..a53d8efb 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -55,7 +55,15 @@ def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]: def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor": - logits = logits[0] if isinstance(logits, (list, tuple)) else logits + if isinstance(logits, (list, tuple)): + if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size) + logits = logits[0] + else: # moe models have aux loss + logits = logits[1] + + if logits.dim() != 3: + raise ValueError("Cannot process the logits.") + return torch.argmax(logits, dim=-1)