diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index c063b214..6bf5b7c0 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -13,6 +13,7 @@ from ..trainer_utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: + from torch.utils.data import Dataset from transformers import ProcessorMixin from transformers.trainer import PredictionOutput @@ -94,7 +95,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding return padded_tensor.contiguous() # in contiguous memory - def save_predictions(self, predict_results: "PredictionOutput") -> None: + def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None: r""" Saves model predictions to `output_dir`. @@ -120,6 +121,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): (preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1 ) # move pad token to last + decoded_inputs = self.tokenizer.batch_decode( + dataset["input_ids"], skip_special_tokens=True, clean_up_tokenization_spaces=False + ) decoded_labels = self.tokenizer.batch_decode( labels, skip_special_tokens=True, clean_up_tokenization_spaces=False ) @@ -127,6 +131,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): with open(output_prediction_file, "w", encoding="utf-8") as writer: res: List[str] = [] - for label, pred in zip(decoded_labels, decoded_preds): - res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) + for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds): + res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False)) writer.write("\n".join(res)) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index f09b5173..a989b3f7 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -93,7 +93,7 @@ def run_sft( predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) - trainer.save_predictions(predict_results) + trainer.save_predictions(dataset, predict_results) # Create model card create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)