modify some style

This commit is contained in:
BUAADreamer 2024-04-25 22:04:09 +08:00
parent 2d4ded535f
commit c27f7fbf62
4 changed files with 50 additions and 8 deletions

View File

@ -324,10 +324,7 @@ def get_preprocess_and_print_func(
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset,
tokenizer=tokenizer,
template=template,
data_args=data_args,
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)

View File

@ -11,6 +11,7 @@ if TYPE_CHECKING:
from .formatter import SLOTS, Formatter
logger = get_logger(__name__)
@ -103,9 +104,7 @@ class Template:
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
def _convert_elements_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
elements: List[Union[str, Dict[str, str]]],
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
) -> List[int]:
r"""
Converts elements to token ids.
@ -391,6 +390,7 @@ _register_template(
),
)
_register_template(
name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
@ -403,6 +403,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="atom",
format_user=StringFormatter(
@ -411,18 +412,21 @@ _register_template(
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
)
_register_template(
name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
_register_template(
name="baichuan2",
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
efficient_eos=True,
)
_register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
@ -431,11 +435,13 @@ _register_template(
force_system=True,
)
_register_template(
name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
_register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
@ -447,6 +453,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
@ -456,6 +463,7 @@ _register_template(
force_system=True,
)
_register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
@ -470,6 +478,7 @@ _register_template(
force_system=True,
)
_register_template(
name="chatglm3_system",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
@ -489,6 +498,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -499,6 +509,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -510,12 +521,14 @@ _register_template(
replace_eos=True,
)
_register_template(
name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="cohere",
format_user=StringFormatter(
@ -530,6 +543,7 @@ _register_template(
force_system=True,
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
@ -537,6 +551,7 @@ _register_template(
force_system=True,
)
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -562,6 +577,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
@ -569,6 +585,7 @@ _register_template(
force_system=True,
)
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
@ -584,6 +601,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
@ -591,12 +609,14 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@ -604,12 +624,14 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
@ -622,6 +644,7 @@ _register_template(
force_system=True,
)
_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
@ -630,6 +653,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -646,6 +670,7 @@ _register_template(
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
)
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
@ -662,6 +687,7 @@ _register_template(
),
)
_register_template(
name="llama2_zh",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
@ -669,6 +695,7 @@ _register_template(
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
)
_register_template(
name="llama3",
format_user=StringFormatter(
@ -695,6 +722,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
@ -702,6 +730,7 @@ _register_template(
force_system=True,
)
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
@ -710,6 +739,7 @@ _register_template(
force_system=True,
)
_register_template(
name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
@ -718,6 +748,7 @@ _register_template(
force_system=True,
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
@ -725,6 +756,7 @@ _register_template(
force_system=True,
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
@ -736,6 +768,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -747,6 +780,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
@ -754,6 +788,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="starchat",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
@ -764,6 +799,7 @@ _register_template(
force_system=True,
)
_register_template(
name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
@ -773,6 +809,7 @@ _register_template(
),
)
_register_template(
name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
@ -783,11 +820,13 @@ _register_template(
),
)
_register_template(
name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
)
_register_template(
name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
@ -807,6 +846,7 @@ _register_template(
stop_words=["<|End|>"],
)
_register_template(
name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -815,6 +855,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
@ -823,6 +864,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
@ -831,12 +873,14 @@ _register_template(
default_system="You are a friendly chatbot who always responds in the style of a pirate",
)
_register_template(
name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT: "]),

View File

@ -43,7 +43,7 @@ def run_sft(
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=(IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id),
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
)
# Override the decoding parameters of Seq2SeqTrainer

View File

@ -19,6 +19,7 @@ from .sft import run_sft
if TYPE_CHECKING:
from transformers import TrainerCallback
logger = get_logger(__name__)