This commit is contained in:
hiyouga 2024-02-03 23:14:31 +08:00
parent 901faa16cc
commit 4ecadc3512
2 changed files with 2 additions and 6 deletions

View File

@ -22,12 +22,8 @@ def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
for i in range(len(tokenized_examples["input_ids"])):
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
tokenized_examples["attention_mask"][i] += [1]
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len

View File

@ -110,7 +110,7 @@ def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs:
logger.warning("FlashAttention2 is not installed.")
config_kwargs["attn_implementation"] = None
else:
config_kwargs["attn_implementation"] = "eager"
config_kwargs["attn_implementation"] = "eager"
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: