forked from p04798526/LLaMA-Factory-Mirror
fix ppo dataset bug #4012
This commit is contained in:
parent
fad2591e31
commit
149610c636
|
@ -130,7 +130,7 @@ def get_dataset(
|
|||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "kto"],
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
|
|
|
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "kto"],
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
|
|
|
@ -18,7 +18,7 @@ def preprocess_pretrain_dataset(
|
|||
if data_args.template == "gemma":
|
||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||
|
||||
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
|
||||
else:
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
|
|
|
@ -29,7 +29,7 @@ def run_ppo(
|
|||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
|
|
Loading…
Reference in New Issue