fix system prompt

This commit is contained in:
hiyouga 2023-08-16 01:35:52 +08:00
parent 273135f595
commit 7407d9daa1
15 changed files with 170 additions and 152 deletions

View File

@ -47,15 +47,15 @@ def create_app(chat_model: ChatModel) -> FastAPI:
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != Role.USER: if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content query = request.messages[-1].content
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content system = prev_messages.pop(0).content
else: else:
prefix = None system = None
history = [] history = []
if len(prev_messages) % 2 == 0: if len(prev_messages) % 2 == 0:
@ -64,11 +64,11 @@ def create_app(chat_model: ChatModel) -> FastAPI:
history.append([prev_messages[i].content, prev_messages[i+1].content]) history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream: if request.stream:
generate = predict(query, history, prefix, request) generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat( response, (prompt_length, response_length) = chat_model.chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
) )
usage = ChatCompletionResponseUsage( usage = ChatCompletionResponseUsage(
@ -85,7 +85,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(role=Role.ASSISTANT), delta=DeltaMessage(role=Role.ASSISTANT),
@ -95,7 +95,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
): ):
if len(new_text) == 0: if len(new_text) == 0:
continue continue

View File

@ -16,19 +16,19 @@ class ChatModel:
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.model = self.model.eval() # enable evaluation mode self.model = self.model.eval() # enable evaluation mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.source_prefix = data_args.source_prefix self.system_prompt = data_args.system_prompt
def process_args( def process_args(
self, self,
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
prefix = prefix or self.source_prefix system = system or self.system_prompt
prompt, _ = self.template.encode_oneturn( prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
) )
input_ids = torch.tensor([prompt], device=self.model.device) input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0]) prompt_length = len(input_ids[0])
@ -68,10 +68,10 @@ class ChatModel:
self, self,
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Tuple[str, Tuple[int, int]]: ) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs) generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:] outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True) response = self.tokenizer.decode(outputs, skip_special_tokens=True)
@ -83,10 +83,10 @@ class ChatModel:
self, self,
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer

View File

@ -92,14 +92,13 @@ def get_dataset(
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.source_prefix: # add prefix if dataset_attr.system_prompt: # add system prompt
if data_args.streaming: if data_args.streaming:
features = dataset.features features = dataset.features
features["prefix"] = Value(dtype="string", id=None) features["system"] = Value(dtype="string", id=None)
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features) dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}, features=features)
else: else:
prefix_data = [dataset_attr.source_prefix] * len(dataset) dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
dataset = dataset.add_column("prefix", prefix_data)
all_datasets.append(dataset) all_datasets.append(dataset)

View File

@ -27,8 +27,8 @@ def preprocess_dataset(
query, response = examples["prompt"][i], examples["response"][i] query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None history = examples["history"][i] if "history" in examples else None
prefix = examples["prefix"][i] if "prefix" in examples else None system = examples["system"][i] if "system" in examples else None
yield query, response, history, prefix yield query, response, history, system
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `X1 X2 X3 ...` (without <eos>) # build grouped texts with format `X1 X2 X3 ...` (without <eos>)
@ -56,10 +56,10 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length max_length = data_args.max_source_length + data_args.max_target_length
for query, response, history, prefix in construct_example(examples): for query, response, history, system in construct_example(examples):
input_ids, labels = [], [] input_ids, labels = [], []
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, prefix): for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length: if len(target_ids) > data_args.max_target_length:
@ -78,11 +78,11 @@ def preprocess_dataset(
return model_inputs return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `<bos> Y` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, prefix in construct_example(examples): for query, response, history, system in construct_example(examples):
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, prefix) source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
@ -98,9 +98,9 @@ def preprocess_dataset(
def preprocess_pairwise_dataset(examples): def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, prefix in construct_example(examples): for query, response, history, system in construct_example(examples):
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix) _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
if len(prompt_ids) > data_args.max_source_length: if len(prompt_ids) > data_args.max_source_length:
prompt_ids = prompt_ids[:data_args.max_source_length] prompt_ids = prompt_ids[:data_args.max_source_length]

View File

