fix seq2seq predictions
This commit is contained in:
parent
cb26f78923
commit
65e9ce2cdd
|
@ -89,9 +89,9 @@ huggingface-cli login
|
|||
|
||||
And **powerful GPUs**!
|
||||
|
||||
If you want to enable LoRA(QLoRA) or Freeze quantization on Windows, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
|
||||
If you want to enable quantized LoRA (QLoRA) on the Windows platform, you should install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
|
||||
|
||||
```
|
||||
```bash
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from transformers.trainer import PredictionOutput
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
@ -34,11 +36,10 @@ class ComputeMetrics:
|
|||
preds, labels = eval_preds
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
|
||||
for pred, label in zip(preds, labels):
|
||||
pred_pad_len, label_pad_len = np.sum(pred == IGNORE_INDEX), np.sum(label == IGNORE_INDEX)
|
||||
pred = pred[len(label) - label_pad_len : len(pred) - pred_pad_len] # remove prompts
|
||||
label = label[:len(label) - label_pad_len]
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||
|
||||
for pred, label in zip(preds, labels):
|
||||
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
|
||||
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
|
||||
|
||||
|
@ -63,6 +64,25 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
r"""
|
||||
Removes the prompt part in the generated tokens.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
input_ids = inputs["input_ids"]
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
generated_tokens = generated_tokens[:, input_ids.size(-1):] if generated_tokens is not None else None
|
||||
return (loss, generated_tokens, labels)
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: PredictionOutput
|
||||
|
@ -77,13 +97,13 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||
|
||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||
|
||||
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
||||
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for pred, label in zip(predict_results.predictions, predict_results.label_ids):
|
||||
pred_pad_len, label_pad_len = np.sum(pred == IGNORE_INDEX), np.sum(label == IGNORE_INDEX)
|
||||
pred = pred[len(label) - label_pad_len : len(pred) - pred_pad_len] # remove prompts
|
||||
label = label[:len(label) - label_pad_len]
|
||||
|
||||
for pred, label in zip(preds, labels):
|
||||
pred = self.tokenizer.decode(pred, skip_special_tokens=True)
|
||||
label = self.tokenizer.decode(label, skip_special_tokens=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue