fix unittest
This commit is contained in:
parent
608de799a2
commit
e80006795f
|
@ -33,6 +33,8 @@ TRAIN_ARGS = {
|
||||||
"stage": "sft",
|
"stage": "sft",
|
||||||
"do_predict": True,
|
"do_predict": True,
|
||||||
"finetuning_type": "full",
|
"finetuning_type": "full",
|
||||||
|
"eval_dataset": "system_chat",
|
||||||
|
"dataset_dir": "REMOTE:" + DEMO_DATA,
|
||||||
"template": "llama3",
|
"template": "llama3",
|
||||||
"cutoff_len": 8192,
|
"cutoff_len": 8192,
|
||||||
"overwrite_cache": True,
|
"overwrite_cache": True,
|
||||||
|
@ -45,7 +47,7 @@ TRAIN_ARGS = {
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_samples", [16])
|
@pytest.mark.parametrize("num_samples", [16])
|
||||||
def test_unsupervised_data(num_samples: int):
|
def test_unsupervised_data(num_samples: int):
|
||||||
train_dataset = load_train_dataset(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)
|
train_dataset = load_train_dataset(**TRAIN_ARGS)
|
||||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||||
|
|
Loading…
Reference in New Issue