add Baichuan2 models

This commit is contained in:
hiyouga 2023-09-06 18:36:04 +08:00
parent 9224db90ea
commit 62ce65c628
3 changed files with 11 additions and 2 deletions

View File

@ -46,11 +46,16 @@ SUPPORTED_MODELS = {
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
"Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
"InternLM-7B": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"XVERSE-13B": "xverse/XVERSE-13B",
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
}
@ -62,6 +67,7 @@ DEFAULT_MODULE = {
"BLOOMZ": "query_key_value",
"Falcon": "query_key_value",
"Baichuan": "W_pack",
"Baichuan2": "W_pack",
"InternLM": "q_proj,v_proj",
"Qwen": "c_attn",
"XVERSE": "q_proj,v_proj",
@ -72,7 +78,9 @@ DEFAULT_TEMPLATE = {
"LLaMA2": "llama2",
"ChineseLLaMA2": "llama2_zh",
"Baichuan": "baichuan",
"Baichuan2": "baichuan",
"InternLM": "intern",
"Qwen": "chatml",
"XVERSE": "xverse",
"ChatGLM2": "chatglm2"
}

View File

@ -1,3 +1,4 @@
import gc
import torch
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
@ -98,6 +99,7 @@ def torch_gc() -> None:
r"""
Collects GPU memory.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View File

@ -490,8 +490,7 @@ register_template(
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[],
stop_words=[]
sep=[]
)