diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index 72faef0a..c69608c0 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -48,7 +48,7 @@ def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]: preds, labels = eval_preds.predictions, eval_preds.label_ids accuracies = [] for i in range(len(preds)): - pred, label = preds[i, 1:], labels[i, :-1] + pred, label = preds[i, :-1], labels[i, 1:] label_mask = label != IGNORE_INDEX accuracies.append(np.mean(pred[label_mask] == label[label_mask]))