This commit is contained in:
hiyouga 2023-08-21 18:16:11 +08:00
parent 02d69b6fde
commit 5235b15c91
2 changed files with 2 additions and 6 deletions

View File

@ -10,7 +10,7 @@ from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING:
from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
from llmtuner.hparams import FinetuningArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
@ -18,12 +18,10 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
**kwargs
):
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX

View File

@ -13,7 +13,7 @@ from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_dpo(
@ -21,7 +21,6 @@ def run_dpo(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
@ -38,7 +37,6 @@ def run_dpo(
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
generating_args=generating_args,
ref_model=ref_model,
model=model,
args=training_args,