Support safe ChatML template, fix qwen tok #351 #354

https://github.com/openai/openai-python/blob/main/chatml.md
This commit is contained in:
hoshi-hiyouga 2023-08-05 00:00:23 +08:00 committed by GitHub
commit f30fc3b030
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 306 additions and 149 deletions

View File

@ -30,10 +30,11 @@ class ChatModel:
) -> Tuple[Dict[str, Any], int]:
prefix = prefix or self.source_prefix
prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token)
inputs = self.tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.model.device)
prompt_length = len(inputs["input_ids"][0])
prompt, _ = self.template.get_prompt(
tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix
)
input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0])
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
@ -45,7 +46,7 @@ class ChatModel:
gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict(
input_ids=inputs["input_ids"],
input_ids=input_ids,
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"],

View File

@ -30,7 +30,7 @@ def preprocess_dataset(
yield query, response, history, prefix
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
# build grouped texts with format `X1 X2 X3 ...` (without <eos>)
tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
@ -55,20 +55,17 @@ def preprocess_dataset(
for query, response, history, prefix in construct_example(examples):
input_ids, labels = [], []
for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)):
source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False)
for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix):
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
if len(target_ids) > data_args.max_target_length:
target_ids = target_ids[:data_args.max_target_length]
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
break
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
input_ids += source_ids + target_ids
labels += [IGNORE_INDEX] * len(source_ids) + target_ids
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
@ -81,10 +78,7 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, prefix in construct_example(examples):
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
target_ids = tokenizer.encode(text=response, add_special_tokens=True)
source_ids, target_ids = template.get_prompt(tokenizer, query, response, history, prefix)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
@ -101,11 +95,8 @@ def preprocess_dataset(
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []}
for query, response, history, prefix in construct_example(examples):
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False)
source_ids, accept_ids = template.get_prompt(tokenizer, query, response[0], history, prefix)
source_ids, reject_ids = template.get_prompt(tokenizer, query, response[1], history, prefix)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]

View File

