fix sft encode
This commit is contained in:
parent
1af031c02b
commit
b2f7cb4465
|
@ -505,7 +505,7 @@ def preprocess_data(
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
|
|
||||||
for i in range(len(dialog) // 2):
|
for i in range(len(dialog) // 2):
|
||||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True)
|
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
|
||||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length:
|
if len(source_ids) > data_args.max_source_length:
|
||||||
|
|
Loading…
Reference in New Issue