import gc from io import BytesIO import requests import timm import torch from PIL import Image from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from torchvision.transforms import InterpolationMode, transforms from accelerate import init_empty_weights, load_checkpoint_and_dispatch import os,sys sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) from vis_fm9g.generation.vllm_fm9g import VLLMFM9GBeamSearch from vis_fm9g.model.fm9g import FM9GConfig, FM9GTorch from vis_fm9g.model.vlu_fm9g import VLU_FM9G from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer from vis_fm9g.utils.constants import SYSTEM def chat(model, image, question, context, tokenizer, query_nums=64, vision_hidden_states=None, max_length=1024): if not context: question = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + question final_input = f'{SYSTEM}<用户>{question}' else: final_input = f'{context}<用户>{question}' data_list = [ {'input': final_input} ] res, vision_hidden_states = model.generate( data_list=data_list, max_inp_length=2048, beam_size=3, img_list=[[image]], max_length=max_length, repetition_penalty=1.1, temperature=0.7, length_penalty=3, return_vision_hidden_states=True ) answer = res[0] context = final_input + answer return answer, context, vision_hidden_states def load_llm(llm_path): config = FM9GConfig.from_json_file(llm_path) config.use_flash_attn = False cpm_model = FM9GTorch(config) return cpm_model def load_vpm(vision_encoder, drop_vision_last_layer=False): model = timm.create_model( vision_encoder, pretrained=False, num_classes=0, dynamic_img_size=True, dynamic_img_pad=True ) if isinstance(model, timm.models.VisionTransformer): if model.attn_pool is not None: model.attn_pool = torch.nn.Identity() if drop_vision_last_layer: model.blocks[-1] = torch.nn.Identity() return model def load_vis_fm9g(llm_path, vision_encoder): llm =load_llm(llm_path) vpm = load_vpm(vision_encoder, drop_vision_last_layer=False) vision_dim = vpm.embed_dim model = VLU_FM9G(llm, vpm, vision_dim, query_num=256) return model def load_tokenizer(vocabs_path): return FM9GTokenizer(vocabs_path) def load_transform(img_size): transform = transforms.Compose([ transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD) ]) return transform if __name__ == '__main__': root = "checkpoint/" llm_path = root + 'config.json' vocabs_path = root + 'vocabs.txt' model_checkpoint = root + 'sharded' vision_encoder = 'eva02_enormous_patch14_clip_224.laion2b_plus' img_size = 448 with init_empty_weights(): model = load_vis_fm9g(llm_path, vision_encoder) model = load_checkpoint_and_dispatch(model, model_checkpoint, device_map="auto", max_memory={0: "24GiB", 1: "24GiB"}, no_split_module_classes=['EvaBlockPostNorm']) model.eval() tokenizer = load_tokenizer(vocabs_path) transform = load_transform(img_size) beam_search = VLLMFM9GBeamSearch(model, tokenizer, transform) # 图像输入 url = 'test.jpg' image = Image.open(url).convert('RGB') # 文本输入 prompt = '这幅图描述了什么?' answer, context, _ = chat( beam_search, image, prompt, context=None, tokenizer=tokenizer, query_nums=256 ) print(answer)