fix bleu score

This commit is contained in:
hiyouga 2023-07-05 00:11:21 +08:00
parent 395ed1cf1b
commit cac87fd553
1 changed files with 10 additions and 8 deletions

View File

@ -39,9 +39,12 @@ class ComputeMetrics:
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
for pred, label in zip(preds, labels):
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
for pred, label in zip(decoded_preds, decoded_labels):
hypothesis = list(jieba.cut(pred))
reference = list(jieba.cut(label))
if len(" ".join(hypothesis).split()) == 0:
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
@ -101,12 +104,11 @@ class Seq2SeqPeftTrainer(PeftTrainer):
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for pred, label in zip(preds, labels):
pred = self.tokenizer.decode(pred, skip_special_tokens=True)
label = self.tokenizer.decode(label, skip_special_tokens=True)
for pred, label in zip(decoded_preds, decoded_labels):
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res))