diff --git a/src/utils/seq2seq.py b/src/utils/seq2seq.py index 05f25c85..53b12f38 100644 --- a/src/utils/seq2seq.py +++ b/src/utils/seq2seq.py @@ -39,9 +39,12 @@ class ComputeMetrics: preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) - for pred, label in zip(preds, labels): - hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True))) - reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True))) + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + + for pred, label in zip(decoded_preds, decoded_labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) if len(" ".join(hypothesis).split()) == 0: result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} @@ -101,12 +104,11 @@ class Seq2SeqPeftTrainer(PeftTrainer): preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + with open(output_prediction_file, "w", encoding="utf-8") as writer: res: List[str] = [] - for pred, label in zip(preds, labels): - pred = self.tokenizer.decode(pred, skip_special_tokens=True) - label = self.tokenizer.decode(label, skip_special_tokens=True) - + for pred, label in zip(decoded_preds, decoded_labels): res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) - writer.write("\n".join(res))