Update trainer.py

This commit is contained in:
hoshi-hiyouga 2024-05-15 14:13:26 +08:00 committed by GitHub
parent c309605ff5
commit aa4a8933dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 9 deletions

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import ProcessorMixin, Seq2SeqTrainer
from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
@ -13,6 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments
@ -26,7 +27,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
"""
def __init__(self, finetuning_args: "FinetuningArguments", processor: "ProcessorMixin", **kwargs) -> None:
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self.processor = processor
@ -46,6 +49,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def prediction_step(
self,
model: "torch.nn.Module",
@ -121,10 +130,3 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
for label, pred in zip(decoded_labels, decoded_preds):
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res))
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
super().save_model(output_dir, _internal_call)
if self.processor is not None:
if output_dir is None:
output_dir = self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)