forked from p04798526/LLaMA-Factory-Mirror
fix #4410
This commit is contained in:
parent
e0014db7d2
commit
fca893d73c
|
@ -34,8 +34,8 @@ DEFAULT_TOOL_PROMPT = (
|
|||
|
||||
|
||||
GLM4_TOOL_PROMPT = (
|
||||
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
|
||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -97,7 +97,7 @@ class ModelArguments:
|
|||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field(
|
||||
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||
)
|
||||
|
|
|
@ -58,10 +58,10 @@ def patch_config(
|
|||
is_trainable: bool,
|
||||
) -> None:
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
if model_args.infer_dtype == "auto":
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
else:
|
||||
if model_args.infer_dtype != "auto" and not is_trainable:
|
||||
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
|
||||
else:
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
if is_torch_npu_available():
|
||||
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
||||
|
|
|
@ -87,6 +87,7 @@ class WebChatModel(ChatModel):
|
|||
visual_inputs=get("top.visual_inputs"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
infer_backend=get("infer.infer_backend"),
|
||||
infer_dtype=get("infer.infer_dtype"),
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
|
|
|
@ -32,15 +32,26 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
|
||||
infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto")
|
||||
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
input_elems.update({infer_backend})
|
||||
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||
input_elems.update({infer_backend, infer_dtype})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
infer_backend=infer_backend,
|
||||
infer_dtype=infer_dtype,
|
||||
load_btn=load_btn,
|
||||
unload_btn=unload_btn,
|
||||
info_box=info_box,
|
||||
)
|
||||
)
|
||||
|
||||
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(chat_elems)
|
||||
|
|
|
@ -1206,6 +1206,17 @@ LOCALES = {
|
|||
"label": "推理引擎",
|
||||
},
|
||||
},
|
||||
"infer_dtype": {
|
||||
"en": {
|
||||
"label": "Inference data type",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Тип данных для вывода",
|
||||
},
|
||||
"zh": {
|
||||
"label": "推理数据类型",
|
||||
},
|
||||
},
|
||||
"load_btn": {
|
||||
"en": {
|
||||
"value": "Load model",
|
||||
|
|
Loading…
Reference in New Issue