fix #617
This commit is contained in:
parent
02d69b6fde
commit
5235b15c91
|
@ -10,7 +10,7 @@ from llmtuner.tuner.core.trainer import PeftModelMixin
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
|
@ -18,12 +18,10 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
|||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
**kwargs
|
||||
):
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generating_args = generating_args
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
|
|
|
@ -13,7 +13,7 @@ from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
def run_dpo(
|
||||
|
@ -21,7 +21,6 @@ def run_dpo(
|
|||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
|
@ -38,7 +37,6 @@ def run_dpo(
|
|||
# Initialize our Trainer
|
||||
trainer = DPOPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
ref_model=ref_model,
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
|
Loading…
Reference in New Issue