fix decoding in seq2seq

This commit is contained in:
hiyouga 2023-06-27 19:33:08 +08:00
parent 33f2141507
commit 1c732e2537
2 changed files with 17 additions and 20 deletions

View File

@ -80,15 +80,19 @@ def main():
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
# Predict # Predict
if training_args.do_predict: if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics) trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results, tokenizer) trainer.save_predictions(predict_results)
def _mp_fn(index): def _mp_fn(index):

View File

@ -32,18 +32,12 @@ class ComputeMetrics:
Uses the model predictions to compute metrics. Uses the model predictions to compute metrics.
""" """
preds, labels = eval_preds preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
# Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them.
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
preds = preds[:, labels.shape[1]:] # remove prompts
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
for pred, label in zip(preds, labels): for pred, label in zip(preds, labels):
pred = pred[len(label) - np.sum(label == IGNORE_INDEX) : len(pred) - np.sum(pred == IGNORE_INDEX)] # remove prompts
label = label[:len(label) - np.sum(label == IGNORE_INDEX)]
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True))) hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True))) reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
@ -70,8 +64,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
def save_predictions( def save_predictions(
self, self,
predict_results: PredictionOutput, predict_results: PredictionOutput
tokenizer: PreTrainedTokenizer
) -> None: ) -> None:
r""" r"""
Saves model predictions to `output_dir`. Saves model predictions to `output_dir`.
@ -81,17 +74,17 @@ class Seq2SeqPeftTrainer(PeftTrainer):
if not self.is_world_process_zero(): if not self.is_world_process_zero():
return return
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
preds = preds[:, labels.shape[1]:] # remove prompts
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}") logger.info(f"Saving prediction results to {output_prediction_file}")
with open(output_prediction_file, "w", encoding="utf-8") as writer: with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = [] res: List[str] = []
for pred, label in zip(preds, labels): for pred, label in zip(predict_results.predictions, predict_results.label_ids):
pred = pred[len(label) - np.sum(label == IGNORE_INDEX) : len(pred) - np.sum(pred == IGNORE_INDEX)] # remove prompts
label = label[:len(label) - np.sum(label == IGNORE_INDEX)]
pred = self.tokenizer.decode(pred, skip_special_tokens=True)
label = self.tokenizer.decode(label, skip_special_tokens=True)
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res)) writer.write("\n".join(res))