diff --git a/src/train_sft.py b/src/train_sft.py index da104fdd..30ca2e2c 100644 --- a/src/train_sft.py +++ b/src/train_sft.py @@ -80,15 +80,19 @@ def main(): # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) + if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled + metrics.pop("eval_loss", None) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) # Predict if training_args.do_predict: predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) + if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled + predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) - trainer.save_predictions(predict_results, tokenizer) + trainer.save_predictions(predict_results) def _mp_fn(index): diff --git a/src/utils/seq2seq.py b/src/utils/seq2seq.py index 90a9810e..6c772a25 100644 --- a/src/utils/seq2seq.py +++ b/src/utils/seq2seq.py @@ -32,18 +32,12 @@ class ComputeMetrics: Uses the model predictions to compute metrics. """ preds, labels = eval_preds - - if isinstance(preds, tuple): - preds = preds[0] - - # Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them. - preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) - labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) - - preds = preds[:, labels.shape[1]:] # remove prompts score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} for pred, label in zip(preds, labels): + pred = pred[len(label) - np.sum(label == IGNORE_INDEX) : len(pred) - np.sum(pred == IGNORE_INDEX)] # remove prompts + label = label[:len(label) - np.sum(label == IGNORE_INDEX)] + hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True))) reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True))) @@ -70,8 +64,7 @@ class Seq2SeqPeftTrainer(PeftTrainer): def save_predictions( self, - predict_results: PredictionOutput, - tokenizer: PreTrainedTokenizer + predict_results: PredictionOutput ) -> None: r""" Saves model predictions to `output_dir`. @@ -81,17 +74,17 @@ class Seq2SeqPeftTrainer(PeftTrainer): if not self.is_world_process_zero(): return - preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) - labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) - - preds = preds[:, labels.shape[1]:] # remove prompts - preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds] - labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels] - output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") logger.info(f"Saving prediction results to {output_prediction_file}") with open(output_prediction_file, "w", encoding="utf-8") as writer: res: List[str] = [] - for pred, label in zip(preds, labels): + for pred, label in zip(predict_results.predictions, predict_results.label_ids): + pred = pred[len(label) - np.sum(label == IGNORE_INDEX) : len(pred) - np.sum(pred == IGNORE_INDEX)] # remove prompts + label = label[:len(label) - np.sum(label == IGNORE_INDEX)] + + pred = self.tokenizer.decode(pred, skip_special_tokens=True) + label = self.tokenizer.decode(label, skip_special_tokens=True) + res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) + writer.write("\n".join(res))