fix system prompt
This commit is contained in:
parent
273135f595
commit
7407d9daa1
|
@ -47,15 +47,15 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if request.messages[-1].role != Role.USER:
|
||||
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
|
||||
query = request.messages[-1].content
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
||||
prefix = prev_messages.pop(0).content
|
||||
system = prev_messages.pop(0).content
|
||||
else:
|
||||
prefix = None
|
||||
system = None
|
||||
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
|
@ -64,11 +64,11 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
|
||||
if request.stream:
|
||||
generate = predict(query, history, prefix, request)
|
||||
generate = predict(query, history, system, request)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
response, (prompt_length, response_length) = chat_model.chat(
|
||||
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||
)
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
|
@ -85,7 +85,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
|
||||
|
||||
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
|
||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role=Role.ASSISTANT),
|
||||
|
@ -95,7 +95,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
for new_text in chat_model.stream_chat(
|
||||
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||
):
|
||||
if len(new_text) == 0:
|
||||
continue
|
||||
|
|
|
@ -16,19 +16,19 @@ class ChatModel:
|
|||
self.model = dispatch_model(self.model)
|
||||
self.model = self.model.eval() # enable evaluation mode
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.system_prompt = data_args.system_prompt
|
||||
|
||||
def process_args(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
prefix = prefix or self.source_prefix
|
||||
system = system or self.system_prompt
|
||||
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix
|
||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
||||
)
|
||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||
prompt_length = len(input_ids[0])
|
||||
|
@ -68,10 +68,10 @@ class ChatModel:
|
|||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
|
||||
generation_output = self.model.generate(**gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][prompt_length:]
|
||||
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
|
@ -83,10 +83,10 @@ class ChatModel:
|
|||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
||||
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
|
||||
|
|
|
@ -92,14 +92,13 @@ def get_dataset(
|
|||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
||||
if dataset_attr.source_prefix: # add prefix
|
||||
if dataset_attr.system_prompt: # add system prompt
|
||||
if data_args.streaming:
|
||||
features = dataset.features
|
||||
features["prefix"] = Value(dtype="string", id=None)
|
||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
|
||||
features["system"] = Value(dtype="string", id=None)
|
||||
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}, features=features)
|
||||
else:
|
||||
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
||||
dataset = dataset.add_column("prefix", prefix_data)
|
||||
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
|
||||
|
||||
all_datasets.append(dataset)
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@ def preprocess_dataset(
|
|||
query, response = examples["prompt"][i], examples["response"][i]
|
||||
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
||||
history = examples["history"][i] if "history" in examples else None
|
||||
prefix = examples["prefix"][i] if "prefix" in examples else None
|
||||
yield query, response, history, prefix
|
||||
system = examples["system"][i] if "system" in examples else None
|
||||
yield query, response, history, system
|
||||
|
||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` (without <eos>)
|
||||
|
@ -56,10 +56,10 @@ def preprocess_dataset(
|
|||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
for query, response, history, system in construct_example(examples):
|
||||
input_ids, labels = [], []
|
||||
|
||||
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, prefix):
|
||||
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
|
@ -78,11 +78,11 @@ def preprocess_dataset(
|
|||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, prefix)
|
||||
for query, response, history, system in construct_example(examples):
|
||||
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
|
@ -98,9 +98,9 @@ def preprocess_dataset(
|
|||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix)
|
||||
for query, response, history, system in construct_example(examples):
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
||||
|
||||
if len(prompt_ids) > data_args.max_source_length:
|
||||
prompt_ids = prompt_ids[:data_args.max_source_length]
|
||||
|
|
|
@ -16,9 +16,11 @@ class Template:
|
|||
|
||||
prefix: List[Union[str, Dict[str, str]]]
|
||||
prompt: List[Union[str, Dict[str, str]]]
|
||||
system: str
|
||||
sep: List[Union[str, Dict[str, str]]]
|
||||
stop_words: List[str]
|
||||
use_history: bool
|
||||
bos_after_prefix: bool
|
||||
|
||||
def encode_oneturn(
|
||||
self,
|
||||
|
@ -26,18 +28,18 @@ class Template:
|
|||
query: str,
|
||||
resp: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None
|
||||
system: Optional[str] = None
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
prefix, history = self._format(query, resp, history, prefix)
|
||||
encoded_pairs = self._encode(tokenizer, prefix, history)
|
||||
system, history = self._format(query, resp, history, system)
|
||||
encoded_pairs = self._encode(tokenizer, system, history)
|
||||
prompt_ids = []
|
||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||
prompt_ids = prompt_ids + query_ids + resp_ids
|
||||
prompt_ids = prompt_ids + encoded_pairs[-1][0]
|
||||
return prompt_ids, encoded_pairs[-1][1]
|
||||
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
|
||||
return prompt_ids, answer_ids
|
||||
|
||||
def encode_multiturn(
|
||||
self,
|
||||
|
@ -45,13 +47,13 @@ class Template:
|
|||
query: str,
|
||||
resp: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None
|
||||
system: Optional[str] = None
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
prefix, history = self._format(query, resp, history, prefix)
|
||||
encoded_pairs = self._encode(tokenizer, prefix, history)
|
||||
system, history = self._format(query, resp, history, system)
|
||||
encoded_pairs = self._encode(tokenizer, system, history)
|
||||
return encoded_pairs
|
||||
|
||||
def _format(
|
||||
|
@ -59,15 +61,15 @@ class Template:
|
|||
query: str,
|
||||
resp: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None
|
||||
) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]:
|
||||
system: Optional[str] = None
|
||||
) -> Tuple[str, List[Tuple[str, str]]]:
|
||||
r"""
|
||||
Aligns inputs to the standard format.
|
||||
"""
|
||||
prefix = [prefix] if prefix else self.prefix # use prefix if provided
|
||||
system = system or self.system # use system if provided
|
||||
history = history if (history and self.use_history) else []
|
||||
history = history + [(query, resp)]
|
||||
return prefix, history
|
||||
return system, history
|
||||
|
||||
def _get_special_ids(
|
||||
self,
|
||||
|
@ -88,7 +90,7 @@ class Template:
|
|||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
prefix: List[Union[str, Dict[str, str]]],
|
||||
system: str,
|
||||
history: List[Tuple[str, str]]
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
|
@ -101,8 +103,12 @@ class Template:
|
|||
encoded_pairs = []
|
||||
for turn_idx, (query, resp) in enumerate(history):
|
||||
if turn_idx == 0:
|
||||
if prefix: # has prefix
|
||||
prefix_ids = bos_ids + self._convert_inputs_to_ids(tokenizer, context=prefix) + sep_ids
|
||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
|
||||
if len(prefix_ids) != 0: # has prefix
|
||||
if self.bos_after_prefix:
|
||||
prefix_ids = prefix_ids + bos_ids + sep_ids
|
||||
else:
|
||||
prefix_ids = bos_ids + prefix_ids + sep_ids
|
||||
else:
|
||||
prefix_ids = bos_ids
|
||||
else:
|
||||
|
@ -117,8 +123,9 @@ class Template:
|
|||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
context: List[Union[str, Dict[str, str]]],
|
||||
query: Optional[str] = "",
|
||||
idx: Optional[str] = ""
|
||||
system: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
idx: Optional[str] = None
|
||||
) -> List[int]:
|
||||
r"""
|
||||
Converts context to token ids.
|
||||
|
@ -131,13 +138,15 @@ class Template:
|
|||
token_ids = []
|
||||
for elem in context:
|
||||
if isinstance(elem, str):
|
||||
elem = elem.replace("{{query}}", query, 1)
|
||||
elem = elem.replace("{{idx}}", idx, 1)
|
||||
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
|
||||
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
|
||||
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
|
||||
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
||||
elif isinstance(elem, dict):
|
||||
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return token_ids
|
||||
|
||||
|
||||
|
@ -147,7 +156,7 @@ class Llama2Template(Template):
|
|||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
prefix: List[Union[str, Dict[str, str]]],
|
||||
system: str,
|
||||
history: List[Tuple[str, str]]
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
|
@ -157,10 +166,9 @@ class Llama2Template(Template):
|
|||
"""
|
||||
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||
encoded_pairs = []
|
||||
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single string."
|
||||
for turn_idx, (query, resp) in enumerate(history):
|
||||
if turn_idx == 0: # llama2 template has not sep_ids
|
||||
query = prefix[0] + query
|
||||
if turn_idx == 0: # llama2 template has no sep_ids
|
||||
query = self.prefix[0].replace("{{system}}", system) + query
|
||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
|
||||
|
@ -174,17 +182,21 @@ def register_template(
|
|||
name: str,
|
||||
prefix: List[Union[str, Dict[str, str]]],
|
||||
prompt: List[Union[str, Dict[str, str]]],
|
||||
system: str,
|
||||
sep: List[Union[str, Dict[str, str]]],
|
||||
stop_words: List[str],
|
||||
use_history: bool
|
||||
stop_words: Optional[List[str]] = [],
|
||||
use_history: Optional[bool] = True,
|
||||
bos_after_prefix: Optional[bool] = False
|
||||
) -> None:
|
||||
template_class = Llama2Template if "llama2" in name else Template
|
||||
templates[name] = template_class(
|
||||
prefix=prefix,
|
||||
prompt=prompt,
|
||||
system=system,
|
||||
sep=sep,
|
||||
stop_words=stop_words,
|
||||
use_history=use_history
|
||||
use_history=use_history,
|
||||
bos_after_prefix=bos_after_prefix
|
||||
)
|
||||
|
||||
|
||||
|
@ -201,7 +213,7 @@ def get_template_and_fix_tokenizer(
|
|||
if tokenizer.eos_token_id is not None:
|
||||
additional_special_tokens.append(tokenizer.eos_token)
|
||||
|
||||
tokenizer.eos_token = template.stop_words[0]
|
||||
tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
|
||||
additional_special_tokens.pop(0)
|
||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
|
@ -229,8 +241,8 @@ register_template(
|
|||
prompt=[
|
||||
"{{query}}"
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
stop_words=[],
|
||||
use_history=False
|
||||
)
|
||||
|
||||
|
@ -241,17 +253,18 @@ Default template.
|
|||
register_template(
|
||||
name="default",
|
||||
prefix=[
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\nAssistant: "
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
sep=[
|
||||
"\n"
|
||||
],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -263,21 +276,22 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
|||
register_template(
|
||||
name="llama2",
|
||||
prefix=[
|
||||
"<<SYS>>\nYou are a helpful, respectful and honest assistant. "
|
||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST] "
|
||||
],
|
||||
system=(
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST] "
|
||||
],
|
||||
sep=[],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
|
@ -288,14 +302,13 @@ Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
|
|||
register_template(
|
||||
name="llama2_zh",
|
||||
prefix=[
|
||||
"<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n"
|
||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST] "
|
||||
],
|
||||
sep=[],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
|
@ -306,17 +319,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
|||
register_template(
|
||||
name="alpaca",
|
||||
prefix=[
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request."
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"### Instruction:\n{{query}}\n\n### Response:\n"
|
||||
],
|
||||
system=(
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request."
|
||||
),
|
||||
sep=[
|
||||
"\n\n"
|
||||
],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -327,15 +341,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
|
|||
register_template(
|
||||
name="vicuna",
|
||||
prefix=[
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"USER: {{query}} ASSISTANT: "
|
||||
],
|
||||
sep=[],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
|
@ -344,15 +359,16 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
|||
"""
|
||||
register_template(
|
||||
name="belle",
|
||||
prefix=[],
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\n\nBelle: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n\n"
|
||||
],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -361,15 +377,16 @@ Supports: https://github.com/CVI-SZU/Linly
|
|||
"""
|
||||
register_template(
|
||||
name="linly",
|
||||
prefix=[],
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"User: {{query}}\nBot: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -378,15 +395,16 @@ Supports: https://github.com/Neutralzz/BiLLa
|
|||
"""
|
||||
register_template(
|
||||
name="billa",
|
||||
prefix=[],
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\nAssistant: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -395,18 +413,19 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
|||
"""
|
||||
register_template(
|
||||
name="ziya",
|
||||
prefix=[],
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<human>"},
|
||||
":{{query}}\n",
|
||||
{"token": "<bot>"},
|
||||
":"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -416,17 +435,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b
|
|||
register_template(
|
||||
name="aquila",
|
||||
prefix=[
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}###Assistant: "
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
),
|
||||
sep=[
|
||||
"###"
|
||||
],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -435,19 +455,21 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b
|
|||
"""
|
||||
register_template(
|
||||
name="intern",
|
||||
prefix=[],
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"<|User|>:{{query}}",
|
||||
{"token": "<eoh>"},
|
||||
"\n<|Bot|>:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<eoa>"
|
||||
],
|
||||
use_history=True
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -457,17 +479,19 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
|||
register_template(
|
||||
name="baichuan",
|
||||
prefix=[
|
||||
{"token": "<reserved_102>"} # user token (a little difference in the first turn)
|
||||
"{{system}}",
|
||||
{"token": "<reserved_102>"} # user token
|
||||
],
|
||||
prompt=[
|
||||
"{{query}}",
|
||||
{"token": "<reserved_103>"} # assistant token
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
stop_words=[
|
||||
"<reserved_102>" # user token
|
||||
],
|
||||
use_history=True
|
||||
bos_after_prefix=True
|
||||
)
|
||||
|
||||
|
||||
|
@ -479,7 +503,7 @@ register_template(
|
|||
name="starchat",
|
||||
prefix=[
|
||||
{"token": "<|system|>"},
|
||||
"\n",
|
||||
"\n{{system}}",
|
||||
{"token": "<|end|>"}
|
||||
],
|
||||
prompt=[
|
||||
|
@ -489,13 +513,13 @@ register_template(
|
|||
"\n",
|
||||
{"token": "<|assistant|>"}
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|end|>"
|
||||
],
|
||||
use_history=True
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -506,7 +530,7 @@ register_template(
|
|||
name="chatml",
|
||||
prefix=[
|
||||
{"token": "<|im_start|>"},
|
||||
"system\nYou are a helpful assistant.",
|
||||
"system\n{{system}}",
|
||||
{"token": "<|im_end|>"}
|
||||
],
|
||||
prompt=[
|
||||
|
@ -517,13 +541,13 @@ register_template(
|
|||
{"token": "<|im_start|>"},
|
||||
"assistant\n"
|
||||
],
|
||||
system="You are a helpful assistant.",
|
||||
sep=[
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|im_end|>"
|
||||
],
|
||||
use_history=True
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -534,14 +558,14 @@ register_template(
|
|||
name="chatglm2",
|
||||
prefix=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"}
|
||||
{"token": "sop"},
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n\n"
|
||||
],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
]
|
||||
)
|
||||
|
|
|
@ -10,7 +10,7 @@ class DatasetAttr:
|
|||
load_from: str
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
source_prefix: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
@ -86,9 +86,9 @@ class DataArguments:
|
|||
default=True,
|
||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
|
||||
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
|
||||
)
|
||||
val_size: Optional[float] = field(
|
||||
default=0,
|
||||
|
@ -100,12 +100,9 @@ class DataArguments:
|
|||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
if self.source_prefix is not None:
|
||||
prefix_list = self.source_prefix.split("|")
|
||||
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
|
||||
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
|
||||
else:
|
||||
prefix_list = [None] * len(dataset_names)
|
||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
|
||||
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
|
||||
|
||||
if self.interleave_probs is not None:
|
||||
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
||||
|
@ -126,12 +123,11 @@ class DataArguments:
|
|||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||
)
|
||||
|
||||
dataset_attr.source_prefix = prefix_list[i]
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
||||
|
||||
dataset_attr.system_prompt = prompt_list[i]
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
|
|
@ -56,6 +56,5 @@ def run_pt(
|
|||
perplexity = float("inf")
|
||||
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
|
|
@ -26,7 +26,7 @@ class WebChatModel(ChatModel):
|
|||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str
|
||||
system_prompt: str
|
||||
):
|
||||
if self.model is not None:
|
||||
yield ALERTS["err_exists"][lang]
|
||||
|
@ -55,7 +55,7 @@ class WebChatModel(ChatModel):
|
|||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||
template=template,
|
||||
source_prefix=source_prefix
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
super().__init__(args)
|
||||
|
||||
|
@ -73,7 +73,7 @@ class WebChatModel(ChatModel):
|
|||
chatbot: List[Tuple[str, str]],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
prefix: str,
|
||||
system: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
|
@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
|
|||
chatbot.append([query, ""])
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
response = self.postprocess(response)
|
||||
|
|
|
@ -17,7 +17,7 @@ def create_chat_box(
|
|||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
prefix = gr.Textbox(show_label=False)
|
||||
system = gr.Textbox(show_label=False)
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
|
@ -31,7 +31,7 @@ def create_chat_box(
|
|||
|
||||
submit_btn.click(
|
||||
chat_model.predict,
|
||||
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature],
|
||||
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
|
||||
[chatbot, history],
|
||||
show_progress=True
|
||||
).then(
|
||||
|
@ -41,7 +41,7 @@ def create_chat_box(
|
|||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||
|
||||
return chat_box, chatbot, history, dict(
|
||||
prefix=prefix,
|
||||
system=system,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
clear_btn=clear_btn,
|
||||
|
|
|
@ -52,7 +52,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"],
|
||||
top_elems["system_prompt"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
|
|
|
@ -28,7 +28,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
|
|||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"]
|
||||
top_elems["system_prompt"]
|
||||
],
|
||||
[info_box]
|
||||
).then(
|
||||
|
|
|
@ -28,7 +28,7 @@ def create_top() -> Dict[str, "Component"]:
|
|||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
|
||||
source_prefix = gr.Textbox(scale=2)
|
||||
system_prompt = gr.Textbox(scale=2)
|
||||
|
||||
lang.change(save_config, [lang, model_name, model_path])
|
||||
|
||||
|
@ -62,5 +62,5 @@ def create_top() -> Dict[str, "Component"]:
|
|||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
template=template,
|
||||
source_prefix=source_prefix
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
|
|
|
@ -101,7 +101,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"],
|
||||
top_elems["system_prompt"],
|
||||
training_stage,
|
||||
dataset_dir,
|
||||
dataset,
|
||||
|
|
|
@ -77,7 +77,7 @@ LOCALES = {
|
|||
"info": "构建提示词时使用的模板"
|
||||
}
|
||||
},
|
||||
"source_prefix": {
|
||||
"system_prompt": {
|
||||
"en": {
|
||||
"label": "System prompt (optional)",
|
||||
"info": "A sequence used as the default system prompt."
|
||||
|
@ -455,7 +455,7 @@ LOCALES = {
|
|||
"value": "模型未加载,请先加载模型。"
|
||||
}
|
||||
},
|
||||
"prefix": {
|
||||
"system": {
|
||||
"en": {
|
||||
"placeholder": "System prompt (optional)"
|
||||
},
|
||||
|
|
|
@ -69,7 +69,7 @@ class Runner:
|
|||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
system_prompt: str,
|
||||
training_stage: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
|
@ -114,7 +114,7 @@ class Runner:
|
|||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||
template=template,
|
||||
source_prefix=source_prefix,
|
||||
system_prompt=system_prompt,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
|
@ -170,7 +170,7 @@ class Runner:
|
|||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
system_prompt: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
|
@ -198,7 +198,7 @@ class Runner:
|
|||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||
template=template,
|
||||
source_prefix=source_prefix,
|
||||
system_prompt=system_prompt,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
|
|
Loading…
Reference in New Issue