fix import error
This commit is contained in:
parent
57146c101f
commit
2de1a7610a
|
@ -3,6 +3,7 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
@ -12,7 +13,7 @@ from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||||
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
|
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
|
@ -12,7 +13,7 @@ from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||||
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
@ -12,7 +12,7 @@ from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||||
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue