fix unittest

This commit is contained in:
hiyouga 2024-07-19 01:10:30 +08:00
parent 608de799a2
commit e80006795f
1 changed files with 3 additions and 1 deletions

View File

@ -33,6 +33,8 @@ TRAIN_ARGS = {
"stage": "sft", "stage": "sft",
"do_predict": True, "do_predict": True,
"finetuning_type": "full", "finetuning_type": "full",
"eval_dataset": "system_chat",
"dataset_dir": "REMOTE:" + DEMO_DATA,
"template": "llama3", "template": "llama3",
"cutoff_len": 8192, "cutoff_len": 8192,
"overwrite_cache": True, "overwrite_cache": True,
@ -45,7 +47,7 @@ TRAIN_ARGS = {
@pytest.mark.parametrize("num_samples", [16]) @pytest.mark.parametrize("num_samples", [16])
def test_unsupervised_data(num_samples: int): def test_unsupervised_data(num_samples: int):
train_dataset = load_train_dataset(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS) train_dataset = load_train_dataset(**TRAIN_ARGS)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train") original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples) indexes = random.choices(range(len(original_data)), k=num_samples)