@ -16,9 +16,11 @@ class Template:
prefix: List[Union[str, Dict[str, str]]] prefix: List[Union[str, Dict[str, str]]]
prompt: List[Union[str, Dict[str, str]]] prompt: List[Union[str, Dict[str, str]]]
system: str
sep: List[Union[str, Dict[str, str]]] sep: List[Union[str, Dict[str, str]]]
stop_words: List[str] stop_words: List[str]
use_history: bool use_history: bool
bos_after_prefix: bool
def encode_oneturn( def encode_oneturn(
self, self,
@ -26,18 +28,18 @@ class Template:
query: str, query: str,
resp: str, resp: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None system: Optional[str] = None
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
r""" r"""
Returns a single pair of token ids representing prompt and response respectively. Returns a single pair of token ids representing prompt and response respectively.
""" """
prefix, history = self._format(query, resp, history, prefix) system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, prefix, history) encoded_pairs = self._encode(tokenizer, system, history)
prompt_ids = [] prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]: for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0] prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
return prompt_ids, encoded_pairs[-1][1] return prompt_ids, answer_ids
def encode_multiturn( def encode_multiturn(
self, self,
@ -45,13 +47,13 @@ class Template:
query: str, query: str,
resp: str, resp: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None system: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
Returns multiple pairs of token ids representing prompts and responses respectively. Returns multiple pairs of token ids representing prompts and responses respectively.
""" """
prefix, history = self._format(query, resp, history, prefix) system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, prefix, history) encoded_pairs = self._encode(tokenizer, system, history)
return encoded_pairs return encoded_pairs
def _format( def _format(
@ -59,15 +61,15 @@ class Template:
query: str, query: str,
resp: str, resp: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None system: Optional[str] = None
) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]: ) -> Tuple[str, List[Tuple[str, str]]]:
r""" r"""
Aligns inputs to the standard format. Aligns inputs to the standard format.
""" """
prefix = [prefix] if prefix else self.prefix # use prefix if provided system = system or self.system # use system if provided
history = history if (history and self.use_history) else [] history = history if (history and self.use_history) else []
history = history + [(query, resp)] history = history + [(query, resp)]
return prefix, history return system, history
def _get_special_ids( def _get_special_ids(
self, self,
@ -88,7 +90,7 @@ class Template:
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
prefix: List[Union[str, Dict[str, str]]], system: str,
history: List[Tuple[str, str]] history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
@ -101,8 +103,12 @@ class Template:
encoded_pairs = [] encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history): for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: if turn_idx == 0:
if prefix: # has prefix prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
prefix_ids = bos_ids + self._convert_inputs_to_ids(tokenizer, context=prefix) + sep_ids if len(prefix_ids) != 0: # has prefix
if self.bos_after_prefix:
prefix_ids = prefix_ids + bos_ids + sep_ids
else:
prefix_ids = bos_ids + prefix_ids + sep_ids
else: else:
prefix_ids = bos_ids prefix_ids = bos_ids
else: else:
@ -117,8 +123,9 @@ class Template:
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]], context: List[Union[str, Dict[str, str]]],
query: Optional[str] = "", system: Optional[str] = None,
idx: Optional[str] = "" query: Optional[str] = None,
idx: Optional[str] = None
) -> List[int]: ) -> List[int]:
r""" r"""
Converts context to token ids. Converts context to token ids.
@ -131,13 +138,15 @@ class Template:
token_ids = [] token_ids = []
for elem in context: for elem in context:
if isinstance(elem, str): if isinstance(elem, str):
elem = elem.replace("{{query}}", query, 1) elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{idx}}", idx, 1) elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
token_ids = token_ids + tokenizer.encode(elem, **kwargs) token_ids = token_ids + tokenizer.encode(elem, **kwargs)
elif isinstance(elem, dict): elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
else: else:
raise NotImplementedError raise NotImplementedError
return token_ids return token_ids
@ -147,7 +156,7 @@ class Llama2Template(Template):
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
prefix: List[Union[str, Dict[str, str]]], system: str,
history: List[Tuple[str, str]] history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
@ -157,10 +166,9 @@ class Llama2Template(Template):
""" """
bos_ids, eos_ids = self._get_special_ids(tokenizer) bos_ids, eos_ids = self._get_special_ids(tokenizer)
encoded_pairs = [] encoded_pairs = []
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single string."
for turn_idx, (query, resp) in enumerate(history): for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: # llama2 template has not sep_ids if turn_idx == 0: # llama2 template has no sep_ids
query = prefix[0] + query query = self.prefix[0].replace("{{system}}", system) + query
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
@ -174,17 +182,21 @@ def register_template(
name: str, name: str,
prefix: List[Union[str, Dict[str, str]]], prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]], prompt: List[Union[str, Dict[str, str]]],
system: str,
sep: List[Union[str, Dict[str, str]]], sep: List[Union[str, Dict[str, str]]],
stop_words: List[str], stop_words: Optional[List[str]] = [],
use_history: bool use_history: Optional[bool] = True,
bos_after_prefix: Optional[bool] = False
) -> None: ) -> None:
template_class = Llama2Template if "llama2" in name else Template template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class( templates[name] = template_class(
prefix=prefix, prefix=prefix,
prompt=prompt, prompt=prompt,
system=system,
sep=sep, sep=sep,
stop_words=stop_words, stop_words=stop_words,
use_history=use_history use_history=use_history,
bos_after_prefix=bos_after_prefix
) )
@ -201,7 +213,7 @@ def get_template_and_fix_tokenizer(
if tokenizer.eos_token_id is not None: if tokenizer.eos_token_id is not None:
additional_special_tokens.append(tokenizer.eos_token) additional_special_tokens.append(tokenizer.eos_token)
tokenizer.eos_token = template.stop_words[0] tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
additional_special_tokens.pop(0) additional_special_tokens.pop(0)
logger.info("Replace eos token: {}".format(tokenizer.eos_token)) logger.info("Replace eos token: {}".format(tokenizer.eos_token))
@ -229,8 +241,8 @@ register_template(
prompt=[ prompt=[
"{{query}}" "{{query}}"
], ],
system="",
sep=[], sep=[],
stop_words=[],
use_history=False use_history=False
) )
@ -241,17 +253,18 @@ Default template.
register_template( register_template(
name="default", name="default",
prefix=[ prefix=[
"A chat between a curious user and an artificial intelligence assistant. " "{{system}}"
"The assistant gives helpful, detailed, and polite answers to the user's questions."
], ],
prompt=[ prompt=[
"Human: {{query}}\nAssistant: " "Human: {{query}}\nAssistant: "
], ],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[ sep=[
"\n" "\n"
], ]
stop_words=[],
use_history=True
) )
@ -263,21 +276,22 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
register_template( register_template(
name="llama2", name="llama2",
prefix=[ prefix=[
"<<SYS>>\nYou are a helpful, respectful and honest assistant. " "<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. " "Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, " "Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. " "racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n" "Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, " "If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. " "explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n" "If you don't know the answer to a question, please don't share false information."
], ),
prompt=[ sep=[]
"[INST] {{query}} [/INST] "
],
sep=[],
stop_words=[],
use_history=True
) )
@ -288,14 +302,13 @@ Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
register_template( register_template(
name="llama2_zh", name="llama2_zh",
prefix=[ prefix=[
"<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n" "<<SYS>>\n{{system}}\n<</SYS>>\n\n"
], ],
prompt=[ prompt=[
"[INST] {{query}} [/INST] " "[INST] {{query}} [/INST] "
], ],
sep=[], system="You are a helpful assistant. 你是一个乐于助人的助手。",
stop_words=[], sep=[]
use_history=True
) )
@ -306,17 +319,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
register_template( register_template(
name="alpaca", name="alpaca",
prefix=[ prefix=[
"Below is an instruction that describes a task. " "{{system}}"
"Write a response that appropriately completes the request."
], ],
prompt=[ prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n" "### Instruction:\n{{query}}\n\n### Response:\n"
], ],
system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
sep=[ sep=[
"\n\n" "\n\n"
], ]
stop_words=[],
use_history=True
) )
@ -327,15 +341,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
register_template( register_template(
name="vicuna", name="vicuna",
prefix=[ prefix=[
"A chat between a curious user and an artificial intelligence assistant. " "{{system}}"
"The assistant gives helpful, detailed, and polite answers to the user's questions."
], ],
prompt=[ prompt=[
"USER: {{query}} ASSISTANT: " "USER: {{query}} ASSISTANT: "
], ],
sep=[], system=(
stop_words=[], "A chat between a curious user and an artificial intelligence assistant. "
use_history=True "The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[]
) )
@ -344,15 +359,16 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
""" """
register_template( register_template(
name="belle", name="belle",
prefix=[], prefix=[
"{{system}}"
],
prompt=[ prompt=[
"Human: {{query}}\n\nBelle: " "Human: {{query}}\n\nBelle: "
], ],
system="",
sep=[ sep=[
"\n\n" "\n\n"
], ]
stop_words=[],
use_history=True
) )
@ -361,15 +377,16 @@ Supports: https://github.com/CVI-SZU/Linly
""" """
register_template( register_template(
name="linly", name="linly",
prefix=[], prefix=[
"{{system}}"
],
prompt=[ prompt=[
"User: {{query}}\nBot: " "User: {{query}}\nBot: "
], ],
system="",
sep=[ sep=[
"\n" "\n"
], ]
stop_words=[],
use_history=True
) )
@ -378,15 +395,16 @@ Supports: https://github.com/Neutralzz/BiLLa
""" """
register_template( register_template(
name="billa", name="billa",
prefix=[], prefix=[
"{{system}}"
],
prompt=[ prompt=[
"Human: {{query}}\nAssistant: " "Human: {{query}}\nAssistant: "
], ],
system="",
sep=[ sep=[
"\n" "\n"
], ]
stop_words=[],
use_history=True
) )
@ -395,18 +413,19 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
""" """
register_template( register_template(
name="ziya", name="ziya",
prefix=[], prefix=[
"{{system}}"
],
prompt=[ prompt=[
{"token": "<human>"}, {"token": "<human>"},
":{{query}}\n", ":{{query}}\n",
{"token": "<bot>"}, {"token": "<bot>"},
":" ":"
], ],
system="",
sep=[ sep=[
"\n" "\n"
], ]
stop_words=[],
use_history=True
) )
@ -416,17 +435,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b
register_template( register_template(
name="aquila", name="aquila",
prefix=[ prefix=[
"A chat between a curious human and an artificial intelligence assistant. " "{{system}}"
"The assistant gives helpful, detailed, and polite answers to the human's questions."
], ],
prompt=[ prompt=[
"Human: {{query}}###Assistant: " "Human: {{query}}###Assistant: "
], ],
system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
sep=[ sep=[
"###" "###"
], ]
stop_words=[],
use_history=True
) )
@ -435,19 +455,21 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b
""" """
register_template( register_template(
name="intern", name="intern",
prefix=[], prefix=[
"{{system}}"
],
prompt=[ prompt=[
"<|User|>:{{query}}", "<|User|>:{{query}}",
{"token": "<eoh>"}, {"token": "<eoh>"},
"\n<|Bot|>:" "\n<|Bot|>:"
], ],
system="",
sep=[ sep=[
"\n" "\n"
], ],
stop_words=[ stop_words=[
"<eoa>" "<eoa>"
], ]
use_history=True
) )
@ -457,17 +479,19 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
register_template( register_template(
name="baichuan", name="baichuan",
prefix=[ prefix=[
{"token": "<reserved_102>"} # user token (a little difference in the first turn) "{{system}}",
{"token": "<reserved_102>"} # user token
], ],
prompt=[ prompt=[
"{{query}}", "{{query}}",
{"token": "<reserved_103>"} # assistant token {"token": "<reserved_103>"} # assistant token
], ],
system="",
sep=[], sep=[],
stop_words=[ stop_words=[
"<reserved_102>" # user token "<reserved_102>" # user token
], ],
use_history=True bos_after_prefix=True
) )
@ -479,7 +503,7 @@ register_template(
name="starchat", name="starchat",
prefix=[ prefix=[
{"token": "<|system|>"}, {"token": "<|system|>"},
"\n", "\n{{system}}",
{"token": "<|end|>"} {"token": "<|end|>"}
], ],
prompt=[ prompt=[
@ -489,13 +513,13 @@ register_template(
"\n", "\n",
{"token": "<|assistant|>"} {"token": "<|assistant|>"}
], ],
system="",
sep=[ sep=[
"\n" "\n"
], ],
stop_words=[ stop_words=[
"<|end|>" "<|end|>"
], ]
use_history=True
) )
@ -506,7 +530,7 @@ register_template(
name="chatml", name="chatml",
prefix=[ prefix=[
{"token": "<|im_start|>"}, {"token": "<|im_start|>"},
"system\nYou are a helpful assistant.", "system\n{{system}}",
{"token": "<|im_end|>"} {"token": "<|im_end|>"}
], ],
prompt=[ prompt=[
@ -517,13 +541,13 @@ register_template(
{"token": "<|im_start|>"}, {"token": "<|im_start|>"},
"assistant\n" "assistant\n"
], ],
system="You are a helpful assistant.",
sep=[ sep=[
"\n" "\n"
], ],
stop_words=[ stop_words=[
"<|im_end|>" "<|im_end|>"
], ]
use_history=True
) )
@ -534,14 +558,14 @@ register_template(
name="chatglm2", name="chatglm2",
prefix=[ prefix=[
{"token": "[gMASK]"}, {"token": "[gMASK]"},
{"token": "sop"} {"token": "sop"},
"{{system}}"
], ],
prompt=[ prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:" "[Round {{idx}}]\n\n问:{{query}}\n\n答:"
], ],
system="",
sep=[ sep=[
"\n\n" "\n\n"
], ]
stop_words=[],
use_history=True
) )

View File

@ -10,7 +10,7 @@ class DatasetAttr:
load_from: str load_from: str
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None dataset_sha1: Optional[str] = None
source_prefix: Optional[str] = None system_prompt: Optional[str] = None
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
@ -86,9 +86,9 @@ class DataArguments:
default=True, default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
) )
source_prefix: Optional[str] = field( system_prompt: Optional[str] = field(
default=None, default=None,
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."} metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
) )
val_size: Optional[float] = field( val_size: Optional[float] = field(
default=0, default=0,
@ -100,12 +100,9 @@ class DataArguments:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
if self.source_prefix is not None: prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
prefix_list = self.source_prefix.split("|") prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
else:
prefix_list = [None] * len(dataset_names)
if self.interleave_probs is not None: if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")] self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
@ -126,12 +123,11 @@ class DataArguments:
dataset_sha1=dataset_info[name].get("file_sha1", None) dataset_sha1=dataset_info[name].get("file_sha1", None)
) )
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None) dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None) dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr) self.dataset_list.append(dataset_attr)

View File

@ -56,6 +56,5 @@ def run_pt(
perplexity = float("inf") perplexity = float("inf")
metrics["perplexity"] = perplexity metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)

View File

@ -26,7 +26,7 @@ class WebChatModel(ChatModel):
finetuning_type: str, finetuning_type: str,
quantization_bit: str, quantization_bit: str,
template: str, template: str,
source_prefix: str system_prompt: str
): ):
if self.model is not None: if self.model is not None:
yield ALERTS["err_exists"][lang] yield ALERTS["err_exists"][lang]
@ -55,7 +55,7 @@ class WebChatModel(ChatModel):
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template, template=template,
source_prefix=source_prefix system_prompt=system_prompt
) )
super().__init__(args) super().__init__(args)
@ -73,7 +73,7 @@ class WebChatModel(ChatModel):
chatbot: List[Tuple[str, str]], chatbot: List[Tuple[str, str]],
query: str, query: str,
history: List[Tuple[str, str]], history: List[Tuple[str, str]],
prefix: str, system: str,
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float temperature: float
@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
chatbot.append([query, ""]) chatbot.append([query, ""])
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
response = self.postprocess(response) response = self.postprocess(response)

