diff --git a/src/llamafactory/train/rm/metric.py b/src/llamafactory/train/rm/metric.py index b77c58d0..7c9dfeb4 100644 --- a/src/llamafactory/train/rm/metric.py +++ b/src/llamafactory/train/rm/metric.py @@ -26,8 +26,16 @@ if TYPE_CHECKING: @dataclass class ComputeAccuracy: - def __post_init__(self): + def _dump(self) -> Optional[Dict[str, float]]: + result = None + if hasattr(self, "score_dict"): + result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} + self.score_dict = {"accuracy": []} + return result + + def __post_init__(self): + self._dump() def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1]) @@ -38,4 +46,4 @@ class ComputeAccuracy: self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i]) if compute_result: - return {"accuracy": float(np.mean(self.score_dict["accuracy"]))} + return self._dump() diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index efd90369..69327379 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -59,8 +59,16 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor @dataclass class ComputeAccuracy: - def __post_init__(self): + def _dump(self) -> Optional[Dict[str, float]]: + result = None + if hasattr(self, "score_dict"): + result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} + self.score_dict = {"accuracy": []} + return result + + def __post_init__(self): + self._dump() def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids) @@ -70,7 +78,7 @@ class ComputeAccuracy: self.score_dict["accuracy"].append(np.mean(pred[label_mask] == label[label_mask])) if compute_result: - return {"accuracy": float(np.mean(self.score_dict["accuracy"]))} + return self._dump() @dataclass @@ -81,8 +89,16 @@ class ComputeSimilarity: tokenizer: "PreTrainedTokenizer" - def __post_init__(self): + def _dump(self) -> Optional[Dict[str, float]]: + result = None + if hasattr(self, "score_dict"): + result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} + self.score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} + return result + + def __post_init__(self): + self._dump() def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids) @@ -111,4 +127,4 @@ class ComputeSimilarity: self.score_dict["bleu-4"].append(round(bleu_score * 100, 4)) if compute_result: - return {k: float(np.mean(v)) for k, v in self.score_dict.items()} + return self._dump()