forked from p04798526/LLaMA-Factory-Mirror
fix #4410
This commit is contained in:
parent
e0014db7d2
commit
fca893d73c
|
@ -34,8 +34,8 @@ DEFAULT_TOOL_PROMPT = (
|
||||||
|
|
||||||
|
|
||||||
GLM4_TOOL_PROMPT = (
|
GLM4_TOOL_PROMPT = (
|
||||||
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
|
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -97,7 +97,7 @@ class ModelArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||||
)
|
)
|
||||||
flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field(
|
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
|
||||||
default="auto",
|
default="auto",
|
||||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||||
)
|
)
|
||||||
|
|
|
@ -58,10 +58,10 @@ def patch_config(
|
||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
if model_args.infer_dtype == "auto":
|
if model_args.infer_dtype != "auto" and not is_trainable:
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
|
||||||
else:
|
|
||||||
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
|
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
|
||||||
|
else:
|
||||||
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
|
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
||||||
|
|
|
@ -87,6 +87,7 @@ class WebChatModel(ChatModel):
|
||||||
visual_inputs=get("top.visual_inputs"),
|
visual_inputs=get("top.visual_inputs"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
infer_backend=get("infer.infer_backend"),
|
infer_backend=get("infer.infer_backend"),
|
||||||
|
infer_dtype=get("infer.infer_dtype"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoint_path:
|
if checkpoint_path:
|
||||||
|
|
|
@ -32,15 +32,26 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
input_elems = engine.manager.get_base_elems()
|
input_elems = engine.manager.get_base_elems()
|
||||||
elem_dict = dict()
|
elem_dict = dict()
|
||||||
|
|
||||||
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
|
with gr.Row():
|
||||||
|
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
|
||||||
|
infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
load_btn = gr.Button()
|
load_btn = gr.Button()
|
||||||
unload_btn = gr.Button()
|
unload_btn = gr.Button()
|
||||||
|
|
||||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||||
|
|
||||||
input_elems.update({infer_backend})
|
input_elems.update({infer_backend, infer_dtype})
|
||||||
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
elem_dict.update(
|
||||||
|
dict(
|
||||||
|
infer_backend=infer_backend,
|
||||||
|
infer_dtype=infer_dtype,
|
||||||
|
load_btn=load_btn,
|
||||||
|
unload_btn=unload_btn,
|
||||||
|
info_box=info_box,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
||||||
elem_dict.update(chat_elems)
|
elem_dict.update(chat_elems)
|
||||||
|
|
|
@ -1206,6 +1206,17 @@ LOCALES = {
|
||||||
"label": "推理引擎",
|
"label": "推理引擎",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"infer_dtype": {
|
||||||
|
"en": {
|
||||||
|
"label": "Inference data type",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Тип данных для вывода",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "推理数据类型",
|
||||||
|
},
|
||||||
|
},
|
||||||
"load_btn": {
|
"load_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Load model",
|
"value": "Load model",
|
||||||
|
|
Loading…
Reference in New Issue