diff --git a/tests/data/processors/test_unsupervised.py b/tests/data/processors/test_unsupervised.py index 8713c772..1bfab53e 100644 --- a/tests/data/processors/test_unsupervised.py +++ b/tests/data/processors/test_unsupervised.py @@ -30,17 +30,18 @@ TINY_DATA = os.environ.get("TINY_DATA", "llamafactory/tiny-supervised-dataset") TRAIN_ARGS = { "model_name_or_path": TINY_LLAMA, - "stage": "sft", - "do_predict": True, + "stage": "ppo", + "do_train": True, "finetuning_type": "full", - "eval_dataset": "system_chat", + "reward_model": "", + "reward_model_type": "full", + "dataset": "system_chat", "dataset_dir": "REMOTE:" + DEMO_DATA, "template": "llama3", "cutoff_len": 8192, "overwrite_cache": True, "output_dir": "dummy_dir", "overwrite_output_dir": True, - "predict_with_generate": True, "fp16": True, }