From b2f7cb446591e3722b5be8d250ddfe0caa226384 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 11 Jul 2023 19:50:33 +0800 Subject: [PATCH] fix sft encode --- src/utils/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/common.py b/src/utils/common.py index 917bd867..648f226f 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -505,7 +505,7 @@ def preprocess_data( input_ids, labels = [], [] for i in range(len(dialog) // 2): - source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True) + source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0)) target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) if len(source_ids) > data_args.max_source_length: