126 lines
3.6 KiB
Python
126 lines
3.6 KiB
Python
import gc
|
|
from io import BytesIO
|
|
|
|
import requests
|
|
import timm
|
|
import torch
|
|
from PIL import Image
|
|
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
|
from torchvision.transforms import InterpolationMode, transforms
|
|
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
|
|
|
import os,sys
|
|
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
|
from vis_fm9g.generation.vllm_fm9g import VLLMFM9GBeamSearch
|
|
from vis_fm9g.model.fm9g import FM9GConfig, FM9GTorch
|
|
from vis_fm9g.model.vlu_fm9g import VLU_FM9G
|
|
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
|
|
from vis_fm9g.utils.constants import SYSTEM
|
|
|
|
|
|
def chat(model, image, question, context, tokenizer, query_nums=64, vision_hidden_states=None, max_length=1024):
|
|
if not context:
|
|
question = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + question
|
|
final_input = f'{SYSTEM}<用户>{question}<AI>'
|
|
else:
|
|
final_input = f'{context}<用户>{question}<AI>'
|
|
|
|
data_list = [
|
|
{'input': final_input}
|
|
]
|
|
|
|
res, vision_hidden_states = model.generate(
|
|
data_list=data_list,
|
|
max_inp_length=2048,
|
|
beam_size=3,
|
|
img_list=[[image]],
|
|
max_length=max_length,
|
|
repetition_penalty=1.1,
|
|
temperature=0.7,
|
|
length_penalty=3,
|
|
return_vision_hidden_states=True
|
|
)
|
|
|
|
answer = res[0]
|
|
|
|
context = final_input + answer
|
|
return answer, context, vision_hidden_states
|
|
|
|
|
|
def load_llm(llm_path):
|
|
config = FM9GConfig.from_json_file(llm_path)
|
|
config.use_flash_attn = False
|
|
cpm_model = FM9GTorch(config)
|
|
return cpm_model
|
|
|
|
|
|
def load_vpm(vision_encoder, drop_vision_last_layer=False):
|
|
model = timm.create_model(
|
|
vision_encoder,
|
|
pretrained=False,
|
|
num_classes=0,
|
|
dynamic_img_size=True,
|
|
dynamic_img_pad=True
|
|
)
|
|
|
|
if isinstance(model, timm.models.VisionTransformer):
|
|
if model.attn_pool is not None:
|
|
model.attn_pool = torch.nn.Identity()
|
|
|
|
if drop_vision_last_layer:
|
|
model.blocks[-1] = torch.nn.Identity()
|
|
|
|
return model
|
|
|
|
|
|
def load_vis_fm9g(llm_path, vision_encoder):
|
|
llm =load_llm(llm_path)
|
|
vpm = load_vpm(vision_encoder, drop_vision_last_layer=False)
|
|
|
|
vision_dim = vpm.embed_dim
|
|
model = VLU_FM9G(llm, vpm, vision_dim, query_num=256)
|
|
return model
|
|
|
|
|
|
def load_tokenizer(vocabs_path):
|
|
return FM9GTokenizer(vocabs_path)
|
|
|
|
|
|
def load_transform(img_size):
|
|
transform = transforms.Compose([
|
|
transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD)
|
|
])
|
|
|
|
return transform
|
|
|
|
|
|
if __name__ == '__main__':
|
|
root = "checkpoint/"
|
|
llm_path = root + 'config.json'
|
|
vocabs_path = root + 'vocabs.txt'
|
|
model_checkpoint = root + 'sharded'
|
|
vision_encoder = 'eva02_enormous_patch14_clip_224.laion2b_plus'
|
|
img_size = 448
|
|
|
|
|
|
with init_empty_weights():
|
|
model = load_vis_fm9g(llm_path, vision_encoder)
|
|
|
|
model = load_checkpoint_and_dispatch(model, model_checkpoint, device_map="auto", max_memory={0: "24GiB", 1: "24GiB"}, no_split_module_classes=['EvaBlockPostNorm'])
|
|
model.eval()
|
|
|
|
tokenizer = load_tokenizer(vocabs_path)
|
|
transform = load_transform(img_size)
|
|
beam_search = VLLMFM9GBeamSearch(model, tokenizer, transform)
|
|
# 图像输入
|
|
url = 'test.jpg'
|
|
image = Image.open(url).convert('RGB')
|
|
# 文本输入
|
|
prompt = '这幅图描述了什么?'
|
|
answer, context, _ = chat(
|
|
beam_search, image, prompt, context=None, tokenizer=tokenizer, query_nums=256
|
|
)
|
|
|
|
print(answer) |