This commit is contained in:
hiyouga 2024-06-24 22:34:31 +08:00
parent e0014db7d2
commit fca893d73c
6 changed files with 32 additions and 9 deletions

View File

@ -34,8 +34,8 @@ DEFAULT_TOOL_PROMPT = (
GLM4_TOOL_PROMPT = (
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
)

View File

@ -97,7 +97,7 @@ class ModelArguments:
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field(
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
default="auto",
metadata={"help": "Enable FlashAttention for faster training and inference."},
)

View File

@ -58,10 +58,10 @@ def patch_config(
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
if model_args.infer_dtype == "auto":
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
else:
if model_args.infer_dtype != "auto" and not is_trainable:
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
else:
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available():
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]

View File

@ -87,6 +87,7 @@ class WebChatModel(ChatModel):
visual_inputs=get("top.visual_inputs"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"),
)
if checkpoint_path:

View File

@ -32,15 +32,26 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row():
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto")
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
input_elems.update({infer_backend})
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
input_elems.update({infer_backend, infer_dtype})
elem_dict.update(
dict(
infer_backend=infer_backend,
infer_dtype=infer_dtype,
load_btn=load_btn,
unload_btn=unload_btn,
info_box=info_box,
)
)
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(chat_elems)

View File

@ -1206,6 +1206,17 @@ LOCALES = {
"label": "推理引擎",
},
},
"infer_dtype": {
"en": {
"label": "Inference data type",
},
"ru": {
"label": "Тип данных для вывода",
},
"zh": {
"label": "推理数据类型",
},
},
"load_btn": {
"en": {
"value": "Load model",