From 2de1a7610a78e41680970b9f308741f98df489fa Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 23 Aug 2023 20:45:03 +0800 Subject: [PATCH] fix import error --- src/llmtuner/tuner/dpo/workflow.py | 3 ++- src/llmtuner/tuner/rm/workflow.py | 3 ++- src/llmtuner/tuner/sft/workflow.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/tuner/dpo/workflow.py index f2bad7b0..31d82fbf 100644 --- a/src/llmtuner/tuner/dpo/workflow.py +++ b/src/llmtuner/tuner/dpo/workflow.py @@ -3,6 +3,7 @@ from copy import deepcopy from peft import PeftModel from typing import TYPE_CHECKING, Optional, List +from transformers import Seq2SeqTrainingArguments from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX @@ -12,7 +13,7 @@ from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding from llmtuner.tuner.dpo.trainer import DPOPeftTrainer if TYPE_CHECKING: - from transformers import Seq2SeqTrainingArguments, TrainerCallback + from transformers import TrainerCallback from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index 68dfff69..fd1e7a47 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -3,6 +3,7 @@ # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py from typing import TYPE_CHECKING, Optional, List +from transformers import Seq2SeqTrainingArguments from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.ploting import plot_loss @@ -12,7 +13,7 @@ from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding from llmtuner.tuner.rm.trainer import PairwisePeftTrainer if TYPE_CHECKING: - from transformers import Seq2SeqTrainingArguments, TrainerCallback + from transformers import TrainerCallback from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 0fa90071..511db1ba 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -1,7 +1,7 @@ # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py from typing import TYPE_CHECKING, Optional, List -from transformers import DataCollatorForSeq2Seq +from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX @@ -12,7 +12,7 @@ from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer if TYPE_CHECKING: - from transformers import Seq2SeqTrainingArguments, TrainerCallback + from transformers import TrainerCallback from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments