fix unittest
This commit is contained in:
parent
608de799a2
commit
e80006795f
|
@ -33,6 +33,8 @@ TRAIN_ARGS = {
|
|||
"stage": "sft",
|
||||
"do_predict": True,
|
||||
"finetuning_type": "full",
|
||||
"eval_dataset": "system_chat",
|
||||
"dataset_dir": "REMOTE:" + DEMO_DATA,
|
||||
"template": "llama3",
|
||||
"cutoff_len": 8192,
|
||||
"overwrite_cache": True,
|
||||
|
@ -45,7 +47,7 @@ TRAIN_ARGS = {
|
|||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_unsupervised_data(num_samples: int):
|
||||
train_dataset = load_train_dataset(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)
|
||||
train_dataset = load_train_dataset(**TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
|
|
Loading…
Reference in New Issue