CPM-9G-8B/FM_9G/chat_model.py

126 lines
3.6 KiB
Python

import gc
from io import BytesIO
import requests
import timm
import torch
from PIL import Image
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from torchvision.transforms import InterpolationMode, transforms
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import os,sys
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
from vis_fm9g.generation.vllm_fm9g import VLLMFM9GBeamSearch
from vis_fm9g.model.fm9g import FM9GConfig, FM9GTorch
from vis_fm9g.model.vlu_fm9g import VLU_FM9G
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
from vis_fm9g.utils.constants import SYSTEM
def chat(model, image, question, context, tokenizer, query_nums=64, vision_hidden_states=None, max_length=1024):
if not context:
question = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + question
final_input = f'{SYSTEM}<用户>{question}<AI>'
else:
final_input = f'{context}<用户>{question}<AI>'
data_list = [
{'input': final_input}
]
res, vision_hidden_states = model.generate(
data_list=data_list,
max_inp_length=2048,
beam_size=3,
img_list=[[image]],
max_length=max_length,
repetition_penalty=1.1,
temperature=0.7,
length_penalty=3,
return_vision_hidden_states=True
)
answer = res[0]
context = final_input + answer
return answer, context, vision_hidden_states
def load_llm(llm_path):
config = FM9GConfig.from_json_file(llm_path)
config.use_flash_attn = False
cpm_model = FM9GTorch(config)
return cpm_model
def load_vpm(vision_encoder, drop_vision_last_layer=False):
model = timm.create_model(
vision_encoder,
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True
)
if isinstance(model, timm.models.VisionTransformer):
if model.attn_pool is not None:
model.attn_pool = torch.nn.Identity()
if drop_vision_last_layer:
model.blocks[-1] = torch.nn.Identity()
return model
def load_vis_fm9g(llm_path, vision_encoder):
llm =load_llm(llm_path)
vpm = load_vpm(vision_encoder, drop_vision_last_layer=False)
vision_dim = vpm.embed_dim
model = VLU_FM9G(llm, vpm, vision_dim, query_num=256)
return model
def load_tokenizer(vocabs_path):
return FM9GTokenizer(vocabs_path)
def load_transform(img_size):
transform = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD)
])
return transform
if __name__ == '__main__':
root = "checkpoint/"
llm_path = root + 'config.json'
vocabs_path = root + 'vocabs.txt'
model_checkpoint = root + 'sharded'
vision_encoder = 'eva02_enormous_patch14_clip_224.laion2b_plus'
img_size = 448
with init_empty_weights():
model = load_vis_fm9g(llm_path, vision_encoder)
model = load_checkpoint_and_dispatch(model, model_checkpoint, device_map="auto", max_memory={0: "24GiB", 1: "24GiB"}, no_split_module_classes=['EvaBlockPostNorm'])
model.eval()
tokenizer = load_tokenizer(vocabs_path)
transform = load_transform(img_size)
beam_search = VLLMFM9GBeamSearch(model, tokenizer, transform)
# 图像输入
url = 'test.jpg'
image = Image.open(url).convert('RGB')
# 文本输入
prompt = '这幅图描述了什么?'
answer, context, _ = chat(
beam_search, image, prompt, context=None, tokenizer=tokenizer, query_nums=256
)
print(answer)