@ -1,82 +1,153 @@
from typing import Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@dataclass
class Template:
prefix: str
prompt: str
sep: str
use_history: bool
prefix: List[Union[str, Dict[str, str]]]
prompt: List[Union[str, Dict[str, str]]]
sep: List[Union[str, Dict[str, str]]]
stop_words: List[str]
use_history: bool
def get_prompt(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = "",
eos_token: Optional[str] = "</s>"
) -> str:
prefix: Optional[str] = None
) -> Tuple[List[int], List[int]]:
r"""
Returns a string containing prompt without response.
Returns a single pair of token ids representing prompt and response respectively.
"""
return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix)))
prefix, history = self._format(query=query, resp=resp, history=history, prefix=prefix)
encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
return prompt_ids, encoded_pairs[-1][1]
def get_dialog(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
prefix, history = self._format(query=query, resp=resp, history=history, prefix=prefix)
encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history)
return encoded_pairs
def _format(
self,
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
prefix: Optional[str] = None
) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]:
r"""
Returns a list containing prompt-response pairs.
Aligns inputs to a special format.
"""
result = self._format_example(query, history, prefix)
result[-1][-1] = resp
return result
def _format_example(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
prefix = [prefix] if prefix is not None else self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else [] # add separator for non-empty prefix
history = history if (history and self.use_history) else []
history = history + [(query, "")]
return [
[(self.sep if i else prefix) + self.prompt.format(query=q), r]
for i, (q, r) in enumerate(history)
]
history = history + [(query, resp)]
return prefix, history
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
prefix: List[Union[str, Dict[str, str]]],
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
"""
if tokenizer.bos_token and getattr(tokenizer, "add_bos_token", False): # bos token is optional
bos_token_id = [tokenizer.bos_token_id]
else:
bos_token_id = []
eos_token_id = [tokenizer.eos_token_id] # eos token is required
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix)
else:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id))
return encoded_pairs
def _convert_inputs_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]],
query: Optional[str] = ""
) -> List[int]:
r"""
Converts context to token ids.
"""
token_ids = []
for elem in context:
if isinstance(elem, str):
subelems = elem.split("{{query}}")
if len(subelems) > 1:
elem = subelems[0] + query + subelems[1]
token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
else:
raise NotImplementedError
return token_ids
@dataclass
class Llama2Template(Template):
def _format_example(
def _encode(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix if prefix.startswith("<<SYS>>") else "<<SYS>>\n{}\n<</SYS>>\n\n".format(prefix)
history = history if (history and self.use_history) else []
history = history + [(query, "")]
return [
[(self.sep if i else "") + self.prompt.format(query=(q if i else prefix + q)), r]
for i, (q, r) in enumerate(history)
]
tokenizer: "PreTrainedTokenizer",
prefix: List[Union[str, Dict[str, str]]],
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
"""
encoded_pairs = []
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str."
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
prefix_ids = []
query = prefix[0] + query
else:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((prefix_ids + query_ids, resp_ids))
return encoded_pairs
templates: Dict[str, Template] = {}
def register_template(
name: str, prefix: str, prompt: str, sep: str, use_history: bool, stop_words: List[str]
name: str,
prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]],
sep: List[Union[str, Dict[str, str]]],
stop_words: List[str],
use_history: bool
) -> None:
template_class = Llama2Template if name == "llama2" else Template
templates[name] = template_class(
@ -99,11 +170,13 @@ Supports language model inference without histories.
"""
register_template(
name="vanilla",
prefix="",
prompt="{query}",
sep="",
use_history=False,
stop_words=[]
prefix=[],
prompt=[
"{{query}}"
],
sep=[],
stop_words=[],
use_history=False
)
@ -112,12 +185,18 @@ Default template.
"""
register_template(
name="default",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True,
stop_words=[]
prefix=[
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
],
prompt=[
"Human: {{query}}\nAssistant: "
],
sep=[
"\n"
],
stop_words=[],
use_history=True
)
@ -128,18 +207,24 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
"""
register_template(
name="llama2",
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
prompt="[INST] {query} [/INST] ",
sep="<s>",
use_history=True,
stop_words=[]
prefix=[
"<<SYS>>\nYou are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
sep=[
{"token": "<s>"}
],
stop_words=[],
use_history=True
)
@ -149,12 +234,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
"""
register_template(
name="alpaca",
prefix="Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.",
prompt="### Instruction:\n{query}\n\n### Response:\n",
sep="\n\n",
use_history=True,
stop_words=[]
prefix=[
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
sep=[
"\n\n"
],
stop_words=[],
use_history=True
)
@ -164,12 +255,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
"""
register_template(
name="vicuna",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ",
sep="",
use_history=True,
stop_words=[]
prefix=[
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
],
prompt=[
"USER: {{query}} ASSISTANT: "
],
sep=[],
stop_words=[],
use_history=True
)
@ -178,11 +273,15 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix="",
prompt="Human: {query}\n\nBelle: ",
sep="\n\n",
use_history=True,
stop_words=[]
prefix=[],
prompt=[
"Human: {{query}}\n\nBelle: "
],
sep=[
"\n\n"
],
stop_words=[],
use_history=True
)
@ -191,11 +290,15 @@ Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
prefix="",
prompt="User: {query}\nBot: ",
sep="\n",
use_history=True,
stop_words=[]
prefix=[],
prompt=[
"User: {{query}}\nBot: "
],
sep=[
"\n"
],
stop_words=[],
use_history=True
)
@ -204,11 +307,15 @@ Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
prefix="",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True,
stop_words=[]
prefix=[],
prompt=[
"Human: {{query}}\nAssistant: "
],
sep=[
"\n"
],
stop_words=[],
use_history=True
)
@ -217,11 +324,18 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
register_template(
name="ziya",
prefix="",
prompt="<human>:{query}\n<bot>:",
sep="\n",
use_history=True,
stop_words=[]
prefix=[],
prompt=[
{"token": "<human>"},
":{{query}}\n",
{"token": "<bot>"},
":"
],
sep=[
"\n"
],
stop_words=[],
use_history=True
)
@ -230,12 +344,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b
"""
register_template(
name="aquila",
prefix="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
prompt="Human: {query}###Assistant: ",
sep="###",
use_history=True,
stop_words=[]
prefix=[
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
],
prompt=[
"Human: {{query}}###Assistant: "
],
sep=[
"###"
],
stop_words=[],
use_history=True
)
@ -244,11 +364,23 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
register_template(
name="intern",
prefix="",
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
sep="<eoa>\n",
use_history=True,
stop_words=["<eoa>"]
prefix=[],
prompt=[
{"token": "<|User|>"},
":{{query}}",
{"token": "<eoh>"},
"\n",
{"token": "<|Bot|>"},
":"
],
sep=[
{"token": "<eoa>"},
"\n"
],
stop_words=[
"<eoa>"
],
use_history=True
)
@ -257,11 +389,15 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
prefix="",
prompt="<reserved_102>{query}<reserved_103>",
sep="",
use_history=True,
stop_words=[]
prefix=[],
prompt=[
{"token": "<reserved_102>"},
"{{query}}",
{"token": "<reserved_103>"}
],
sep=[],
stop_words=[],
use_history=True
)
@ -271,11 +407,25 @@ Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
"""
register_template(
name="starchat",
prefix="<|system|>\n",
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
sep="<|end|>\n",
use_history=True,
stop_words=["<|end|>"]
prefix=[
{"token": "<|system|>"},
"\n"
],
prompt=[
{"token": "<|user|>"},
"\n{{query}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
],
sep=[
{"token": "<|end|>"},
"\n"
],
stop_words=[
"<|end|>"
],
use_history=True
)
@ -284,9 +434,24 @@ Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
"""
register_template(
name="chatml",
prefix="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n",
sep="<|im_end|>\n",
use_history=True,
stop_words=["<|im_end|>"]
prefix=[
{"token": "<|im_start|>"},
"system\nYou are a helpful assistant."
],
prompt=[
{"token": "<|im_start|>"},
"user\n{{query}}",
{"token": "<|im_end|>"},
"\n",
{"token": "<|im_start|>"},
"assistant\n"
],
sep=[
{"token": "<|im_end|>"},
"\n"
],
stop_words=[
"<|im_end|>"
],
use_history=True
)