diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 47b7661f..41a7fe9a 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -47,15 +47,15 @@ def create_app(chat_model: ChatModel) -> FastAPI: @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): - if request.messages[-1].role != Role.USER: + if len(request.messages) < 1 or request.messages[-1].role != Role.USER: raise HTTPException(status_code=400, detail="Invalid request") query = request.messages[-1].content prev_messages = request.messages[:-1] if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: - prefix = prev_messages.pop(0).content + system = prev_messages.pop(0).content else: - prefix = None + system = None history = [] if len(prev_messages) % 2 == 0: @@ -64,11 +64,11 @@ def create_app(chat_model: ChatModel) -> FastAPI: history.append([prev_messages[i].content, prev_messages[i+1].content]) if request.stream: - generate = predict(query, history, prefix, request) + generate = predict(query, history, system, request) return EventSourceResponse(generate, media_type="text/event-stream") response, (prompt_length, response_length) = chat_model.chat( - query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens + query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens ) usage = ChatCompletionResponseUsage( @@ -85,7 +85,7 @@ def create_app(chat_model: ChatModel) -> FastAPI: return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) - async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): + async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role=Role.ASSISTANT), @@ -95,7 +95,7 @@ def create_app(chat_model: ChatModel) -> FastAPI: yield chunk.json(exclude_unset=True, ensure_ascii=False) for new_text in chat_model.stream_chat( - query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens + query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens ): if len(new_text) == 0: continue diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index e207ee0b..bf602dd5 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -16,19 +16,19 @@ class ChatModel: self.model = dispatch_model(self.model) self.model = self.model.eval() # enable evaluation mode self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) - self.source_prefix = data_args.source_prefix + self.system_prompt = data_args.system_prompt def process_args( self, query: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None, + system: Optional[str] = None, **input_kwargs ) -> Tuple[Dict[str, Any], int]: - prefix = prefix or self.source_prefix + system = system or self.system_prompt prompt, _ = self.template.encode_oneturn( - tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix + tokenizer=self.tokenizer, query=query, resp="", history=history, system=system ) input_ids = torch.tensor([prompt], device=self.model.device) prompt_length = len(input_ids[0]) @@ -68,10 +68,10 @@ class ChatModel: self, query: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None, + system: Optional[str] = None, **input_kwargs ) -> Tuple[str, Tuple[int, int]]: - gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) + gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs) generation_output = self.model.generate(**gen_kwargs) outputs = generation_output.tolist()[0][prompt_length:] response = self.tokenizer.decode(outputs, skip_special_tokens=True) @@ -83,10 +83,10 @@ class ChatModel: self, query: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None, + system: Optional[str] = None, **input_kwargs ) -> Generator[str, None, None]: - gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) + gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index 7588443f..6e293c78 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -92,14 +92,13 @@ def get_dataset( if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) - if dataset_attr.source_prefix: # add prefix + if dataset_attr.system_prompt: # add system prompt if data_args.streaming: features = dataset.features - features["prefix"] = Value(dtype="string", id=None) - dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features) + features["system"] = Value(dtype="string", id=None) + dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}, features=features) else: - prefix_data = [dataset_attr.source_prefix] * len(dataset) - dataset = dataset.add_column("prefix", prefix_data) + dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset)) all_datasets.append(dataset) diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 64e0d8b1..3cc5a483 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -27,8 +27,8 @@ def preprocess_dataset( query, response = examples["prompt"][i], examples["response"][i] query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query history = examples["history"][i] if "history" in examples else None - prefix = examples["prefix"][i] if "prefix" in examples else None - yield query, response, history, prefix + system = examples["system"][i] if "system" in examples else None + yield query, response, history, system def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: # build grouped texts with format `X1 X2 X3 ...` (without ) @@ -56,10 +56,10 @@ def preprocess_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} max_length = data_args.max_source_length + data_args.max_target_length - for query, response, history, prefix in construct_example(examples): + for query, response, history, system in construct_example(examples): input_ids, labels = [], [] - for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, prefix): + for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system): if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] if len(target_ids) > data_args.max_target_length: @@ -78,11 +78,11 @@ def preprocess_dataset( return model_inputs def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: - # build inputs with format ` X` and labels with format ` Y` + # build inputs with format ` X` and labels with format `Y ` model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - for query, response, history, prefix in construct_example(examples): - source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, prefix) + for query, response, history, system in construct_example(examples): + source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] @@ -98,9 +98,9 @@ def preprocess_dataset( def preprocess_pairwise_dataset(examples): # build input pairs with format ` X`, `Y1 ` and `Y2 ` model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} - for query, response, history, prefix in construct_example(examples): - prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix) - _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix) + for query, response, history, system in construct_example(examples): + prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system) + _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system) if len(prompt_ids) > data_args.max_source_length: prompt_ids = prompt_ids[:data_args.max_source_length] diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 8ad42ac8..25907382 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -16,9 +16,11 @@ class Template: prefix: List[Union[str, Dict[str, str]]] prompt: List[Union[str, Dict[str, str]]] + system: str sep: List[Union[str, Dict[str, str]]] stop_words: List[str] use_history: bool + bos_after_prefix: bool def encode_oneturn( self, @@ -26,18 +28,18 @@ class Template: query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None + system: Optional[str] = None ) -> Tuple[List[int], List[int]]: r""" Returns a single pair of token ids representing prompt and response respectively. """ - prefix, history = self._format(query, resp, history, prefix) - encoded_pairs = self._encode(tokenizer, prefix, history) + system, history = self._format(query, resp, history, system) + encoded_pairs = self._encode(tokenizer, system, history) prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: prompt_ids = prompt_ids + query_ids + resp_ids - prompt_ids = prompt_ids + encoded_pairs[-1][0] - return prompt_ids, encoded_pairs[-1][1] + prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1] + return prompt_ids, answer_ids def encode_multiturn( self, @@ -45,13 +47,13 @@ class Template: query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None + system: Optional[str] = None ) -> List[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ - prefix, history = self._format(query, resp, history, prefix) - encoded_pairs = self._encode(tokenizer, prefix, history) + system, history = self._format(query, resp, history, system) + encoded_pairs = self._encode(tokenizer, system, history) return encoded_pairs def _format( @@ -59,15 +61,15 @@ class Template: query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None - ) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]: + system: Optional[str] = None + ) -> Tuple[str, List[Tuple[str, str]]]: r""" Aligns inputs to the standard format. """ - prefix = [prefix] if prefix else self.prefix # use prefix if provided + system = system or self.system # use system if provided history = history if (history and self.use_history) else [] history = history + [(query, resp)] - return prefix, history + return system, history def _get_special_ids( self, @@ -88,7 +90,7 @@ class Template: def _encode( self, tokenizer: "PreTrainedTokenizer", - prefix: List[Union[str, Dict[str, str]]], + system: str, history: List[Tuple[str, str]] ) -> List[Tuple[List[int], List[int]]]: r""" @@ -101,8 +103,12 @@ class Template: encoded_pairs = [] for turn_idx, (query, resp) in enumerate(history): if turn_idx == 0: - if prefix: # has prefix - prefix_ids = bos_ids + self._convert_inputs_to_ids(tokenizer, context=prefix) + sep_ids + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system) + if len(prefix_ids) != 0: # has prefix + if self.bos_after_prefix: + prefix_ids = prefix_ids + bos_ids + sep_ids + else: + prefix_ids = bos_ids + prefix_ids + sep_ids else: prefix_ids = bos_ids else: @@ -117,8 +123,9 @@ class Template: self, tokenizer: "PreTrainedTokenizer", context: List[Union[str, Dict[str, str]]], - query: Optional[str] = "", - idx: Optional[str] = "" + system: Optional[str] = None, + query: Optional[str] = None, + idx: Optional[str] = None ) -> List[int]: r""" Converts context to token ids. @@ -131,13 +138,15 @@ class Template: token_ids = [] for elem in context: if isinstance(elem, str): - elem = elem.replace("{{query}}", query, 1) - elem = elem.replace("{{idx}}", idx, 1) + elem = elem.replace("{{system}}", system, 1) if system is not None else elem + elem = elem.replace("{{query}}", query, 1) if query is not None else elem + elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem token_ids = token_ids + tokenizer.encode(elem, **kwargs) elif isinstance(elem, dict): token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] else: raise NotImplementedError + return token_ids @@ -147,7 +156,7 @@ class Llama2Template(Template): def _encode( self, tokenizer: "PreTrainedTokenizer", - prefix: List[Union[str, Dict[str, str]]], + system: str, history: List[Tuple[str, str]] ) -> List[Tuple[List[int], List[int]]]: r""" @@ -157,10 +166,9 @@ class Llama2Template(Template): """ bos_ids, eos_ids = self._get_special_ids(tokenizer) encoded_pairs = [] - assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single string." for turn_idx, (query, resp) in enumerate(history): - if turn_idx == 0: # llama2 template has not sep_ids - query = prefix[0] + query + if turn_idx == 0: # llama2 template has no sep_ids + query = self.prefix[0].replace("{{system}}", system) + query query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) @@ -174,17 +182,21 @@ def register_template( name: str, prefix: List[Union[str, Dict[str, str]]], prompt: List[Union[str, Dict[str, str]]], + system: str, sep: List[Union[str, Dict[str, str]]], - stop_words: List[str], - use_history: bool + stop_words: Optional[List[str]] = [], + use_history: Optional[bool] = True, + bos_after_prefix: Optional[bool] = False ) -> None: template_class = Llama2Template if "llama2" in name else Template templates[name] = template_class( prefix=prefix, prompt=prompt, + system=system, sep=sep, stop_words=stop_words, - use_history=use_history + use_history=use_history, + bos_after_prefix=bos_after_prefix ) @@ -201,7 +213,7 @@ def get_template_and_fix_tokenizer( if tokenizer.eos_token_id is not None: additional_special_tokens.append(tokenizer.eos_token) - tokenizer.eos_token = template.stop_words[0] + tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token additional_special_tokens.pop(0) logger.info("Replace eos token: {}".format(tokenizer.eos_token)) @@ -229,8 +241,8 @@ register_template( prompt=[ "{{query}}" ], + system="", sep=[], - stop_words=[], use_history=False ) @@ -241,17 +253,18 @@ Default template. register_template( name="default", prefix=[ - "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions." + "{{system}}" ], prompt=[ "Human: {{query}}\nAssistant: " ], + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), sep=[ "\n" - ], - stop_words=[], - use_history=True + ] ) @@ -263,21 +276,22 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf register_template( name="llama2", prefix=[ - "<>\nYou are a helpful, respectful and honest assistant. " + "<>\n{{system}}\n<>\n\n" + ], + prompt=[ + "[INST] {{query}} [/INST] " + ], + system=( + "You are a helpful, respectful and honest assistant. " "Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, " "racist, sexist, toxic, dangerous, or illegal content. " "Please ensure that your responses are socially unbiased and positive in nature.\n" "If a question does not make any sense, or is not factually coherent, " "explain why instead of answering something not correct. " - "If you don't know the answer to a question, please don't share false information.\n<>\n\n" - ], - prompt=[ - "[INST] {{query}} [/INST] " - ], - sep=[], - stop_words=[], - use_history=True + "If you don't know the answer to a question, please don't share false information." + ), + sep=[] ) @@ -288,14 +302,13 @@ Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2 register_template( name="llama2_zh", prefix=[ - "<>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<>\n\n" + "<>\n{{system}}\n<>\n\n" ], prompt=[ "[INST] {{query}} [/INST] " ], - sep=[], - stop_words=[], - use_history=True + system="You are a helpful assistant. 你是一个乐于助人的助手。", + sep=[] ) @@ -306,17 +319,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff register_template( name="alpaca", prefix=[ - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request." + "{{system}}" ], prompt=[ "### Instruction:\n{{query}}\n\n### Response:\n" ], + system=( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request." + ), sep=[ "\n\n" - ], - stop_words=[], - use_history=True + ] ) @@ -327,15 +341,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 register_template( name="vicuna", prefix=[ - "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions." + "{{system}}" ], prompt=[ "USER: {{query}} ASSISTANT: " ], - sep=[], - stop_words=[], - use_history=True + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + sep=[] ) @@ -344,15 +359,16 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B """ register_template( name="belle", - prefix=[], + prefix=[ + "{{system}}" + ], prompt=[ "Human: {{query}}\n\nBelle: " ], + system="", sep=[ "\n\n" - ], - stop_words=[], - use_history=True + ] ) @@ -361,15 +377,16 @@ Supports: https://github.com/CVI-SZU/Linly """ register_template( name="linly", - prefix=[], + prefix=[ + "{{system}}" + ], prompt=[ "User: {{query}}\nBot: " ], + system="", sep=[ "\n" - ], - stop_words=[], - use_history=True + ] ) @@ -378,15 +395,16 @@ Supports: https://github.com/Neutralzz/BiLLa """ register_template( name="billa", - prefix=[], + prefix=[ + "{{system}}" + ], prompt=[ "Human: {{query}}\nAssistant: " ], + system="", sep=[ "\n" - ], - stop_words=[], - use_history=True + ] ) @@ -395,18 +413,19 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 """ register_template( name="ziya", - prefix=[], + prefix=[ + "{{system}}" + ], prompt=[ {"token": ""}, ":{{query}}\n", {"token": ""}, ":" ], + system="", sep=[ "\n" - ], - stop_words=[], - use_history=True + ] ) @@ -416,17 +435,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b register_template( name="aquila", prefix=[ - "A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions." + "{{system}}" ], prompt=[ "Human: {{query}}###Assistant: " ], + system=( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ), sep=[ "###" - ], - stop_words=[], - use_history=True + ] ) @@ -435,19 +455,21 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b """ register_template( name="intern", - prefix=[], + prefix=[ + "{{system}}" + ], prompt=[ "<|User|>:{{query}}", {"token": ""}, "\n<|Bot|>:" ], + system="", sep=[ "\n" ], stop_words=[ "" - ], - use_history=True + ] ) @@ -457,17 +479,19 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat register_template( name="baichuan", prefix=[ - {"token": ""} # user token (a little difference in the first turn) + "{{system}}", + {"token": ""} # user token ], prompt=[ "{{query}}", {"token": ""} # assistant token ], + system="", sep=[], stop_words=[ "" # user token ], - use_history=True + bos_after_prefix=True ) @@ -479,7 +503,7 @@ register_template( name="starchat", prefix=[ {"token": "<|system|>"}, - "\n", + "\n{{system}}", {"token": "<|end|>"} ], prompt=[ @@ -489,13 +513,13 @@ register_template( "\n", {"token": "<|assistant|>"} ], + system="", sep=[ "\n" ], stop_words=[ "<|end|>" - ], - use_history=True + ] ) @@ -506,7 +530,7 @@ register_template( name="chatml", prefix=[ {"token": "<|im_start|>"}, - "system\nYou are a helpful assistant.", + "system\n{{system}}", {"token": "<|im_end|>"} ], prompt=[ @@ -517,13 +541,13 @@ register_template( {"token": "<|im_start|>"}, "assistant\n" ], + system="You are a helpful assistant.", sep=[ "\n" ], stop_words=[ "<|im_end|>" - ], - use_history=True + ] ) @@ -534,14 +558,14 @@ register_template( name="chatglm2", prefix=[ {"token": "[gMASK]"}, - {"token": "sop"} + {"token": "sop"}, + "{{system}}" ], prompt=[ "[Round {{idx}}]\n\n问:{{query}}\n\n答:" ], + system="", sep=[ "\n\n" - ], - stop_words=[], - use_history=True + ] ) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 7d1c982c..374d03c6 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -10,7 +10,7 @@ class DatasetAttr: load_from: str dataset_name: Optional[str] = None dataset_sha1: Optional[str] = None - source_prefix: Optional[str] = None + system_prompt: Optional[str] = None def __repr__(self) -> str: return self.dataset_name @@ -86,9 +86,9 @@ class DataArguments: default=True, metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} ) - source_prefix: Optional[str] = field( + system_prompt: Optional[str] = field( default=None, - metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."} + metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."} ) val_size: Optional[float] = field( default=0, @@ -100,12 +100,9 @@ class DataArguments: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: dataset_info = json.load(f) - if self.source_prefix is not None: - prefix_list = self.source_prefix.split("|") - prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list - assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1." - else: - prefix_list = [None] * len(dataset_names) + prompt_list = self.system_prompt.split("|") if self.system_prompt else [None] + prompt_list = prompt_list * (len(dataset_names) // len(prompt_list)) + assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1." if self.interleave_probs is not None: self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")] @@ -126,12 +123,11 @@ class DataArguments: dataset_sha1=dataset_info[name].get("file_sha1", None) ) - dataset_attr.source_prefix = prefix_list[i] - if "columns" in dataset_info[name]: dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) dataset_attr.query = dataset_info[name]["columns"].get("query", None) dataset_attr.response = dataset_info[name]["columns"].get("response", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None) + dataset_attr.system_prompt = prompt_list[i] self.dataset_list.append(dataset_attr) diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/tuner/pt/workflow.py index f7bf6448..865ec218 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/tuner/pt/workflow.py @@ -56,6 +56,5 @@ def run_pt( perplexity = float("inf") metrics["perplexity"] = perplexity - trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py index e73ea402..154efa5a 100644 --- a/src/llmtuner/webui/chat.py +++ b/src/llmtuner/webui/chat.py @@ -26,7 +26,7 @@ class WebChatModel(ChatModel): finetuning_type: str, quantization_bit: str, template: str, - source_prefix: str + system_prompt: str ): if self.model is not None: yield ALERTS["err_exists"][lang] @@ -55,7 +55,7 @@ class WebChatModel(ChatModel): finetuning_type=finetuning_type, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, template=template, - source_prefix=source_prefix + system_prompt=system_prompt ) super().__init__(args) @@ -73,7 +73,7 @@ class WebChatModel(ChatModel): chatbot: List[Tuple[str, str]], query: str, history: List[Tuple[str, str]], - prefix: str, + system: str, max_new_tokens: int, top_p: float, temperature: float @@ -81,7 +81,7 @@ class WebChatModel(ChatModel): chatbot.append([query, ""]) response = "" for new_text in self.stream_chat( - query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature ): response += new_text response = self.postprocess(response) diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 6fcfc652..928a568c 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -17,7 +17,7 @@ def create_chat_box( with gr.Row(): with gr.Column(scale=4): - prefix = gr.Textbox(show_label=False) + system = gr.Textbox(show_label=False) query = gr.Textbox(show_label=False, lines=8) submit_btn = gr.Button(variant="primary") @@ -31,7 +31,7 @@ def create_chat_box( submit_btn.click( chat_model.predict, - [chatbot, query, history, prefix, max_new_tokens, top_p, temperature], + [chatbot, query, history, system, max_new_tokens, top_p, temperature], [chatbot, history], show_progress=True ).then( @@ -41,7 +41,7 @@ def create_chat_box( clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) return chat_box, chatbot, history, dict( - prefix=prefix, + system=system, query=query, submit_btn=submit_btn, clear_btn=clear_btn, diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index 48372b4c..cbc71daf 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -52,7 +52,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict top_elems["finetuning_type"], top_elems["quantization_bit"], top_elems["template"], - top_elems["source_prefix"], + top_elems["system_prompt"], dataset_dir, dataset, max_source_length, diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index 40e0323e..14aef162 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -28,7 +28,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component" top_elems["finetuning_type"], top_elems["quantization_bit"], top_elems["template"], - top_elems["source_prefix"] + top_elems["system_prompt"] ], [info_box] ).then( diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 7f3c6faa..62c1f9c9 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -28,7 +28,7 @@ def create_top() -> Dict[str, "Component"]: with gr.Row(): quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1) template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1) - source_prefix = gr.Textbox(scale=2) + system_prompt = gr.Textbox(scale=2) lang.change(save_config, [lang, model_name, model_path]) @@ -62,5 +62,5 @@ def create_top() -> Dict[str, "Component"]: advanced_tab=advanced_tab, quantization_bit=quantization_bit, template=template, - source_prefix=source_prefix + system_prompt=system_prompt ) diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 6aaeecbb..aab512ee 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -101,7 +101,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic top_elems["finetuning_type"], top_elems["quantization_bit"], top_elems["template"], - top_elems["source_prefix"], + top_elems["system_prompt"], training_stage, dataset_dir, dataset, diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 7a58c4c7..c4032f39 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -77,7 +77,7 @@ LOCALES = { "info": "构建提示词时使用的模板" } }, - "source_prefix": { + "system_prompt": { "en": { "label": "System prompt (optional)", "info": "A sequence used as the default system prompt." @@ -455,7 +455,7 @@ LOCALES = { "value": "模型未加载,请先加载模型。" } }, - "prefix": { + "system": { "en": { "placeholder": "System prompt (optional)" }, diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 49fed19b..ac74a4c7 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -69,7 +69,7 @@ class Runner: finetuning_type: str, quantization_bit: str, template: str, - source_prefix: str, + system_prompt: str, training_stage: str, dataset_dir: str, dataset: List[str], @@ -114,7 +114,7 @@ class Runner: finetuning_type=finetuning_type, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, template=template, - source_prefix=source_prefix, + system_prompt=system_prompt, dataset_dir=dataset_dir, dataset=",".join(dataset), max_source_length=max_source_length, @@ -170,7 +170,7 @@ class Runner: finetuning_type: str, quantization_bit: str, template: str, - source_prefix: str, + system_prompt: str, dataset_dir: str, dataset: List[str], max_source_length: int, @@ -198,7 +198,7 @@ class Runner: finetuning_type=finetuning_type, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, template=template, - source_prefix=source_prefix, + system_prompt=system_prompt, dataset_dir=dataset_dir, dataset=",".join(dataset), max_source_length=max_source_length,