View File

@ -17,7 +17,7 @@ def create_chat_box(
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
prefix = gr.Textbox(show_label=False) system = gr.Textbox(show_label=False)
query = gr.Textbox(show_label=False, lines=8) query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary") submit_btn = gr.Button(variant="primary")
@ -31,7 +31,7 @@ def create_chat_box(
submit_btn.click( submit_btn.click(
chat_model.predict, chat_model.predict,
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature], [chatbot, query, history, system, max_new_tokens, top_p, temperature],
[chatbot, history], [chatbot, history],
show_progress=True show_progress=True
).then( ).then(
@ -41,7 +41,7 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict( return chat_box, chatbot, history, dict(
prefix=prefix, system=system,
query=query, query=query,
submit_btn=submit_btn, submit_btn=submit_btn,
clear_btn=clear_btn, clear_btn=clear_btn,

View File

@ -52,7 +52,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
top_elems["finetuning_type"], top_elems["finetuning_type"],
top_elems["quantization_bit"], top_elems["quantization_bit"],
top_elems["template"], top_elems["template"],
top_elems["source_prefix"], top_elems["system_prompt"],
dataset_dir, dataset_dir,
dataset, dataset,
max_source_length, max_source_length,

View File

@ -28,7 +28,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
top_elems["finetuning_type"], top_elems["finetuning_type"],
top_elems["quantization_bit"], top_elems["quantization_bit"],
top_elems["template"], top_elems["template"],
top_elems["source_prefix"] top_elems["system_prompt"]
], ],
[info_box] [info_box]
).then( ).then(

View File

@ -28,7 +28,7 @@ def create_top() -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1) quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1) template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
source_prefix = gr.Textbox(scale=2) system_prompt = gr.Textbox(scale=2)
lang.change(save_config, [lang, model_name, model_path]) lang.change(save_config, [lang, model_name, model_path])
@ -62,5 +62,5 @@ def create_top() -> Dict[str, "Component"]:
advanced_tab=advanced_tab, advanced_tab=advanced_tab,
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
template=template, template=template,
source_prefix=source_prefix system_prompt=system_prompt
) )

