fix system prompt

This commit is contained in:
hiyouga 2023-08-16 01:35:52 +08:00
parent 273135f595
commit 7407d9daa1
15 changed files with 170 additions and 152 deletions

View File

@ -47,15 +47,15 @@ def create_app(chat_model: ChatModel) -> FastAPI:
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != Role.USER:
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content
system = prev_messages.pop(0).content
else:
prefix = None
system = None
history = []
if len(prev_messages) % 2 == 0:
@ -64,11 +64,11 @@ def create_app(chat_model: ChatModel) -> FastAPI:
history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream:
generate = predict(query, history, prefix, request)
generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
)
usage = ChatCompletionResponseUsage(
@ -85,7 +85,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role=Role.ASSISTANT),
@ -95,7 +95,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
):
if len(new_text) == 0:
continue

View File

@ -16,19 +16,19 @@ class ChatModel:
self.model = dispatch_model(self.model)
self.model = self.model.eval() # enable evaluation mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.source_prefix = data_args.source_prefix
self.system_prompt = data_args.system_prompt
def process_args(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
system: Optional[str] = None,
**input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix or self.source_prefix
system = system or self.system_prompt
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
)
input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0])
@ -68,10 +68,10 @@ class ChatModel:
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
system: Optional[str] = None,
**input_kwargs
) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
@ -83,10 +83,10 @@ class ChatModel:
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
system: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer

View File

@ -92,14 +92,13 @@ def get_dataset(
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.source_prefix: # add prefix
if dataset_attr.system_prompt: # add system prompt
if data_args.streaming:
features = dataset.features
features["prefix"] = Value(dtype="string", id=None)
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
features["system"] = Value(dtype="string", id=None)
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}, features=features)
else:
prefix_data = [dataset_attr.source_prefix] * len(dataset)
dataset = dataset.add_column("prefix", prefix_data)
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
all_datasets.append(dataset)

View File

@ -27,8 +27,8 @@ def preprocess_dataset(
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
prefix = examples["prefix"][i] if "prefix" in examples else None
yield query, response, history, prefix
system = examples["system"][i] if "system" in examples else None
yield query, response, history, system
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `X1 X2 X3 ...` (without <eos>)
@ -56,10 +56,10 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length
for query, response, history, prefix in construct_example(examples):
for query, response, history, system in construct_example(examples):
input_ids, labels = [], []
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, prefix):
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length:
@ -78,11 +78,11 @@ def preprocess_dataset(
return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `<bos> Y`
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, prefix in construct_example(examples):
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, prefix)
for query, response, history, system in construct_example(examples):
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
@ -98,9 +98,9 @@ def preprocess_dataset(
def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, prefix in construct_example(examples):
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix)
for query, response, history, system in construct_example(examples):
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
if len(prompt_ids) > data_args.max_source_length:
prompt_ids = prompt_ids[:data_args.max_source_length]

View File

