fix #2376
This commit is contained in:
parent
901faa16cc
commit
4ecadc3512
|
@ -22,12 +22,8 @@ def preprocess_pretrain_dataset(
|
|||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...`
|
||||
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
for i in range(len(tokenized_examples["input_ids"])):
|
||||
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
|
||||
tokenized_examples["attention_mask"][i] += [1]
|
||||
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.cutoff_len
|
||||
|
|
|
@ -110,7 +110,7 @@ def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs:
|
|||
logger.warning("FlashAttention2 is not installed.")
|
||||
config_kwargs["attn_implementation"] = None
|
||||
else:
|
||||
config_kwargs["attn_implementation"] = "eager"
|
||||
config_kwargs["attn_implementation"] = "eager"
|
||||
|
||||
|
||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
|
|
Loading…
Reference in New Issue