View File

@ -101,7 +101,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
top_elems["finetuning_type"], top_elems["finetuning_type"],
top_elems["quantization_bit"], top_elems["quantization_bit"],
top_elems["template"], top_elems["template"],
top_elems["source_prefix"], top_elems["system_prompt"],
training_stage, training_stage,
dataset_dir, dataset_dir,
dataset, dataset,

View File

@ -77,7 +77,7 @@ LOCALES = {
"info": "构建提示词时使用的模板" "info": "构建提示词时使用的模板"
} }
}, },
"source_prefix": { "system_prompt": {
"en": { "en": {
"label": "System prompt (optional)", "label": "System prompt (optional)",
"info": "A sequence used as the default system prompt." "info": "A sequence used as the default system prompt."
@ -455,7 +455,7 @@ LOCALES = {
"value": "模型未加载,请先加载模型。" "value": "模型未加载,请先加载模型。"
} }
}, },
"prefix": { "system": {
"en": { "en": {
"placeholder": "System prompt (optional)" "placeholder": "System prompt (optional)"
}, },

View File

@ -69,7 +69,7 @@ class Runner:
finetuning_type: str, finetuning_type: str,
quantization_bit: str, quantization_bit: str,
template: str, template: str,
source_prefix: str, system_prompt: str,
training_stage: str, training_stage: str,
dataset_dir: str, dataset_dir: str,
dataset: List[str], dataset: List[str],
@ -114,7 +114,7 @@ class Runner:
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template, template=template,
source_prefix=source_prefix, system_prompt=system_prompt,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),
max_source_length=max_source_length, max_source_length=max_source_length,
@ -170,7 +170,7 @@ class Runner:
finetuning_type: str, finetuning_type: str,
quantization_bit: str, quantization_bit: str,
template: str, template: str,
source_prefix: str, system_prompt: str,
dataset_dir: str, dataset_dir: str,
dataset: List[str], dataset: List[str],
max_source_length: int, max_source_length: int,
@ -198,7 +198,7 @@ class Runner:
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template, template=template,
source_prefix=source_prefix, system_prompt=system_prompt,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),
max_source_length=max_source_length, max_source_length=max_source_length,