diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 26ee57ce..776e8c84 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -120,6 +120,9 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest): if request.stream: + if tools: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") + generate = stream_chat_completion(messages, system, tools, request) return EventSourceResponse(generate, media_type="text/event-stream") diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index a3b23be9..928d2633 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -218,7 +218,7 @@ def register_template( default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) - default_tool_formatter = ToolFormatter(slots="default") + default_tool_formatter = ToolFormatter(tool_format="default") default_separator_formatter = EmptyFormatter() templates[name] = template_class( format_user=format_user or default_user_formatter, @@ -356,6 +356,14 @@ register_template( ) +register_template( + name="cpm", + format_user=StringFormatter(slots=["<用户>{{content}}"]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + force_system=True, +) + + register_template( name="deepseek", format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), @@ -464,7 +472,7 @@ register_template( register_template( name="orion", - format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]), + format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), force_system=True, ) diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index 8b144626..f794c846 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -63,13 +63,12 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter: raise ValueError("Cannot create new adapter upon a quantized model.") - if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: - if finetuning_args.finetuning_type != "lora": - raise ValueError("Multiple adapters are only available for LoRA tuning.") - - if model_args.quantization_bit is not None: + if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: raise ValueError("Quantized model only accepts a single adapter. Merge them first.") + if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Only LoRA method has adapters.") + def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: parser = HfArgumentParser(_TRAIN_ARGS)