add eval acc

This commit is contained in:
hiyouga 2024-07-01 03:51:20 +08:00
parent fc2c15d713
commit 1856a08e87
3 changed files with 31 additions and 17 deletions

View File

@ -17,9 +17,11 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Dict
import numpy as np
import torch
from transformers import EvalPrediction
from transformers.utils import is_jieba_available, is_nltk_available
from ...extras.constants import IGNORE_INDEX
@ -42,6 +44,22 @@ if is_rouge_available():
from rouge_chinese import Rouge
def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
preds, labels = eval_preds.predictions, eval_preds.label_ids
accuracies = []
for i in range(len(preds)):
pred, label = preds[i, 1:], labels[i, :-1]
label_mask = label != IGNORE_INDEX
accuracies.append(np.mean(pred[label_mask] == label[label_mask]))
return {"accuracy": float(np.mean(accuracies))}
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
logits = logits[0] if isinstance(logits, (list, tuple)) else logits
return torch.argmax(logits, dim=-1)
@dataclass
class ComputeMetrics:
r"""
@ -50,11 +68,11 @@ class ComputeMetrics:
tokenizer: "PreTrainedTokenizer"
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]:
r"""
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
preds, labels = eval_preds.predictions, eval_preds.label_ids
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)

View File

@ -135,21 +135,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
if len(pad_len):
preds[i] = np.concatenate(
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
) # move pad token to last
if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
decoded_inputs = self.tokenizer.batch_decode(
dataset["input_ids"], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
decoded_labels = self.tokenizer.batch_decode(
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res))

View File

@ -25,7 +25,7 @@ from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeMetrics
from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer
@ -72,7 +72,8 @@ def run_sft(
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
**tokenizer_module,
**split_dataset(dataset, data_args, training_args),
)
@ -91,7 +92,7 @@ def run_sft(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
# Evaluation
if training_args.do_eval: