diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index ed9ba8b8..b5dc57ff 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -34,8 +34,8 @@ DEFAULT_TOOL_PROMPT = ( GLM4_TOOL_PROMPT = ( - "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," - "你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}" + "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," + "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}" ) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 9b51c064..3f21145d 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -97,7 +97,7 @@ class ModelArguments: default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) - flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field( + flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field( default="auto", metadata={"help": "Enable FlashAttention for faster training and inference."}, ) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 35153649..24cd2601 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -58,10 +58,10 @@ def patch_config( is_trainable: bool, ) -> None: if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 - if model_args.infer_dtype == "auto": - model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) - else: + if model_args.infer_dtype != "auto" and not is_trainable: model_args.compute_dtype = getattr(torch, model_args.infer_dtype) + else: + model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) if is_torch_npu_available(): use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"] diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index a2b54dce..652c341c 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -87,6 +87,7 @@ class WebChatModel(ChatModel): visual_inputs=get("top.visual_inputs"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, infer_backend=get("infer.infer_backend"), + infer_dtype=get("infer.infer_dtype"), ) if checkpoint_path: diff --git a/src/llamafactory/webui/components/infer.py b/src/llamafactory/webui/components/infer.py index 03bccd7f..a0064479 100644 --- a/src/llamafactory/webui/components/infer.py +++ b/src/llamafactory/webui/components/infer.py @@ -32,15 +32,26 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: input_elems = engine.manager.get_base_elems() elem_dict = dict() - infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface") + with gr.Row(): + infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface") + infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto") + with gr.Row(): load_btn = gr.Button() unload_btn = gr.Button() info_box = gr.Textbox(show_label=False, interactive=False) - input_elems.update({infer_backend}) - elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box)) + input_elems.update({infer_backend, infer_dtype}) + elem_dict.update( + dict( + infer_backend=infer_backend, + infer_dtype=infer_dtype, + load_btn=load_btn, + unload_btn=unload_btn, + info_box=info_box, + ) + ) chatbot, messages, chat_elems = create_chat_box(engine, visible=False) elem_dict.update(chat_elems) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 8e8d6fce..cd166584 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1206,6 +1206,17 @@ LOCALES = { "label": "推理引擎", }, }, + "infer_dtype": { + "en": { + "label": "Inference data type", + }, + "ru": { + "label": "Тип данных для вывода", + }, + "zh": { + "label": "推理数据类型", + }, + }, "load_btn": { "en": { "value": "Load model",