@ -16,9 +16,11 @@ class Template:
prefix: List[Union[str, Dict[str, str]]]
prompt: List[Union[str, Dict[str, str]]]
system: str
sep: List[Union[str, Dict[str, str]]]
stop_words: List[str]
use_history: bool
bos_after_prefix: bool
def encode_oneturn(
self,
@ -26,18 +28,18 @@ class Template:
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None
system: Optional[str] = None
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
prefix, history = self._format(query, resp, history, prefix)
encoded_pairs = self._encode(tokenizer, prefix, history)
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
return prompt_ids, encoded_pairs[-1][1]
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
@ -45,13 +47,13 @@ class Template:
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None
system: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
prefix, history = self._format(query, resp, history, prefix)
encoded_pairs = self._encode(tokenizer, prefix, history)
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
return encoded_pairs
def _format(
@ -59,15 +61,15 @@ class Template:
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None
) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]:
system: Optional[str] = None
) -> Tuple[str, List[Tuple[str, str]]]:
r"""
Aligns inputs to the standard format.
"""
prefix = [prefix] if prefix else self.prefix # use prefix if provided
system = system or self.system # use system if provided
history = history if (history and self.use_history) else []
history = history + [(query, resp)]
return prefix, history
return system, history
def _get_special_ids(
self,
@ -88,7 +90,7 @@ class Template:
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
prefix: List[Union[str, Dict[str, str]]],
system: str,
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
@ -101,8 +103,12 @@ class Template:
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
if prefix: # has prefix
prefix_ids = bos_ids + self._convert_inputs_to_ids(tokenizer, context=prefix) + sep_ids
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
if len(prefix_ids) != 0: # has prefix
if self.bos_after_prefix:
prefix_ids = prefix_ids + bos_ids + sep_ids
else:
prefix_ids = bos_ids + prefix_ids + sep_ids
else:
prefix_ids = bos_ids
else:
@ -117,8 +123,9 @@ class Template:
self,
tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]],
query: Optional[str] = "",
idx: Optional[str] = ""
system: Optional[str] = None,
query: Optional[str] = None,
idx: Optional[str] = None
) -> List[int]:
r"""
Converts context to token ids.
@ -131,13 +138,15 @@ class Template:
token_ids = []
for elem in context:
if isinstance(elem, str):
elem = elem.replace("{{query}}", query, 1)
elem = elem.replace("{{idx}}", idx, 1)
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
else:
raise NotImplementedError
return token_ids
@ -147,7 +156,7 @@ class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
prefix: List[Union[str, Dict[str, str]]],
system: str,
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
@ -157,10 +166,9 @@ class Llama2Template(Template):
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
encoded_pairs = []
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single string."
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: # llama2 template has not sep_ids
query = prefix[0] + query
if turn_idx == 0: # llama2 template has no sep_ids
query = self.prefix[0].replace("{{system}}", system) + query
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
@ -174,17 +182,21 @@ def register_template(
name: str,
prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]],
system: str,
sep: List[Union[str, Dict[str, str]]],
stop_words: List[str],
use_history: bool
stop_words: Optional[List[str]] = [],
use_history: Optional[bool] = True,
bos_after_prefix: Optional[bool] = False
) -> None:
template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
system=system,
sep=sep,
stop_words=stop_words,
use_history=use_history
use_history=use_history,
bos_after_prefix=bos_after_prefix
)
@ -201,7 +213,7 @@ def get_template_and_fix_tokenizer(
if tokenizer.eos_token_id is not None:
additional_special_tokens.append(tokenizer.eos_token)
tokenizer.eos_token = template.stop_words[0]
tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
additional_special_tokens.pop(0)
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
@ -229,8 +241,8 @@ register_template(
prompt=[
"{{query}}"
],
system="",
sep=[],
stop_words=[],
use_history=False
)
@ -241,17 +253,18 @@ Default template.
register_template(
name="default",
prefix=[
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant: "
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[
"\n"
],
stop_words=[],
use_history=True
]
)
@ -263,21 +276,22 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
register_template(
name="llama2",
prefix=[
"<<SYS>>\nYou are a helpful, respectful and honest assistant. "
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
sep=[],
stop_words=[],
use_history=True
"If you don't know the answer to a question, please don't share false information."
),
sep=[]
)
@ -288,14 +302,13 @@ Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
register_template(
name="llama2_zh",
prefix=[
"<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n"
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
sep=[],
stop_words=[],
use_history=True
system="You are a helpful assistant. 你是一个乐于助人的助手。",
sep=[]
)
@ -306,17 +319,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
register_template(
name="alpaca",
prefix=[
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
"{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
sep=[
"\n\n"
],
stop_words=[],
use_history=True
]
)
@ -327,15 +341,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
register_template(
name="vicuna",
prefix=[
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
"{{system}}"
],
prompt=[
"USER: {{query}} ASSISTANT: "
],
sep=[],
stop_words=[],
use_history=True
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[]
)
@ -344,15 +359,16 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix=[],
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nBelle: "
],
system="",
sep=[
"\n\n"
],
stop_words=[],
use_history=True
]
)
@ -361,15 +377,16 @@ Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
prefix=[],
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\nBot: "
],
system="",
sep=[
"\n"
],
stop_words=[],
use_history=True
]
)
@ -378,15 +395,16 @@ Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
prefix=[],
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant: "
],
system="",
sep=[
"\n"
],
stop_words=[],
use_history=True
]
)
@ -395,18 +413,19 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
register_template(
name="ziya",
prefix=[],
prefix=[
"{{system}}"
],
prompt=[
{"token": "<human>"},
":{{query}}\n",
{"token": "<bot>"},
":"
],
system="",
sep=[
"\n"
],
stop_words=[],
use_history=True
]
)
@ -416,17 +435,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b
register_template(
name="aquila",
prefix=[
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
"{{system}}"
],
prompt=[
"Human: {{query}}###Assistant: "
],
system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
sep=[
"###"
],
stop_words=[],
use_history=True
]
)
@ -435,19 +455,21 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
register_template(
name="intern",
prefix=[],
prefix=[
"{{system}}"
],
prompt=[
"<|User|>:{{query}}",
{"token": "<eoh>"},
"\n<|Bot|>:"
],
system="",
sep=[
"\n"
],
stop_words=[
"<eoa>"
],
use_history=True
]
)
@ -457,17 +479,19 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
register_template(
name="baichuan",
prefix=[
{"token": "<reserved_102>"} # user token (a little difference in the first turn)
"{{system}}",
{"token": "<reserved_102>"} # user token
],
prompt=[
"{{query}}",
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[],
stop_words=[
"<reserved_102>" # user token
],
use_history=True
bos_after_prefix=True
)
@ -479,7 +503,7 @@ register_template(
name="starchat",
prefix=[
{"token": "<|system|>"},
"\n",
"\n{{system}}",
{"token": "<|end|>"}
],
prompt=[
@ -489,13 +513,13 @@ register_template(
"\n",
{"token": "<|assistant|>"}
],
system="",
sep=[
"\n"
],
stop_words=[
"<|end|>"
],
use_history=True
]
)
@ -506,7 +530,7 @@ register_template(
name="chatml",
prefix=[
{"token": "<|im_start|>"},
"system\nYou are a helpful assistant.",
"system\n{{system}}",
{"token": "<|im_end|>"}
],
prompt=[
@ -517,13 +541,13 @@ register_template(
{"token": "<|im_start|>"},
"assistant\n"
],
system="You are a helpful assistant.",
sep=[
"\n"
],
stop_words=[
"<|im_end|>"
],
use_history=True
]
)
@ -534,14 +558,14 @@ register_template(
name="chatglm2",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"}
{"token": "sop"},
"{{system}}"
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
],
system="",
sep=[
"\n\n"
],
stop_words=[],
use_history=True
]
)

View File

@ -10,7 +10,7 @@ class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
source_prefix: Optional[str] = None
system_prompt: Optional[str] = None
def __repr__(self) -> str:
return self.dataset_name
@ -86,9 +86,9 @@ class DataArguments:
default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
)
source_prefix: Optional[str] = field(
system_prompt: Optional[str] = field(
default=None,
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
)
val_size: Optional[float] = field(
default=0,
@ -100,12 +100,9 @@ class DataArguments:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
if self.source_prefix is not None:
prefix_list = self.source_prefix.split("|")
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
else:
prefix_list = [None] * len(dataset_names)
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
@ -126,12 +123,11 @@ class DataArguments:
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]:
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr)

View File

@ -56,6 +56,5 @@ def run_pt(
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

View File

@ -26,7 +26,7 @@ class WebChatModel(ChatModel):
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str
system_prompt: str
):
if self.model is not None:
yield ALERTS["err_exists"][lang]
@ -55,7 +55,7 @@ class WebChatModel(ChatModel):
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
source_prefix=source_prefix
system_prompt=system_prompt
)
super().__init__(args)
@ -73,7 +73,7 @@ class WebChatModel(ChatModel):
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
prefix: str,
system: str,
max_new_tokens: int,
top_p: float,
temperature: float
@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
chatbot.append([query, ""])
response = ""
for new_text in self.stream_chat(
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
response = self.postprocess(response)

View File

@ -17,7 +17,7 @@ def create_chat_box(
with gr.Row():
with gr.Column(scale=4):
prefix = gr.Textbox(show_label=False)
system = gr.Textbox(show_label=False)
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
@ -31,7 +31,7 @@ def create_chat_box(
submit_btn.click(
chat_model.predict,
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature],
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
@ -41,7 +41,7 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict(
prefix=prefix,
system=system,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,

View File

@ -52,7 +52,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
top_elems["system_prompt"],
dataset_dir,
dataset,
max_source_length,

View File

@ -28,7 +28,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"]
top_elems["system_prompt"]
],
[info_box]
).then(

View File

@ -28,7 +28,7 @@ def create_top() -> Dict[str, "Component"]:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
source_prefix = gr.Textbox(scale=2)
system_prompt = gr.Textbox(scale=2)
lang.change(save_config, [lang, model_name, model_path])
@ -62,5 +62,5 @@ def create_top() -> Dict[str, "Component"]:
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
template=template,
source_prefix=source_prefix
system_prompt=system_prompt
)

View File

@ -101,7 +101,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
top_elems["system_prompt"],
training_stage,
dataset_dir,
dataset,

View File

@ -77,7 +77,7 @@ LOCALES = {
"info": "构建提示词时使用的模板"
}
},
"source_prefix": {
"system_prompt": {
"en": {
"label": "System prompt (optional)",
"info": "A sequence used as the default system prompt."
@ -455,7 +455,7 @@ LOCALES = {
"value": "模型未加载,请先加载模型。"
}
},
"prefix": {
"system": {
"en": {
"placeholder": "System prompt (optional)"
},

View File

@ -69,7 +69,7 @@ class Runner:
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
system_prompt: str,
training_stage: str,
dataset_dir: str,
dataset: List[str],
@ -114,7 +114,7 @@ class Runner:
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
source_prefix=source_prefix,
system_prompt=system_prompt,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
@ -170,7 +170,7 @@ class Runner:
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
system_prompt: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
@ -198,7 +198,7 @@ class Runner:
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
source_prefix=source_prefix,
system_prompt=system_prompt,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,