diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index de741426..def427fd 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -1,5 +1,6 @@ import json import os +from types import MethodType from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -9,8 +10,7 @@ from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger from ..utils import create_custom_optimzer, create_custom_scheduler -from types import MethodType -from packaging import version + if TYPE_CHECKING: from transformers.trainer import PredictionOutput @@ -31,6 +31,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.finetuning_args = finetuning_args if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) def create_optimizer(self) -> "torch.optim.Optimizer":