Compare commits
13 Commits
Author | SHA1 | Date |
---|---|---|
|
52dbf63007 | |
|
9aeefe95d5 | |
|
a214708255 | |
|
25c3643a05 | |
|
45d7c9b99d | |
|
23d3844492 | |
![]() |
c89395164e | |
![]() |
8e693d5876 | |
![]() |
415c624322 | |
![]() |
4139ba5dfe | |
![]() |
a8d431c14f | |
![]() |
1857f60d1e | |
![]() |
a041469104 |
|
@ -0,0 +1,193 @@
|
|||
import gc
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import timm
|
||||
import torch
|
||||
import random
|
||||
import json
|
||||
from PIL import Image
|
||||
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from torchvision.transforms import InterpolationMode, transforms
|
||||
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
||||
|
||||
import os,sys
|
||||
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
||||
from vis_fm9g.generation.vllm_fm9g import VLLMFM9GBeamSearch
|
||||
from vis_fm9g.model.fm9g import FM9GConfig, FM9GTorch
|
||||
from vis_fm9g.model.vlu_fm9g import VLU_FM9G
|
||||
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
|
||||
from vis_fm9g.utils.constants import SYSTEM
|
||||
|
||||
def random_selection_device(device_num):
|
||||
device_list = [f"cuda:{i}" for i in range(device_num)]
|
||||
if len(device_list) == 1:
|
||||
return [device_list[0]] * 4
|
||||
elif len(device_list) == 2:
|
||||
a, b = device_list
|
||||
return [a, b, a, b]
|
||||
else:
|
||||
selected = random.sample(device_list, 3)
|
||||
repeated = random.choice(selected)
|
||||
return selected + [repeated]
|
||||
|
||||
def has_pt_file(directory):
|
||||
pt_files = []
|
||||
files = os.listdir(directory)
|
||||
for file in files:
|
||||
if file.endswith(".pt"):
|
||||
full_path = os.path.join(directory, file)
|
||||
pt_files.append(full_path)
|
||||
return pt_files
|
||||
|
||||
def load_checkpoint(model, pretrained_model_name_or_path, device_num=4):
|
||||
pt_files = has_pt_file(pretrained_model_name_or_path)
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
if pt_files:
|
||||
device_map = {}
|
||||
state_dict = model.state_dict()
|
||||
layer_names = list(state_dict.keys())
|
||||
device_list = random_selection_device(device_num=device_num)
|
||||
for i, layer_name in enumerate(layer_names):
|
||||
if "vpm" in layer_name: # 假设 vpm 参数包含 "vpm" 前缀
|
||||
device_map[layer_name] = device_list[0]
|
||||
elif "llm" in layer_name: # 假设 llm 参数包含 "llm" 前缀
|
||||
device_map[layer_name] = device_list[1]
|
||||
else:
|
||||
device_map[layer_name] = device_list[2]
|
||||
model = load_checkpoint_and_dispatch(model, pt_files[0], device_map=device_map)
|
||||
model.to(device)
|
||||
return model
|
||||
else:
|
||||
model_checkpoint = os.path.join(pretrained_model_name_or_path, "sharded")
|
||||
index_file = os.path.join(pretrained_model_name_or_path, "sharded", "model.safetensors.index.json")
|
||||
with open(index_file, "r") as f:
|
||||
index_data = json.load(f)
|
||||
weight_map = index_data["weight_map"]
|
||||
all_params = {name for name, _ in model.named_parameters()}
|
||||
mapped_params = set(weight_map.keys())
|
||||
unmapped_params = all_params - mapped_params
|
||||
# 解析 weight_map,设置 device_map
|
||||
#TODO:强制不同的分支放置到不同的GPU上,避免出现计算问题,但是比较丑陋,需要做后续优化
|
||||
device_list = random_selection_device(device_num=device_num)
|
||||
device_map = {}
|
||||
for param_name, weight_file in weight_map.items():
|
||||
if "vpm" in param_name: # 假设 vpm 参数包含 "vpm" 前缀
|
||||
device_map[param_name] = device_list[0]
|
||||
elif "llm" in param_name: # 假设 llm 参数包含 "llm" 前缀
|
||||
device_map[param_name] = device_list[1]
|
||||
else:
|
||||
device_map[param_name] = device_list[2]
|
||||
for param_name in unmapped_params:
|
||||
device_map[param_name] = device_list[3] # 随机选择设备
|
||||
model = load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_checkpoint,
|
||||
device_map=device_map,
|
||||
).to(device)
|
||||
return model
|
||||
|
||||
def chat(model, image, question, context, tokenizer, query_nums=64, vision_hidden_states=None, max_length=1024):
|
||||
if not context:
|
||||
question = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + question
|
||||
final_input = f'{SYSTEM}<用户>{question}<AI>'
|
||||
else:
|
||||
final_input = f'{context}<用户>{question}<AI>'
|
||||
|
||||
data_list = [
|
||||
{'input': final_input}
|
||||
]
|
||||
|
||||
res, vision_hidden_states = model.generate(
|
||||
data_list=data_list,
|
||||
max_inp_length=2048,
|
||||
beam_size=3,
|
||||
img_list=[[image]],
|
||||
max_length=max_length,
|
||||
repetition_penalty=1.1,
|
||||
temperature=0.7,
|
||||
length_penalty=3,
|
||||
return_vision_hidden_states=True
|
||||
)
|
||||
|
||||
answer = res[0]
|
||||
|
||||
context = final_input + answer
|
||||
return answer, context, vision_hidden_states
|
||||
|
||||
|
||||
def load_llm(llm_path):
|
||||
config = FM9GConfig.from_json_file(llm_path)
|
||||
config.use_flash_attn = False
|
||||
cpm_model = FM9GTorch(config)
|
||||
return cpm_model
|
||||
|
||||
|
||||
def load_vpm(vision_encoder, drop_vision_last_layer=False):
|
||||
model = timm.create_model(
|
||||
vision_encoder,
|
||||
pretrained=False,
|
||||
num_classes=0,
|
||||
dynamic_img_size=True,
|
||||
dynamic_img_pad=True
|
||||
)
|
||||
|
||||
if isinstance(model, timm.models.VisionTransformer):
|
||||
if model.attn_pool is not None:
|
||||
model.attn_pool = torch.nn.Identity()
|
||||
|
||||
if drop_vision_last_layer:
|
||||
model.blocks[-1] = torch.nn.Identity()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_vis_fm9g(llm_path, vision_encoder):
|
||||
llm =load_llm(llm_path)
|
||||
vpm = load_vpm(vision_encoder, drop_vision_last_layer=False)
|
||||
|
||||
vision_dim = vpm.embed_dim
|
||||
model = VLU_FM9G(llm, vpm, vision_dim, query_num=256)
|
||||
return model
|
||||
|
||||
|
||||
def load_tokenizer(vocabs_path):
|
||||
return FM9GTokenizer(vocabs_path)
|
||||
|
||||
|
||||
def load_transform(img_size):
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD)
|
||||
])
|
||||
|
||||
return transform
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
root = "checkpoint/"
|
||||
llm_path = root + 'config.json'
|
||||
vocabs_path = root + 'vocabs.txt'
|
||||
vision_encoder = 'eva02_enormous_patch14_clip_224.laion2b_plus'
|
||||
img_size = 448
|
||||
device_num = 4
|
||||
|
||||
with init_empty_weights():
|
||||
model = load_vis_fm9g(llm_path, vision_encoder)
|
||||
model = load_checkpoint(model, root, device_num=device_num)
|
||||
model.eval()
|
||||
|
||||
tokenizer = load_tokenizer(vocabs_path)
|
||||
transform = load_transform(img_size)
|
||||
beam_search = VLLMFM9GBeamSearch(model, tokenizer, transform)
|
||||
# 图像输入
|
||||
url = 'test.jpg'
|
||||
image = Image.open(url).convert('RGB')
|
||||
# 文本输入
|
||||
prompt = '这幅图描述了什么?'
|
||||
answer, context, _ = chat(
|
||||
beam_search, image, prompt, context=None, tokenizer=tokenizer, query_nums=256
|
||||
)
|
||||
|
||||
print(answer)
|
|
@ -0,0 +1,44 @@
|
|||
[
|
||||
{
|
||||
"id": "000000052347",
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "<image>\nPlease generate detailed descriptions of the given image."
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "The image depicts a tranquil rural scene. In the foreground, there are two cows. The cow on the left is white and is standing, while the cow on the right is brown and is lying down. Both cows appear to be resting or grazing. The background of the image shows a vast landscape with rolling hills and a clear sky. The overall atmosphere of the image is peaceful and serene, capturing a moment of calm in the countryside."
|
||||
}
|
||||
],
|
||||
"image": "00000/000000052347.jpg"
|
||||
},
|
||||
{
|
||||
"id": "000000382625",
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "<image>\nPlease generate detailed descriptions of the given image."
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "The image depicts a group of individuals, likely in a celebratory or formal setting. In the foreground, there is a table covered with a blue tablecloth, upon which sit three cakes. The cakes are decorated with various designs, including what appears to be a flag and other symbols, suggesting a theme related to a specific organization or event.\n\nBehind the table, there are six people standing, arranged in two rows. The front row consists of three individuals, two of whom are wearing chef's hats, indicating they may be chefs or cooks. The third person in the front row is wearing a military uniform, suggesting a connection to the military.\n\nThe back row features three more individuals, two of whom are also in military uniforms, while the third is wearing a tie, which could imply a formal or administrative role. All individuals are facing the camera and appear to be posing for the photograph.\n\nThe setting appears to be an indoor space, possibly a dining area or a kitchen, given the presence of the table and the attire of the individuals. The lighting is bright, and the overall atmosphere seems to be one of celebration or recognition.\n\nThere are no visible texts or logos that provide specific information about the event or the individuals. The style of the image is a standard, candid photograph capturing a moment during an event."
|
||||
}
|
||||
],
|
||||
"image": "00000/000000382625.jpg"
|
||||
},
|
||||
{
|
||||
"id": "000000559875",
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "<image>\nPlease generate detailed descriptions of the given image."
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "In the image, a young girl is the main subject. She is dressed in a white tennis outfit, complete with a matching white visor. In her hands, she holds a tennis racket, which is predominantly yellow and black. The racket is adorned with the word \"HEAD\" in bold orange letters.\n\nThe girl is standing on a concrete surface, which suggests an outdoor setting. Behind her, there's a white wall, on which a black and white poster is hanging. The poster features a tennis player, further emphasizing the theme of the image.\n\nThe girl is also holding a small trophy in her left hand. The trophy is gold in color, adding a touch of elegance to the scene. The girl's smile suggests she is proud of her achievement.\n\nOverall, the image captures a moment of joy and accomplishment in the world of tennis."
|
||||
}
|
||||
],
|
||||
"image": "00000/000000559875.jpg"
|
||||
}
|
||||
]
|
After Width: | Height: | Size: 46 KiB |
After Width: | Height: | Size: 52 KiB |
After Width: | Height: | Size: 30 KiB |
|
@ -0,0 +1,44 @@
|
|||
[
|
||||
{
|
||||
"id": "000000052347",
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "<image>\nPlease generate detailed descriptions of the given image."
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "The image depicts a tranquil rural scene. In the foreground, there are two cows. The cow on the left is white and is standing, while the cow on the right is brown and is lying down. Both cows appear to be resting or grazing. The background of the image shows a vast landscape with rolling hills and a clear sky. The overall atmosphere of the image is peaceful and serene, capturing a moment of calm in the countryside."
|
||||
}
|
||||
],
|
||||
"image": "00000/000000052347.jpg"
|
||||
},
|
||||
{
|
||||
"id": "000000382625",
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "<image>\nPlease generate detailed descriptions of the given image."
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "The image depicts a group of individuals, likely in a celebratory or formal setting. In the foreground, there is a table covered with a blue tablecloth, upon which sit three cakes. The cakes are decorated with various designs, including what appears to be a flag and other symbols, suggesting a theme related to a specific organization or event.\n\nBehind the table, there are six people standing, arranged in two rows. The front row consists of three individuals, two of whom are wearing chef's hats, indicating they may be chefs or cooks. The third person in the front row is wearing a military uniform, suggesting a connection to the military.\n\nThe back row features three more individuals, two of whom are also in military uniforms, while the third is wearing a tie, which could imply a formal or administrative role. All individuals are facing the camera and appear to be posing for the photograph.\n\nThe setting appears to be an indoor space, possibly a dining area or a kitchen, given the presence of the table and the attire of the individuals. The lighting is bright, and the overall atmosphere seems to be one of celebration or recognition.\n\nThere are no visible texts or logos that provide specific information about the event or the individuals. The style of the image is a standard, candid photograph capturing a moment during an event."
|
||||
}
|
||||
],
|
||||
"image": "00000/000000382625.jpg"
|
||||
},
|
||||
{
|
||||
"id": "000000559875",
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "<image>\nPlease generate detailed descriptions of the given image."
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "In the image, a young girl is the main subject. She is dressed in a white tennis outfit, complete with a matching white visor. In her hands, she holds a tennis racket, which is predominantly yellow and black. The racket is adorned with the word \"HEAD\" in bold orange letters.\n\nThe girl is standing on a concrete surface, which suggests an outdoor setting. Behind her, there's a white wall, on which a black and white poster is hanging. The poster features a tennis player, further emphasizing the theme of the image.\n\nThe girl is also holding a small trophy in her left hand. The trophy is gold in color, adding a touch of elegance to the scene. The girl's smile suggests she is proud of her achievement.\n\nOverall, the image captures a moment of joy and accomplishment in the world of tennis."
|
||||
}
|
||||
],
|
||||
"image": "00000/000000559875.jpg"
|
||||
}
|
||||
]
|
After Width: | Height: | Size: 46 KiB |
After Width: | Height: | Size: 52 KiB |
After Width: | Height: | Size: 30 KiB |
|
@ -0,0 +1,3 @@
|
|||
0
|
||||
63557
|
||||
137410
|
|
@ -0,0 +1,3 @@
|
|||
0
|
||||
63557
|
||||
137410
|
|
@ -0,0 +1,130 @@
|
|||
import os
|
||||
import json
|
||||
import base64
|
||||
|
||||
def json_to_tsv_with_base64(json_file, output_base_path):
|
||||
"""
|
||||
将 JSON 文件转换为带 Base64 编码的 TSV 文件,并生成自定义命名的 .lineidx 文件。
|
||||
|
||||
Args:
|
||||
json_file: 输入的 JSON 文件路径。
|
||||
output_base_path: 默认的输出路径,在此路径下创建文件夹。
|
||||
data_weight: 数据权重定义。
|
||||
"""
|
||||
# 获取 JSON 文件所在的目录
|
||||
json_dir = os.path.dirname(os.path.abspath(json_file))
|
||||
image_root = os.path.join(json_dir, "images") # 根目录 + "images"
|
||||
json_filename = os.path.splitext(os.path.basename(json_file))[0]
|
||||
data_name = json_filename.split(".")[0]
|
||||
|
||||
# 创建新的文件夹,命名为 JSON 文件名
|
||||
output_folder = os.path.join(output_base_path, json_filename)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# 读取 JSON 数据
|
||||
with open(json_file, 'r') as jf:
|
||||
data = json.load(jf)
|
||||
|
||||
# 获取数据条目数
|
||||
num_entries = len(data)
|
||||
|
||||
# 构建 TSV 和 .lineidx 文件路径
|
||||
tsv_file = os.path.join(output_folder, f"{json_filename}-{num_entries}.tsv")
|
||||
lineidx_file = os.path.join(output_folder, f"{json_filename}-{num_entries}.lineidx")
|
||||
|
||||
# 写入 TSV 文件
|
||||
with open(tsv_file, 'w') as tsv_out:
|
||||
for item in data:
|
||||
# 提取信息
|
||||
item_id = item["id"]
|
||||
image_id = item["image"]
|
||||
# 图像路径调整为 JSON 文件所在目录的 "images" 子目录
|
||||
image_path = os.path.join(image_root, image_id)
|
||||
if not os.path.exists(image_path):
|
||||
raise FileNotFoundError(f"Image not found: {image_path}")
|
||||
|
||||
# 图像编码为 Base64
|
||||
with open(image_path, "rb") as img_file:
|
||||
image_base64 = base64.b64encode(img_file.read()).decode('utf-8')
|
||||
|
||||
# 保留原始 conversations 格式并编码为 Base64
|
||||
conversations_json = json.dumps(item["conversations"], ensure_ascii=False)
|
||||
conversations_base64 = base64.b64encode(conversations_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
# 写入 TSV
|
||||
line = f"{data_name}\t{image_base64}\t{conversations_base64}\t{image_root}\t{image_root}\t{item_id}\t{image_id}\n"
|
||||
tsv_out.write(line)
|
||||
|
||||
# 生成自定义命名的 .lineidx 文件
|
||||
create_lineidx(tsv_file, lineidx_file)
|
||||
|
||||
print(f"TSV 文件生成完成: {tsv_file}")
|
||||
print(f"索引文件生成完成: {lineidx_file}")
|
||||
|
||||
return output_folder
|
||||
|
||||
def create_lineidx(filein, idxout):
|
||||
"""
|
||||
根据 TSV 文件生成自定义命名的 .lineidx 索引文件。
|
||||
|
||||
Args:
|
||||
filein: 输入的 TSV 文件路径。
|
||||
idxout: 输出的索引文件路径。
|
||||
"""
|
||||
idxout_tmp = idxout + '.tmp'
|
||||
with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
|
||||
fsize = os.fstat(tsvin.fileno()).st_size
|
||||
fpos = 0
|
||||
while fpos != fsize:
|
||||
tsvout.write(str(fpos) + "\n")
|
||||
tsvin.readline()
|
||||
fpos = tsvin.tell()
|
||||
os.rename(idxout_tmp, idxout)
|
||||
|
||||
def process_multiple_json(json_files, output_base_path, dataset_weights, train_data=True):
|
||||
"""
|
||||
Process a list of JSON files, converting each to TSV and creating a data.json file with weights.
|
||||
|
||||
Args:
|
||||
json_files: List of paths to JSON files.
|
||||
output_base_path: Base directory to store TSV files.
|
||||
dataset_weights: Dictionary containing the weight for each dataset.
|
||||
"""
|
||||
data_sources = []
|
||||
|
||||
for json_file in json_files:
|
||||
# Get the weight from the dataset_weights dictionary
|
||||
json_filename = os.path.splitext(os.path.basename(json_file))[0]
|
||||
data_weight = dataset_weights.get(json_filename, 1) # Default weight is 1
|
||||
|
||||
# Convert JSON to TSV and get the path
|
||||
tsv_path = json_to_tsv_with_base64(json_file, output_base_path)
|
||||
|
||||
# Append to the data_sources list
|
||||
data_sources.append({
|
||||
"data_source_name": tsv_path,
|
||||
"data_source_weight": data_weight
|
||||
})
|
||||
|
||||
# 保存数据json文件,默认存储到vis_fm9g/config/data下,用于数据索引
|
||||
if train_data:
|
||||
data_json_path = os.path.join("vis_fm9g/config/data", "data.json")
|
||||
else:
|
||||
data_json_path = os.path.join("vis_fm9g/config/data", "eval_data.json")
|
||||
with open(data_json_path, 'w') as f:
|
||||
json.dump(data_sources, f, ensure_ascii=False, indent=4)
|
||||
|
||||
print(f"data.json file generated: {data_json_path}")
|
||||
|
||||
## 示例使用 ##
|
||||
if __name__ == "__main__":
|
||||
# 输入json文件的位置以及对应的数据权重,如果没有设置数据权重,默认为1
|
||||
json_files = ["dataset/data_1/data_1.json", "dataset/data_2/data_2.json"]
|
||||
# json文件会转化成tsv文件格式,可以修改tsv_file_path来改变tsv文件的存储位置
|
||||
tsv_file_path = "dataset/train"
|
||||
# 设置数据权重
|
||||
dataset_weights = {
|
||||
"data": 1,
|
||||
"data_2": 2
|
||||
}
|
||||
process_multiple_json(json_files, tsv_file_path, dataset_weights, train_data=True)
|
|
@ -0,0 +1,12 @@
|
|||
## 环境安装 ##
|
||||
cd FM9G-V
|
||||
pip install -r requirements.txt
|
||||
|
||||
## demo使用 ##
|
||||
python chat.py
|
||||
|
||||
## 数据准备 ##
|
||||
python json2tsv.py
|
||||
|
||||
## 训练 ##
|
||||
bash run_train.sh
|
|
@ -0,0 +1,23 @@
|
|||
torch==2.0.1
|
||||
torchvision==0.15.2
|
||||
transformers==4.31.0
|
||||
tokenizers>=0.12.1,<0.14
|
||||
sentencepiece==0.1.99
|
||||
shortuuid
|
||||
peft==0.4.0
|
||||
bitsandbytes==0.41.0
|
||||
pydantic<2,>=1
|
||||
markdown2[all]
|
||||
numpy==1.23.5
|
||||
scikit-learn==1.2.2
|
||||
gradio==3.35.2
|
||||
gradio_client==0.2.9
|
||||
requests
|
||||
httpx==0.24.0
|
||||
uvicorn
|
||||
fastapi
|
||||
einops==0.6.1
|
||||
einops-exts==0.0.4
|
||||
timm==0.9.8
|
||||
deepspeed==0.11.1
|
||||
flash-attn==2.3.3
|
|
@ -0,0 +1,69 @@
|
|||
#!/bin/bash
|
||||
SELF_DIR=$(cd "$(dirname "$0")" || exit 1; pwd)
|
||||
export PYTHONPATH=$(dirname $SELF_DIR):$PYTHONPATH
|
||||
echo Working Directory at `pwd`
|
||||
echo Bash at `which bash`
|
||||
echo Python at `which python`
|
||||
|
||||
nvidia-smi
|
||||
|
||||
slave_or_master=$1
|
||||
|
||||
# 支持单机多卡和多机多卡
|
||||
GPUS_PER_NODE=8
|
||||
WORLD_SIZE=${WORLD_SIZE:-1}
|
||||
RANK=${RANK:-0}
|
||||
MASTER_ADDR=${MASTER_ADDR:-`hostname`}
|
||||
MASTER_PORT=${MASTER_PORT:-12345}
|
||||
|
||||
rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT
|
||||
|
||||
if [[ $slave_or_master == "slave" ]]; then
|
||||
rdzv_endpoint=$2:$MASTER_PORT
|
||||
fi
|
||||
|
||||
TRAIN_FILE=vis_fm9g/config/data/data.json
|
||||
EVAL_FILE=vis_fm9g/config/data/eval_data.json
|
||||
LLM_PATH=vis_fm9g/config/model/fm9g-7b.json
|
||||
VOCAB_PATH=vis_fm9g/config/vocab/fm9g.txt
|
||||
MODEL_CHECKPOINT=checkpoint/sharded
|
||||
|
||||
DEEPSPEED_CONFIG=vis_fm9g/config/deepspeed/sft.json
|
||||
EXPORT_DIR=./saved
|
||||
|
||||
echo WORLD_SIZE=$WORLD_SIZE RANK=$RANK rdzv_endpoint=$rdzv_endpoint
|
||||
|
||||
# --------------- 运行参数 ---------------
|
||||
OPTS=""
|
||||
OPTS+=" --self_dir ${SELF_DIR}"
|
||||
OPTS+=" --train_file ${TRAIN_FILE}"
|
||||
OPTS+=" --eval_file ${EVAL_FILE}"
|
||||
OPTS+=" --llm_path ${LLM_PATH}"
|
||||
OPTS+=" --vocabs_path ${VOCAB_PATH}"
|
||||
OPTS+=" --model_checkpoint ${MODEL_CHECKPOINT}"
|
||||
OPTS+=" --deepspeed_config ${DEEPSPEED_CONFIG}"
|
||||
OPTS+=" --save_deepspeed"
|
||||
OPTS+=" --export_dir ${EXPORT_DIR}"
|
||||
OPTS+=" --exp_name vis-fm9g-train-eval"
|
||||
OPTS+=" --flash cuda"
|
||||
|
||||
OPTS+=" --max_len 500"
|
||||
OPTS+=" --batch_size 1"
|
||||
OPTS+=" --save_step 1000"
|
||||
OPTS+=" --epochs 1"
|
||||
OPTS+=" --query_num 256"
|
||||
OPTS+=" --vision_encoder eva02_enormous_patch14_clip_224.laion2b_plus"
|
||||
OPTS+=" --tune_resampler"
|
||||
OPTS+=" --tune_vision"
|
||||
OPTS+=" --tune_llm"
|
||||
OPTS+=" --eval"
|
||||
|
||||
RUNNER="torchrun --nnodes=${WORLD_SIZE} --nproc_per_node=${GPUS_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${rdzv_endpoint}"
|
||||
|
||||
CMD="$RUNNER ./vis_fm9g/train/train_vis_fm9g.py ${OPTS}"
|
||||
echo "-------final CMD is------"
|
||||
echo "${CMD}"
|
||||
echo "-------final CMD end------"
|
||||
|
||||
# 执行 CMD
|
||||
eval "${CMD}"
|
After Width: | Height: | Size: 128 KiB |
|
@ -0,0 +1,115 @@
|
|||
""" Init a logger with options from env variables.
|
||||
|
||||
- set log level by ``LOG_LEVEL``, default: ``INFO``;
|
||||
- output log message to file by ``LOG_FILE``, default: output to stdout.
|
||||
|
||||
TODO:
|
||||
support setting log level and log file from config file.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
|
||||
_LOG_FMT = "[%(asctime)s][%(levelname).1s][%(process)d-%(name)s-%(filename)s:%(lineno)s]- %(message)s"
|
||||
_DATE_FMT = "%Y-%m-%d,%H:%M:%S"
|
||||
|
||||
_logging_level = {
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
"ERROR": logging.ERROR,
|
||||
"WARNING": logging.WARNING,
|
||||
"INFO": logging.INFO,
|
||||
# Distributed Level, print log in main proc only by default, set this level to print all messages.
|
||||
"DP": logging.INFO,
|
||||
"DEBUG": logging.DEBUG,
|
||||
None: logging.INFO,
|
||||
}
|
||||
|
||||
_level = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||
|
||||
|
||||
class ShortNameFormatter(logging.Formatter):
|
||||
def format(self, record: logging.LogRecord):
|
||||
raw = record.name # save and restore for other formatters if desired
|
||||
parts = raw.split(".")
|
||||
record.name = ".".join(p[:3] for p in parts) if len(parts) > 1 else raw # keep first char for module name.
|
||||
result = super().format(record)
|
||||
record.name = raw
|
||||
return result
|
||||
|
||||
|
||||
class StyleAdapter(logging.LoggerAdapter):
|
||||
def __init__(self, logger, extra=None, style="default"):
|
||||
super().__init__(logger, extra or {})
|
||||
self._style = style
|
||||
self._enable = self._enable()
|
||||
|
||||
@classmethod
|
||||
def _enable(cls):
|
||||
# Note: to make this Logger more standalone, perform basic check without extra deps, e.g. tf/torch et al.
|
||||
worker = os.getenv("WORKER")
|
||||
rank = os.getenv("RANK")
|
||||
# not in DP/DDP mode or proc_id = "0"
|
||||
is_main = (not worker and not rank) or (worker == "0" or rank == "0")
|
||||
is_jeeves_job = os.getenv("JEEVES_JOB_ID")
|
||||
return _level in ["DEBUG", "DP"] or is_jeeves_job or is_main
|
||||
|
||||
def _format(self, *msgs, color: str = None):
|
||||
if self._style == "legacy":
|
||||
if len(msgs) == 1:
|
||||
msg_str = msgs[0]
|
||||
else:
|
||||
msg_str = msgs[0] % msgs[1:]
|
||||
else:
|
||||
msg_str = ", ".join([str(msg) for msg in msgs])
|
||||
|
||||
if color:
|
||||
pass
|
||||
return msg_str
|
||||
|
||||
def log(self, level, msg, *args, **kwargs):
|
||||
color = kwargs.pop("color", None)
|
||||
if self.isEnabledFor(level) and self._enable:
|
||||
msg, kwargs = self.process(msg, kwargs)
|
||||
msg_str = self._format(msg, *args, color=color)
|
||||
# noinspection PyProtectedMember
|
||||
self.logger._log(level, msg_str, (), **kwargs)
|
||||
|
||||
|
||||
def init_logger(name="ai", filename=os.environ.get("LOG_FILE", ""), fmt=_LOG_FMT, level=_level, style="legacy"):
|
||||
"""init logger
|
||||
|
||||
Args:
|
||||
name(str): optional, default: ai.
|
||||
filename(str): optional, default: "". Output log to file if specified, by default is set by env `LOG_FILE`.
|
||||
fmt(str): optional, default: _LOG_FMT
|
||||
level(str): optional, default: INFO
|
||||
style(str): optional, choice from ["print", "legacy"]
|
||||
- legacy: take first argument as a formatter, the remaining positional arguments as message values.
|
||||
this is consistent with the constraint of `logging` pkg
|
||||
- print: all positional arguments are message values which will be concatenated with ", "
|
||||
|
||||
Returns:
|
||||
a logger instance
|
||||
|
||||
Examples:
|
||||
>>> log = init_logger("log2stdout", level="INFO")
|
||||
>>> log.error("info")
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_logging_level[level])
|
||||
if fmt:
|
||||
# formatter = logging.Formatter(fmt, datefmt=_DATE_FMT)
|
||||
formatter = ShortNameFormatter(fmt, datefmt=_DATE_FMT)
|
||||
else:
|
||||
formatter = None
|
||||
|
||||
if not logger.hasHandlers():
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logging.basicConfig(format=fmt, level=_logging_level[_level], handlers=[handler])
|
||||
|
||||
if filename:
|
||||
handler = logging.FileHandler(filename)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return StyleAdapter(logger, style=style)
|
|
@ -0,0 +1,125 @@
|
|||
# --------------------------------------------------------
|
||||
# BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beitv2
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# By Zhiliang Peng
|
||||
# Based on BEiT, timm, DeiT and DINO code bases
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
# https://github.com/facebookresearch/deit/
|
||||
# https://github.com/facebookresearch/dino
|
||||
# --------------------------------------------------------'
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import torch.nn as nn
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
|
||||
def plot_images(images: dict):
|
||||
x = images["input"]
|
||||
reconstruction = images["rec"]
|
||||
half_sample = images["half_sample"]
|
||||
new_sample = images["new_sample"]
|
||||
|
||||
fig, axarr = plt.subplots(1, 4)
|
||||
axarr[0].imshow(x.cpu().detach().numpy()[0].transpose(1, 2, 0))
|
||||
axarr[1].imshow(reconstruction.cpu().detach().numpy()[0].transpose(1, 2, 0))
|
||||
axarr[2].imshow(half_sample.cpu().detach().numpy()[0].transpose(1, 2, 0))
|
||||
axarr[3].imshow(new_sample.cpu().detach().numpy()[0].transpose(1, 2, 0))
|
||||
plt.show()
|
||||
|
||||
|
||||
def get_model(model):
|
||||
if isinstance(model, torch.nn.DataParallel) \
|
||||
or isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
return model.module
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
||||
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
def mean(lst):
|
||||
return sum(lst) / len(lst)
|
|
@ -0,0 +1,10 @@
|
|||
[
|
||||
{
|
||||
"data_source_name": "dataset/train/data_1",
|
||||
"data_source_weight": 1
|
||||
},
|
||||
{
|
||||
"data_source_name": "dataset/train/data_2",
|
||||
"data_source_weight": 2
|
||||
}
|
||||
]
|
|
@ -0,0 +1,3 @@
|
|||
[
|
||||
{ "data_source_name": "pretrain_eval_eval", "data_source_weight": 1 }
|
||||
]
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"train_micro_batch_size_per_gpu": 16,
|
||||
"gradient_accumulation_steps": 16,
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 5e-5,
|
||||
"betas": [
|
||||
0.9,
|
||||
0.98
|
||||
],
|
||||
"weight_decay": 0.01
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 1e-6,
|
||||
"warmup_max_lr": 1e-5,
|
||||
"warmup_num_steps": 500
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"initial_scale_power": 10,
|
||||
"auto_cast": true
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2
|
||||
},
|
||||
"steps_per_print": 50,
|
||||
"gradient_clipping": 1.0
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 16,
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 1e-5,
|
||||
"betas": [
|
||||
0.9,
|
||||
0.98
|
||||
],
|
||||
"weight_decay": 0.01
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 1e-6,
|
||||
"warmup_max_lr": 1e-5,
|
||||
"warmup_num_steps": 500
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"initial_scale_power": 10,
|
||||
"auto_cast": true
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2
|
||||
},
|
||||
"steps_per_print": 50,
|
||||
"gradient_clipping": 1.0
|
||||
}
|
|
@ -3,25 +3,14 @@
|
|||
"dropout_p": 0.0,
|
||||
"eps": 1e-05,
|
||||
"half": true,
|
||||
"half_type": "bf16",
|
||||
"use_flash_attn": true,
|
||||
"flash_attn_mask_shape": "2d",
|
||||
"dim_model": 4096,
|
||||
"dim_ff": 14336,
|
||||
"dim_ff": 11008,
|
||||
"dim_head": 128,
|
||||
"num_heads": 32,
|
||||
"num_kv_heads": 32,
|
||||
"num_layers": 32,
|
||||
"activate_fn": "silu",
|
||||
"init_std": 0.10,
|
||||
"scale": false,
|
||||
"scale_emb": 12,
|
||||
"scale_depth": -1,
|
||||
"model_type": "fm9g",
|
||||
"architectures": [
|
||||
"FM9GForCausalLM"
|
||||
],
|
||||
"qk_norm": false,
|
||||
"tie_lm_head": false,
|
||||
"ffn_gated": true
|
||||
}
|
||||
"scale": false
|
||||
}
|
|
@ -119687,10 +119687,10 @@
|
|||
"𠳐"
|
||||
"𥻗"
|
||||
"𬉼"
|
||||
"<|im_start|>"
|
||||
"<|im_end|>"
|
||||
"<pad_2>"
|
||||
"<pad_3>"
|
||||
"<pad_4>"
|
||||
"<pad_5>"
|
||||
"<pad_6>"
|
||||
"<image>"
|
||||
"</image>"
|
||||
"<ref>"
|
||||
"</ref>"
|
||||
"<box>"
|
||||
"</box>"
|
||||
"<quad>"
|
|
@ -0,0 +1,347 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import glob
|
||||
import math
|
||||
import json
|
||||
import base64
|
||||
import random
|
||||
import copy
|
||||
|
||||
from PIL import Image
|
||||
from typing import List
|
||||
|
||||
|
||||
class Register(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Register, self).__init__(*args, **kwargs)
|
||||
self._dict = {}
|
||||
|
||||
def register(self, target):
|
||||
def add_register_item(keys, value):
|
||||
if not callable(value):
|
||||
raise Exception(
|
||||
f"Register object must be callable! But receice:{value} is not callable!")
|
||||
|
||||
if not isinstance(keys, list):
|
||||
keys = [keys]
|
||||
|
||||
for key in keys:
|
||||
if key in self._dict:
|
||||
print(
|
||||
f"error: \033[33m{value.__name__} has been registered before, so we will overriden it\033[0m")
|
||||
exit()
|
||||
|
||||
self[key] = value
|
||||
return value
|
||||
|
||||
if callable(target):
|
||||
return add_register_item(target.__name__, target)
|
||||
else:
|
||||
return lambda x: add_register_item(target, x)
|
||||
|
||||
def __call__(self, target):
|
||||
return self.register(target)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._dict[key] = value
|
||||
|
||||
def __getitem__(self, key):
|
||||
# 如果 key 存在于注册的处理器中,直接返回
|
||||
if key in self._dict:
|
||||
return self._dict[key]
|
||||
else:
|
||||
# 如果 key 不存在,使用默认处理器
|
||||
return self._dict['default']
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._dict
|
||||
|
||||
def __str__(self):
|
||||
return str(self._dict)
|
||||
|
||||
def keys(self):
|
||||
return self._dict.keys()
|
||||
|
||||
def values(self):
|
||||
return self._dict.values()
|
||||
|
||||
def items(self):
|
||||
return self._dict.items()
|
||||
|
||||
|
||||
register_data_processor = Register()
|
||||
register_data_path = Register()
|
||||
|
||||
def vqa_instruction_templates(question, idx=None):
|
||||
instructions = [
|
||||
"{Question} A short answer to the question is",
|
||||
"Given the image, answer the following question with no more than three words. {Question}",
|
||||
"Based on the image, respond to this question with a short answer: {Question} Answer:",
|
||||
"Use the provided image to answer the question: {Question} Provide your answer as short as possible:",
|
||||
]
|
||||
if idx is None:
|
||||
new_question = random.choice(
|
||||
instructions).replace("{Question}", question)
|
||||
else:
|
||||
new_question = instructions[idx].replace("{Question}", question)
|
||||
|
||||
return new_question
|
||||
|
||||
|
||||
def caption_instruction_templates():
|
||||
instructions = [
|
||||
"Describe the image concisely.",
|
||||
"Provide a brief description of the given image.",
|
||||
"Offer a succinct explanation of the picture presented.",
|
||||
"Summarize the visual content of the image.",
|
||||
"Give a short and clear explanation of the subsequent image.",
|
||||
"Share a concise interpretation of the image provided.",
|
||||
"Present a compact description of the photo's key features.",
|
||||
"Relay a brief, clear account of the picture shown.",
|
||||
"Render a clear and concise summary of the photo.",
|
||||
"Write a terse but informative summary of the picture.",
|
||||
"Create a compact narrative representing the image presented."
|
||||
]
|
||||
|
||||
new_question = random.choice(instructions)
|
||||
|
||||
return new_question
|
||||
|
||||
|
||||
def ocr_instruction_templates():
|
||||
instructions = [
|
||||
"Identify the text in the image with position."
|
||||
"Pinpoint and indicate the text and its location within the image."
|
||||
"Find the text in the image and identify its positional."
|
||||
"Detect the text within the image and specify its position."
|
||||
"Locate the text in the image and detail its position."
|
||||
]
|
||||
|
||||
new_question = random.choice(instructions)
|
||||
|
||||
return new_question
|
||||
|
||||
|
||||
def textvqa_instruction_templates(question):
|
||||
instructions = [
|
||||
"Answer the question shortly by reading the texts. {Question}"
|
||||
"After reading the text in the image, {Question} A short answer to the question is",
|
||||
"Given the text in the image, answer the following question with no more than three words. {Question}"
|
||||
]
|
||||
|
||||
new_question = random.choice(instructions).replace("{Question}", question)
|
||||
|
||||
return new_question
|
||||
|
||||
|
||||
def load_multimodal_conversation(text_b64, img_b64_buffer):
|
||||
map_role = {
|
||||
'human': 'human',
|
||||
'gpt': 'gpt'
|
||||
}
|
||||
|
||||
text = base64.b64decode(text_b64).decode('utf-8')
|
||||
list_conv = json.loads(text)
|
||||
|
||||
out: List[dict] = []
|
||||
for idx, sentence in enumerate(list_conv):
|
||||
value = sentence['value']
|
||||
|
||||
if idx == 0 and '<image>' not in value:
|
||||
value = f"<image>\n{value}"
|
||||
if idx != 0 and '<image>' in value:
|
||||
value = value.replace('<image>', '')
|
||||
|
||||
out.append({
|
||||
'from': map_role[sentence['from']],
|
||||
'value': value
|
||||
})
|
||||
|
||||
img_io = io.BytesIO(base64.b64decode(img_b64_buffer))
|
||||
img_io.seek(0)
|
||||
image = Image.open(img_io).convert('RGB')
|
||||
return image, out
|
||||
|
||||
def load_fm9g_multimodal_conversation(text_b64, img_b64_buffer):
|
||||
map_role = {
|
||||
'human': 'human',
|
||||
'gpt': 'gpt'
|
||||
}
|
||||
|
||||
text = base64.b64decode(text_b64).decode('utf-8')
|
||||
list_conv = json.loads(text)
|
||||
|
||||
out: List[dict] = []
|
||||
for idx, sentence in enumerate(list_conv):
|
||||
value = sentence['value']
|
||||
value = re.sub(r'<image>.+?</image>', '<image>', value)
|
||||
|
||||
out.append({
|
||||
'from': map_role[sentence['from']],
|
||||
'value': value.strip()
|
||||
})
|
||||
|
||||
img_io = io.BytesIO(base64.b64decode(img_b64_buffer))
|
||||
img_io.seek(0)
|
||||
image = Image.open(img_io).convert('RGB')
|
||||
return image, out
|
||||
|
||||
def load_pretrain_conversation(text_b64, img_b64_buffer):
|
||||
map_role = {
|
||||
'human': 'human',
|
||||
'gpt': 'gpt'
|
||||
}
|
||||
|
||||
text = base64.b64decode(text_b64).decode('utf-8')
|
||||
list_conv = json.loads(text)
|
||||
|
||||
out: List[dict] = []
|
||||
for idx, sentence in enumerate(list_conv):
|
||||
print(sentence)
|
||||
value = sentence['value']
|
||||
value = re.sub(r'<image>.+?</image>', '<image>', value)
|
||||
|
||||
out.append({
|
||||
'from': map_role[sentence['from']],
|
||||
'value': value.strip()
|
||||
})
|
||||
|
||||
img_io = io.BytesIO(base64.b64decode(img_b64_buffer))
|
||||
img_io.seek(0)
|
||||
image = Image.open(img_io).convert('RGB')
|
||||
return image, out
|
||||
|
||||
|
||||
def b64_to_PIL_image(img_b64_buffer):
|
||||
img_io = io.BytesIO(base64.b64decode(img_b64_buffer))
|
||||
img_io.seek(0)
|
||||
image = Image.open(img_io).convert('RGB')
|
||||
return image
|
||||
|
||||
def wrap_qa_to_single_turn_multimodal_conv(answer, question):
|
||||
if '<image>' not in question:
|
||||
question = f"<image>\n{question}"
|
||||
|
||||
out = [
|
||||
{"from": "human", "value": question},
|
||||
{"from": "gpt", "value": answer}
|
||||
]
|
||||
return question, out
|
||||
|
||||
def wrap_generation_single_turn_conv(out, template_func):
|
||||
conv = [
|
||||
{
|
||||
"from": "human",
|
||||
"value": f"<image>\n{template_func()}"
|
||||
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": out
|
||||
}
|
||||
]
|
||||
return conv
|
||||
|
||||
def wrap_ocr_generation_single_turn_conv(out):
|
||||
return wrap_generation_single_turn_conv(out, ocr_instruction_templates)
|
||||
|
||||
|
||||
def wrap_caption_generation_single_turn_conv(out):
|
||||
return wrap_generation_single_turn_conv(out, caption_instruction_templates)
|
||||
|
||||
|
||||
def gather_data_files_by_glob(root: str, pattern='*.tsv'):
|
||||
filenames = []
|
||||
|
||||
for fullpath in glob.glob(f'{root}/{pattern}'):
|
||||
filename = fullpath.split('/')[-1]
|
||||
filenames.append(filename)
|
||||
return root, filenames
|
||||
|
||||
@register_data_path('default')
|
||||
def default_data_path(data_dir):
|
||||
return gather_data_files_by_glob(data_dir, '*.tsv')
|
||||
|
||||
@register_data_processor('default')
|
||||
def default_data_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path,
|
||||
intent, img_transformer=None):
|
||||
if intent == 'pretrain' or intent == 'sft':
|
||||
image, out = load_multimodal_conversation(text_b64, img_b64_buffer)
|
||||
|
||||
metainfo = {
|
||||
"origin_dataset": origin_dataset, # llava folder name
|
||||
"origin_split": origin_split, # llava parquet file name
|
||||
"origin_idx": origin_split_inner_idx, # index in llava parquet file
|
||||
"image_id": img_path, # cocoid
|
||||
}
|
||||
|
||||
return {
|
||||
'image': image,
|
||||
'conversations': out,
|
||||
'idx': origin_split_inner_idx,
|
||||
'metainfo': metainfo,
|
||||
}
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
@register_data_path('llava')
|
||||
def llava_instruct_data_path():
|
||||
data_dir = "dataset/train/default"
|
||||
return gather_data_files_by_glob(data_dir, '*.tsv')
|
||||
|
||||
|
||||
@register_data_processor('llava')
|
||||
def llava_instruct_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path,
|
||||
intent, img_transformer=None):
|
||||
if intent == 'pretrain' or intent == 'sft':
|
||||
image, out = load_multimodal_conversation(text_b64, img_b64_buffer)
|
||||
|
||||
metainfo = {
|
||||
"origin_dataset": origin_dataset, # llava folder name
|
||||
"origin_split": origin_split, # llava parquet file name
|
||||
"origin_idx": origin_split_inner_idx, # index in llava parquet file
|
||||
"image_id": img_path, # cocoid
|
||||
}
|
||||
|
||||
return {
|
||||
'image': image,
|
||||
'conversations': out,
|
||||
'idx': origin_split_inner_idx,
|
||||
'metainfo': metainfo,
|
||||
}
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
@register_data_path('pretrain_eval_eval')
|
||||
def pretrain_eval_train_data_path():
|
||||
data_dir = "dataset/eval/default"
|
||||
return gather_data_files_by_glob(data_dir, '*.tsv')
|
||||
|
||||
@register_data_processor('pretrain_eval_eval')
|
||||
def unimmchat_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path,
|
||||
intent, img_transformer=None):
|
||||
if intent == 'pretrain' or intent == 'sft' or intent == 'eval':
|
||||
if img_b64_buffer == '<no_image>':
|
||||
image = '<no_image>'
|
||||
out = base64.b64decode(text_b64).decode('utf-8')
|
||||
out = json.loads(out)
|
||||
else:
|
||||
image, out = load_multimodal_conversation(text_b64, img_b64_buffer)
|
||||
|
||||
metainfo = {
|
||||
"origin_dataset": origin_dataset, # unimm-chat folder name
|
||||
"origin_split": origin_split, # unimm-chat parquet file name
|
||||
"origin_idx": origin_split_inner_idx, # index in unimm-chat parquet file
|
||||
"image_id": img_path, # cocoid
|
||||
}
|
||||
|
||||
return {
|
||||
'image': image,
|
||||
'conversations': out,
|
||||
'idx': origin_split_inner_idx,
|
||||
'metainfo': metainfo,
|
||||
}
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import base64
|
||||
|
||||
import os.path as op
|
||||
import torch.utils.data as torch_data
|
||||
|
||||
from PIL import Image
|
||||
from typing import List, Iterator
|
||||
from vis_fm9g.dataset.tsv_file import TSVFile
|
||||
from vis_fm9g.dataset.data import register_data_processor
|
||||
from vis_fm9g.dataset.itembuilder import ItemBuilder
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
class MultimodalQADataset(torch_data.Dataset):
|
||||
def __init__(self, qa_file, question_process):
|
||||
'''
|
||||
qa_file: jsonl file that each line is a dict like {
|
||||
'image': b64img,
|
||||
'question': question_text
|
||||
}
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
self.qa_file = qa_file
|
||||
self.qa_data = [json.loads(line) for line in open(self.qa_file)]
|
||||
if isinstance(self.qa_data[0], list):
|
||||
self.qa_data = self.qa_data[0] # unwrap one-line json question file
|
||||
|
||||
self.question_process = question_process
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.qa_data[index]
|
||||
|
||||
img_b64 = item['image']
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert('RGB')
|
||||
|
||||
raw_question = item['question']
|
||||
question_text = self.question_process(raw_question)
|
||||
return {
|
||||
'image': image,
|
||||
'raw_question': raw_question,
|
||||
'question': question_text
|
||||
}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.qa_data)
|
||||
|
||||
|
||||
class SingleDataSourceDataset(torch_data.Dataset):
|
||||
def __init__(self, ds_name, item_builder: ItemBuilder, data_dir, tsv_filenames: List[str], intent='sft') -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.filenames = tsv_filenames
|
||||
self.ds_name = ds_name
|
||||
|
||||
self.sizes = []
|
||||
for filename in self.filenames:
|
||||
try:
|
||||
size = int(filename[:-4].split('-')[-1])
|
||||
except:
|
||||
raise ValueError(f'TSV Data File {filename} is not valid, last component separated by `-` must be the number of sample in this file')
|
||||
self.sizes.append(size)
|
||||
|
||||
self.file_border_index = []
|
||||
self.prepare_border_index()
|
||||
|
||||
self.item_builder = item_builder
|
||||
self.files = self.filenames[:]
|
||||
self.intent = intent
|
||||
|
||||
|
||||
def prepare_border_index(self):
|
||||
self.file_border_index = [0]
|
||||
|
||||
temp_sum = 0
|
||||
for size in self.sizes:
|
||||
temp_sum += size
|
||||
self.file_border_index.append(temp_sum)
|
||||
|
||||
|
||||
def get_file_idx_and_row_idx(self, index):
|
||||
found = False
|
||||
file_idx = -1
|
||||
|
||||
for border_idx, border in enumerate(self.file_border_index):
|
||||
if index < border:
|
||||
file_idx = border_idx - 1
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
raise ValueError(f'Index {index} out of range for {self.size_sum} border markers')
|
||||
|
||||
offset = self.file_border_index[file_idx]
|
||||
row_idx = index - offset
|
||||
return file_idx, row_idx
|
||||
|
||||
def __len__(self):
|
||||
return self.file_border_index[-1]
|
||||
|
||||
def __getitem__(self, index):
|
||||
file_idx, row_idx = self.get_file_idx_and_row_idx(index)
|
||||
try:
|
||||
sample = self.fetch_sample(file_idx, row_idx)
|
||||
item = self.item_builder.build_item(sample)
|
||||
except:
|
||||
logger.warning(f"data fetch error")
|
||||
return self.__getitem__(random.randint(0, len(self)))
|
||||
return item
|
||||
|
||||
def fetch_sample(self, file_idx, row_idx):
|
||||
file = self.files[file_idx]
|
||||
if isinstance(file, str):
|
||||
self.prepare_file(file_idx)
|
||||
file = self.files[file_idx]
|
||||
|
||||
assert isinstance(file, TSVFile), f'Expecting TSVFile but get {file} as {type(file)}'
|
||||
|
||||
# tsv line as tuple
|
||||
sample = file[row_idx]
|
||||
ds_name, *values = sample
|
||||
# data dict
|
||||
sample = register_data_processor[self.ds_name](*values, intent=self.intent)
|
||||
if row_idx + 1 == len(file):
|
||||
del file
|
||||
self.files[file_idx] = self.filenames[file_idx]
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_file(self, idx):
|
||||
filename = self.filenames[idx]
|
||||
file = TSVFile(op.join(self.data_dir, filename))
|
||||
self.files[idx] = file
|
||||
|
||||
|
||||
class IterableSingleDataSourceDataset(torch_data.IterableDataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
class MultiDataSourceDataset(torch_data.Dataset):
|
||||
def __init__(self, data_sources: List[SingleDataSourceDataset], data_source_weights: List[int]):
|
||||
super().__init__()
|
||||
|
||||
self.ds_list = data_sources
|
||||
|
||||
self.sum_weight = sum(data_source_weights)
|
||||
self.ds_weights = data_source_weights
|
||||
for weight in self.ds_weights:
|
||||
assert isinstance(weight, int), 'weight must be integer'
|
||||
|
||||
self.offset2ds = {}
|
||||
self.offset2wt = {}
|
||||
self.offset2pd = {}
|
||||
self.prepare_offset2ds()
|
||||
|
||||
ds_loops = []
|
||||
for ds, wt in zip(self.ds_list, self.ds_weights):
|
||||
ds_loop = len(ds) // wt
|
||||
ds_loops.append(ds_loop)
|
||||
max_loop = max(ds_loops)
|
||||
self.size = max_loop * self.sum_weight
|
||||
|
||||
def prepare_offset2ds(self):
|
||||
offset = 0
|
||||
for ds, weight in zip(self.ds_list, self.ds_weights):
|
||||
pd = offset
|
||||
for _ in range(weight):
|
||||
self.offset2ds[offset] = ds
|
||||
self.offset2wt[offset] = weight
|
||||
self.offset2pd[offset] = pd
|
||||
offset += 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
n_loop = index // self.sum_weight
|
||||
offset = index % self.sum_weight
|
||||
|
||||
ds = self.offset2ds[offset]
|
||||
ds_inner_idx = n_loop * self.offset2wt[offset] + offset - self.offset2pd[offset]
|
||||
ds_inner_idx = ds_inner_idx % len(ds)
|
||||
|
||||
return ds[ds_inner_idx]
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
|
||||
class IterableMultiDataSourceDataset(torch_data.IterableDataset):
|
||||
def __init__(self, data_sources, data_source_weights):
|
||||
super().__init__()
|
||||
|
||||
self.ds_list = data_sources
|
||||
|
||||
sum_weight = sum(data_source_weights)
|
||||
self.ds_weights = [x / sum_weight for x in data_source_weights]
|
||||
|
||||
self.ds_consumption = []
|
||||
self.ds_sizes = [len(ds) for ds in self.ds_list]
|
||||
|
||||
def __next__(self):
|
||||
ds_idx = numpy.random.choice(range(len(self.ds_list)), 1, p=self.ds_weights)[0]
|
||||
data_source = self.ds_list[ds_idx]
|
||||
|
||||
self.ds_consumption[ds_idx] += 1
|
||||
if self.ds_consumption[ds_idx] % self.ds_sizes[ds_idx] == 0:
|
||||
self.report_consumption()
|
||||
|
||||
sample = next(data_source)
|
||||
return sample
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return sum(self.ds_sizes)
|
||||
|
||||
def report_consumption(self):
|
||||
for ds, consumption, size in zip(self.ds_list, self.ds_consumption, self.ds_sizes):
|
||||
print(f'Data {ds} consumption: {consumption / size:.2f} epoch', flush=True)
|
|
@ -0,0 +1,308 @@
|
|||
import io
|
||||
import json
|
||||
from typing import Dict, Tuple, List, Any
|
||||
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.data import default_collate
|
||||
|
||||
from utils.logger import init_logger
|
||||
from vis_fm9g.dataset.utils import convert_data_to_id
|
||||
from vis_fm9g.dataset.utils import convert_conversation_data_to_id
|
||||
from vis_fm9g.dataset.utils import pad
|
||||
import random
|
||||
|
||||
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
|
||||
from vis_fm9g.utils.constants import usr_indicator, bot_indicator
|
||||
from vis_fm9g.dataset.prompts import caption_zh, caption_en
|
||||
|
||||
LARGE_ENOUGH_NUMBER = 100
|
||||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
||||
|
||||
logger = init_logger()
|
||||
|
||||
def is_contain_chinese(check_str):
|
||||
"""
|
||||
判断字符串中是否包含中文
|
||||
:param check_str: {str} 需要检测的字符串
|
||||
:return: {bool} 包含返回True, 不包含返回False
|
||||
"""
|
||||
for ch in check_str:
|
||||
if u'\u4e00' <= ch <= u'\u9fff':
|
||||
return True
|
||||
|
||||
|
||||
def maybe_select_text(raw_text):
|
||||
candidates = raw_text.split('<cap_sep>')
|
||||
return random.choice(candidates)
|
||||
|
||||
|
||||
def maybe_parse_json(raw_text: str):
|
||||
# VG raw
|
||||
if raw_text.startswith('[{') and raw_text.endswith('}]'):
|
||||
try:
|
||||
data = json.loads(raw_text)
|
||||
text_list = [x['phrase'] for x in data if x['height'] > 160 and x['width'] > 160]
|
||||
if len(text_list) == 0:
|
||||
return max(data, key=lambda x: len(x['phrase'].split()))['phrase']
|
||||
else:
|
||||
return random.choice(text_list)
|
||||
except:
|
||||
return raw_text
|
||||
else:
|
||||
return raw_text
|
||||
|
||||
def clean_text(raw_text):
|
||||
text = raw_text.replace('<PERSON>', '')
|
||||
text = maybe_parse_json(maybe_select_text(text))
|
||||
return text
|
||||
|
||||
|
||||
def check_text_valid(raw_text):
|
||||
if pd.isna(raw_text):
|
||||
return False
|
||||
if not is_contain_chinese(raw_text) and len(raw_text.split()) <= 3:
|
||||
return False
|
||||
if '<img' in raw_text or '<a href' in raw_text:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_image_placeholder(tokenizer, query_len, use_im_start_end=False):
|
||||
if use_im_start_end:
|
||||
return tokenizer.im_start + tokenizer.unk_token * query_len + tokenizer.im_end
|
||||
else:
|
||||
return tokenizer.unk_token * query_len
|
||||
|
||||
|
||||
class ItemBuilder():
|
||||
def __init__(self, transform=None):
|
||||
self.transform = transform
|
||||
|
||||
def build_item(self, data):
|
||||
if self.transform is not None:
|
||||
return self.transform(data)
|
||||
return data
|
||||
|
||||
|
||||
# --------------------- FM9G ---------------------
|
||||
class FM9GBuilder(ItemBuilder):
|
||||
def __init__(self, tokenizer: FM9GTokenizer, max_len, transform=None, skip_overlength=False):
|
||||
super().__init__(transform)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.max_len = max_len
|
||||
self.skip_overlength = skip_overlength
|
||||
|
||||
def convert_data(self, inp_dicts: List[Dict], raw_data):
|
||||
res = []
|
||||
for inp_dict in inp_dicts:
|
||||
input_ids, context = convert_data_to_id(self.tokenizer, data=inp_dict)
|
||||
if len(input_ids) > self.max_len:
|
||||
if self.skip_overlength:
|
||||
if random.random() > 0.95:
|
||||
logger.warn(f"overlength={len(input_ids)}, raw_inp={inp_dict}, skip data")
|
||||
else:
|
||||
logger.warn(f"overlength={len(input_ids)}, skip data")
|
||||
continue
|
||||
|
||||
input_ids = input_ids[: self.max_len]
|
||||
context = context[: self.max_len]
|
||||
|
||||
res.append({
|
||||
'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
|
||||
'context': torch.from_numpy(context).unsqueeze(0),
|
||||
'raw_data': raw_data,
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
def convert_conversation_data(self, conversation_list: List[List]):
|
||||
res = []
|
||||
for conversation in conversation_list:
|
||||
input_ids, context, raw = convert_conversation_data_to_id(self.tokenizer, data=conversation, predict_roles={bot_indicator})
|
||||
if len(input_ids) > self.max_len:
|
||||
if self.skip_overlength:
|
||||
if random.random() > 0.95:
|
||||
logger.warn(f"overlength={len(input_ids)}, raw_inp={conversation}, skip data")
|
||||
else:
|
||||
logger.warn(f"overlength={len(input_ids)}, skip data")
|
||||
continue
|
||||
|
||||
input_ids = input_ids[: self.max_len]
|
||||
context = context[: self.max_len]
|
||||
res.append({
|
||||
'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
|
||||
'context': torch.from_numpy(context).unsqueeze(0),
|
||||
'raw_data': raw,
|
||||
})
|
||||
return res
|
||||
|
||||
def build_image_bound(self, res, images):
|
||||
return_res = []
|
||||
if isinstance(images, List) and len(images) > 0:
|
||||
images = torch.stack(images)
|
||||
for r in res:
|
||||
# r['input_ids'] (1, len)
|
||||
image_start_tokens = torch.where(r['input_ids'][0] == self.tokenizer.encoder[self.tokenizer.im_start])[0]
|
||||
# 跳过 im_start
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(r['input_ids'][0] == self.tokenizer.encoder[self.tokenizer.im_end])[0]
|
||||
|
||||
if len(image_start_tokens) != len(image_end_tokens) or len(image_start_tokens) > len(images):
|
||||
continue
|
||||
|
||||
image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])
|
||||
|
||||
r['pixel_values'] = images[:len(image_start_tokens)]
|
||||
r['image_bound'] = image_bound
|
||||
return_res.append(r)
|
||||
|
||||
return return_res
|
||||
|
||||
|
||||
def build_item(self, data):
|
||||
NotImplementedError("build_item is not implemented.")
|
||||
|
||||
|
||||
class FM9GImageTextBuilder(FM9GBuilder):
|
||||
def __init__(self, tokenizer: FM9GTokenizer, max_len, transform=None, query_len=64, min_resolution=0, skip_overlength=False):
|
||||
super().__init__(tokenizer, max_len, transform, skip_overlength)
|
||||
self.query_len = query_len
|
||||
self.min_resolution = min_resolution
|
||||
|
||||
def build_item(self, data):
|
||||
text = data['conversations']
|
||||
image = data['image']
|
||||
source = data.get('metainfo', {}).get('origin_dataset', 'unk')
|
||||
|
||||
image = self.transform(image)
|
||||
|
||||
raw_data = {'text': text}
|
||||
image_placeholder = get_image_placeholder(self.tokenizer, self.query_len, use_im_start_end=True)
|
||||
messages = []
|
||||
for i in range(len(text)):
|
||||
role = text[i]['from']
|
||||
role = usr_indicator if role == 'human' else bot_indicator
|
||||
value = self.tokenizer.escape(text[i]['value'])
|
||||
if '<image>' in value:
|
||||
value = value.replace('<image>', image_placeholder)
|
||||
messages.append((role, value))
|
||||
res = self.convert_conversation_data([messages])
|
||||
self.build_image_bound(res, images=[image])
|
||||
for r in res:
|
||||
r['source'] = source
|
||||
|
||||
return res[0]
|
||||
|
||||
class FM9GCollater:
|
||||
def __init__(self, tokenizer: FM9GTokenizer, max_len: int, unpad: bool = False, unilm: bool = False):
|
||||
self.tokenizer = tokenizer
|
||||
self._max_length = max_len
|
||||
self._unpad = unpad
|
||||
self._unilm = unilm
|
||||
self.pad_keys = ['input_ids', 'context']
|
||||
|
||||
def __call__(self, batch):
|
||||
batch_cnt = len(batch)
|
||||
if self._unpad: # for flash_attention cuda
|
||||
max_length = self._max_length * batch_cnt
|
||||
batch_size = 1
|
||||
else:
|
||||
max_length = self._max_length
|
||||
batch_size = batch_cnt
|
||||
|
||||
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||
context_origin = np.zeros((batch_size, max_length), dtype=np.int8)
|
||||
context = np.zeros((batch_size, max_length), dtype=np.int8)
|
||||
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
|
||||
|
||||
spans = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||
length = np.zeros((batch_size,), dtype=np.int32)
|
||||
position_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||
|
||||
if self._unpad: # for flash_attention cuda, force batch_size=1
|
||||
flatten_input_ids = np.concatenate([batch[i]['input_ids'][0] for i in range(batch_cnt)], axis=0)
|
||||
flatten_context = np.concatenate([batch[i]['context'][0] for i in range(batch_cnt)], axis=0)
|
||||
instance_length = flatten_input_ids.shape[0]
|
||||
inputs[0, : instance_length] = flatten_input_ids
|
||||
context_origin[0, : instance_length] = flatten_context
|
||||
length[0] = instance_length
|
||||
if self._unilm:
|
||||
context[0, : instance_length] = flatten_context
|
||||
# flatten batch
|
||||
_spans = [list(np.cumsum([batch[i]['input_ids'][0].shape[0] for i in range(batch_cnt)]))]
|
||||
|
||||
else:
|
||||
for i in range(batch_cnt):
|
||||
instance_length = batch[i]['input_ids'][0].shape[0]
|
||||
inputs[i, :instance_length] = batch[i]['input_ids'][0]
|
||||
context_origin[i, : instance_length] = batch[i]['context'][0]
|
||||
length[i] = instance_length
|
||||
if self._unilm:
|
||||
context[i, :instance_length] = batch[i]['context'][0]
|
||||
_spans = [[batch[i]['input_ids'][0].shape[0]] for i in range(batch_cnt)]
|
||||
|
||||
|
||||
# cu_seqlens 和 max_seqlen 在 flash_attention cuda 时需要
|
||||
if _spans[0][-1] != max_length:
|
||||
cu_seqlens = np.array([0] + _spans[0] + [max_length], dtype=np.int32)
|
||||
else:
|
||||
cu_seqlens = np.array([0] + _spans[0], dtype=np.int32)
|
||||
max_seqlen = int(np.max(cu_seqlens[1:] - cu_seqlens[:-1]))
|
||||
|
||||
raw_data_list: List[Any] = [batch[i]['raw_data'] for i in range(batch_cnt)]
|
||||
source_list: List[Any] = [batch[i].get('source', 'unk') for i in range(batch_cnt)]
|
||||
|
||||
for i in range(batch_size):
|
||||
instance_length = length[i]
|
||||
span_begin = 0
|
||||
for span_id, span_end in enumerate(_spans[i]):
|
||||
spans[i, span_begin: span_end] = span_id
|
||||
position_ids[i, span_begin:span_end] = np.arange(span_end - span_begin)
|
||||
span_begin = span_end
|
||||
for j in range(instance_length):
|
||||
idx = inputs[i][j]
|
||||
if j > 1:
|
||||
if context_origin[i][j] == 0:
|
||||
if idx != self.tokenizer.bos_id and inputs[i][j - 1] != self.tokenizer.eos_id:
|
||||
tgt[i, j - 1] = idx
|
||||
if context_origin[i][j] == 1 and context_origin[i][j-1] == 0:
|
||||
if idx != self.tokenizer.bos_id and inputs[i][j - 1] != self.tokenizer.eos_id:
|
||||
tgt[i, j - 1] = self.tokenizer.eos_id
|
||||
|
||||
|
||||
data = {}
|
||||
# image
|
||||
if 'pixel_values' in batch[0]:
|
||||
if self._unpad:
|
||||
data['pixel_values'] = [torch.vstack([i['pixel_values'] for i in batch])]
|
||||
else:
|
||||
data['pixel_values'] = [i['pixel_values'] for i in batch]
|
||||
|
||||
|
||||
# image_bound
|
||||
if 'image_bound' in batch[0]:
|
||||
if self._unpad:
|
||||
image_bounds = []
|
||||
for i in range(batch_cnt):
|
||||
offset = _spans[0][i-1] if i > 0 else 0
|
||||
image_bounds.append(batch[i]['image_bound'] + offset)
|
||||
data['image_bound'] = [torch.vstack(image_bounds)]
|
||||
else:
|
||||
data['image_bound'] = [i['image_bound'] for i in batch]
|
||||
|
||||
data['input_ids'] = torch.from_numpy(inputs)
|
||||
data['context'] = torch.from_numpy(context) > 0
|
||||
data['length'] = torch.from_numpy(length)
|
||||
data['spans'] = torch.from_numpy(spans)
|
||||
data['cu_seqlens'] = torch.from_numpy(cu_seqlens)
|
||||
data['max_seqlen'] = max_seqlen
|
||||
data['position_ids'] = torch.from_numpy(position_ids)
|
||||
data['target'] = torch.from_numpy(tgt)
|
||||
data['raw_data'] = raw_data_list
|
||||
data['source'] = source_list
|
||||
|
||||
return data
|
|
@ -0,0 +1,24 @@
|
|||
caption_en = [
|
||||
'Describe the image concisely',
|
||||
'Provide a brief description of the given image',
|
||||
'Offer a succinct explanation of the picture presented',
|
||||
'Summarize the visual content of the image',
|
||||
'Share a conciseinter pretation of the image provided',
|
||||
'Present a compact description of the photo’s key features',
|
||||
'Relay a brief and clear account of the picture shown',
|
||||
'Render a clear and concise summary of the photo',
|
||||
'Write a terse but informative summary of the picture',
|
||||
'Create a compact narrative representing the image presented',
|
||||
]
|
||||
|
||||
caption_zh = [
|
||||
'简明扼要地描述图像',
|
||||
'提供给定图像的简短描述',
|
||||
'对所示的图片进行简要的解释',
|
||||
'总结图像的视觉内容',
|
||||
'对所提供的图像进行简要的解释',
|
||||
'简明扼要并清楚地说明所示图片',
|
||||
'对这张照片作一个简明扼要的总结',
|
||||
'写一篇简洁但内容丰富的图片摘要',
|
||||
'创造一个紧凑的叙事来代表所呈现的图像',
|
||||
]
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) 2021 Microsoft Corporation. Licensed under the MIT license.
|
||||
import os
|
||||
import logging
|
||||
import os.path as op
|
||||
|
||||
LARGEST_TSV_SIZE = 500_000
|
||||
|
||||
# LARGEST_TSV_SIZE = 10_000
|
||||
|
||||
|
||||
def create_lineidx(filein, idxout):
|
||||
idxout_tmp = idxout + '.tmp'
|
||||
with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
|
||||
fsize = os.fstat(tsvin.fileno()).st_size
|
||||
fpos = 0
|
||||
while fpos != fsize:
|
||||
tsvout.write(str(fpos)+"\n")
|
||||
tsvin.readline()
|
||||
fpos = tsvin.tell()
|
||||
os.rename(idxout_tmp, idxout)
|
||||
|
||||
|
||||
def read_to_character(fp, c):
|
||||
result = []
|
||||
while True:
|
||||
s = fp.read(32)
|
||||
assert s != ''
|
||||
if c in s:
|
||||
result.append(s[: s.index(c)])
|
||||
break
|
||||
else:
|
||||
result.append(s)
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
class TSVFile(object):
|
||||
def __init__(self, tsv_file, generate_lineidx=False):
|
||||
self.tsv_file = tsv_file
|
||||
self.lineidx = op.splitext(tsv_file)[0] + '.lineidx'
|
||||
self._fp = None
|
||||
self._lineidx = None
|
||||
# the process always keeps the process which opens the file.
|
||||
# If the pid is not equal to the currrent pid, we will re-open the file.
|
||||
self.pid = None
|
||||
# generate lineidx if not exist
|
||||
if not op.isfile(self.lineidx) and generate_lineidx:
|
||||
create_lineidx(self.tsv_file, self.lineidx)
|
||||
|
||||
def __del__(self):
|
||||
if self._fp:
|
||||
self._fp.close()
|
||||
|
||||
def __str__(self):
|
||||
return "TSVFile(tsv_file='{}')".format(self.tsv_file)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def num_rows(self):
|
||||
self._ensure_lineidx_loaded()
|
||||
assert len(
|
||||
self._lineidx) <= LARGEST_TSV_SIZE, f'Do not support TSVFile larger than {LARGEST_TSV_SIZE} yet'
|
||||
return len(self._lineidx)
|
||||
|
||||
def seek(self, idx):
|
||||
self._ensure_tsv_opened()
|
||||
self._ensure_lineidx_loaded()
|
||||
try:
|
||||
pos = self._lineidx[idx]
|
||||
except:
|
||||
logging.info('{}-{}'.format(self.tsv_file, idx))
|
||||
raise
|
||||
self._fp.seek(pos)
|
||||
return [s.strip() for s in self._fp.readline().split('\t')]
|
||||
|
||||
def seek_first_column(self, idx):
|
||||
self._ensure_tsv_opened()
|
||||
self._ensure_lineidx_loaded()
|
||||
pos = self._lineidx[idx]
|
||||
self._fp.seek(pos)
|
||||
return read_to_character(self._fp, '\t')
|
||||
|
||||
def get_key(self, idx):
|
||||
return self.seek_first_column(idx)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.seek(index)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_rows()
|
||||
|
||||
def _ensure_lineidx_loaded(self):
|
||||
if self._lineidx is None:
|
||||
logging.debug('loading lineidx: {}'.format(self.lineidx))
|
||||
with open(self.lineidx, 'r') as fp:
|
||||
self._lineidx = [int(i.strip()) for i in fp.readlines()]
|
||||
|
||||
def _ensure_tsv_opened(self):
|
||||
if self._fp is None:
|
||||
self._fp = open(self.tsv_file, 'r')
|
||||
self.pid = os.getpid()
|
||||
|
||||
if self.pid != os.getpid():
|
||||
# logging.info('re-open {} because the process id changed'.format(self.tsv_file))
|
||||
self._fp = open(self.tsv_file, 'r')
|
||||
self.pid = os.getpid()
|
|
@ -0,0 +1,170 @@
|
|||
import importlib.machinery
|
||||
import importlib.util
|
||||
import types
|
||||
from typing import Any, Set
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from torch.utils.data import BatchSampler
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
|
||||
from vis_fm9g.utils.constants import SYSTEM
|
||||
|
||||
FM9GInputType = Union[str, Dict[str, "FM9GInputType"]]
|
||||
|
||||
class _TransformFuncDict(TypedDict):
|
||||
loader: importlib.machinery.SourceFileLoader
|
||||
module: types.ModuleType
|
||||
last_m: float
|
||||
|
||||
|
||||
class FM9GBatch(TypedDict):
|
||||
inputs: NDArray[np.int32]
|
||||
length: NDArray[np.int32]
|
||||
context: NDArray[np.bool_]
|
||||
sample_ids: NDArray[np.int32]
|
||||
spans: NDArray[np.int32]
|
||||
target: NDArray[np.int32]
|
||||
task_ids: NDArray[np.int32]
|
||||
task_names: List[str]
|
||||
raw_data: List[Any]
|
||||
|
||||
|
||||
def convert_data_to_id(tokenizer: FM9GTokenizer, data: Any):
|
||||
"""
|
||||
data: {
|
||||
'input': xxx,
|
||||
'output': xxx
|
||||
}
|
||||
"""
|
||||
input_ids = tokenizer.encode(data["input"])
|
||||
output_ids = tokenizer.encode(data["output"])
|
||||
ids = [tokenizer.bos_id] + input_ids + output_ids + [tokenizer.eos_id]
|
||||
ids = np.array(ids, dtype=np.int32)
|
||||
context = np.zeros((ids.shape[0],), dtype=np.int8)
|
||||
|
||||
# context = 0 需要 target
|
||||
context[: len(input_ids) + 1] = 1
|
||||
return ids, context
|
||||
|
||||
|
||||
def convert_conversation_data_to_id(tokenizer: FM9GTokenizer, data: Any, predict_roles: Set):
|
||||
"""
|
||||
predict_roles: {'<AI>'}
|
||||
data: [
|
||||
('<用户>', xxxx),
|
||||
('<AI>', xxxx)
|
||||
]
|
||||
"""
|
||||
assert (set([i[0] for i in data]) & predict_roles)
|
||||
|
||||
|
||||
if SYSTEM:
|
||||
system = tokenizer.bos_token + SYSTEM + '\n'
|
||||
else:
|
||||
system = tokenizer.bos_token
|
||||
sys_idx = tokenizer.encode(system)
|
||||
ret = system
|
||||
|
||||
input_ids = [sys_idx] if sys_idx else []
|
||||
context = [np.ones((len(sys_idx),), dtype=np.int8)]
|
||||
|
||||
for idx, (role, message) in enumerate(data):
|
||||
prefix = role
|
||||
# 最后一句加上 eos
|
||||
if idx == len(data)-1:
|
||||
message = message + tokenizer.eos_token
|
||||
|
||||
prefix_ids = tokenizer.encode(prefix)
|
||||
message_ids = tokenizer.encode(message)
|
||||
|
||||
input_ids.append(prefix_ids)
|
||||
|
||||
input_ids.append(message_ids)
|
||||
context.append(np.ones((len(prefix_ids),), dtype=np.int8))
|
||||
|
||||
if role in predict_roles:
|
||||
context.append(np.zeros((len(message_ids),), dtype=np.int8))
|
||||
else:
|
||||
context.append(np.ones((len(message_ids),), dtype=np.int8))
|
||||
|
||||
ret += (prefix + message)
|
||||
|
||||
ids = np.hstack(input_ids)
|
||||
context = np.hstack(context)
|
||||
|
||||
return ids, context, ret
|
||||
|
||||
|
||||
def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
|
||||
items = []
|
||||
if isinstance(orig_items[0][key], list):
|
||||
assert isinstance(orig_items[0][key][0], torch.Tensor)
|
||||
for it in orig_items:
|
||||
for tr in it[key]:
|
||||
items.append({key: tr})
|
||||
else:
|
||||
assert isinstance(orig_items[0][key], torch.Tensor)
|
||||
items = orig_items
|
||||
|
||||
batch_size = len(items)
|
||||
shape = items[0][key].shape
|
||||
dim = len(shape)
|
||||
assert dim <= 3
|
||||
if max_length is None:
|
||||
max_length = 0
|
||||
max_length = max(max_length, max(item[key].shape[-1] for item in items))
|
||||
min_length = min(item[key].shape[-1] for item in items)
|
||||
dtype = items[0][key].dtype
|
||||
|
||||
if dim == 1:
|
||||
return torch.cat([item[key] for item in items], dim=0)
|
||||
elif dim == 2:
|
||||
if max_length == min_length:
|
||||
return torch.cat([item[key] for item in items], dim=0)
|
||||
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
||||
else:
|
||||
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
||||
|
||||
for i, item in enumerate(items):
|
||||
if dim == 2:
|
||||
if padding_side == "left":
|
||||
tensor[i, -len(item[key][0]):] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0])] = item[key][0].clone()
|
||||
elif dim == 3:
|
||||
if padding_side == "left":
|
||||
tensor[i, -len(item[key][0]):, :] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
class SkipBatchSampler(BatchSampler):
|
||||
"""
|
||||
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_sampler, skip_batches=0):
|
||||
self.batch_sampler = batch_sampler
|
||||
self.skip_batches = skip_batches
|
||||
self.first_epoch = True
|
||||
|
||||
def __iter__(self):
|
||||
for index, samples in enumerate(self.batch_sampler):
|
||||
if index >= self.skip_batches and self.first_epoch:
|
||||
yield samples
|
||||
self.first_epoch = False
|
||||
|
||||
@property
|
||||
def total_length(self):
|
||||
return len(self.batch_sampler)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batch_sampler) - self.skip_batches
|
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
|
||||
# This function has been mostly taken from huggingface conversational ai code at
|
||||
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
|
||||
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
batch_size = logits.size()[0]
|
||||
if top_p > 0.0:
|
||||
logits = logits.view(batch_size, -1).contiguous()
|
||||
for index in range(len(logits)):
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits[index].view(-1), descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[index][indices_to_remove] = filter_value
|
||||
|
||||
logits = logits.view(batch_size, -1).contiguous()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def apply_repetition_penalty(
|
||||
logits,
|
||||
batch_size,
|
||||
num_beams,
|
||||
prev_output_tokens,
|
||||
repetition_penalty,
|
||||
start_idx=None,
|
||||
end_idx=None,
|
||||
window_size=None,
|
||||
):
|
||||
# only conduct repetition penalty for the output
|
||||
assert repetition_penalty >= 1, "repetition penalty coefficient should >= 1"
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
for i in range(batch_size * num_beams):
|
||||
if start_idx is None or end_idx is None:
|
||||
output_tokens = prev_output_tokens[i].tolist()
|
||||
else:
|
||||
if end_idx >= start_idx:
|
||||
if window_size:
|
||||
output_tokens = prev_output_tokens[i][
|
||||
max(start_idx, end_idx + 1 - window_size): end_idx + 1
|
||||
].tolist()
|
||||
else:
|
||||
output_tokens = prev_output_tokens[i][start_idx: end_idx + 1].tolist()
|
||||
else:
|
||||
output_tokens = []
|
||||
for previous_token in set(output_tokens):
|
||||
# if score < 0 then repetition penalty has to
|
||||
# multiplied to reduce the previous token probability
|
||||
if logits[i, previous_token] < 0:
|
||||
logits[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
logits[i, previous_token] /= repetition_penalty
|
||||
|
||||
|
||||
class BeamHypotheses:
|
||||
def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.max_len = max_len
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.n_hyp = n_hyp
|
||||
self.hyp = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return len(self.hyp)
|
||||
|
||||
def add(self, hyp, sum_logprobs):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / len(hyp) ** self.length_penalty
|
||||
|
||||
if len(self) < self.n_hyp or score > self.worst_score:
|
||||
self.hyp.append((score, hyp))
|
||||
if len(self) > self.n_hyp:
|
||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
||||
del self.hyp[sorted_scores[0][1]]
|
||||
self.worst_score = sorted_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs, cur_len):
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
if len(self) < self.n_hyp:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
return self.worst_score >= best_sum_logprobs / cur_len**self.length_penalty
|
|
@ -0,0 +1,425 @@
|
|||
from typing import Any, List, Optional, Union
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vis_fm9g.generation.generation_utils import BeamHypotheses, apply_repetition_penalty, \
|
||||
top_k_top_p_filtering
|
||||
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer, LlamaTokenizerWrapper
|
||||
from vis_fm9g.model.vlu_fm9g import VLU_FM9G
|
||||
from vis_fm9g.dataset.utils import pad
|
||||
|
||||
|
||||
class VLLMFM9GGeneration:
|
||||
def __init__(self, model: VLU_FM9G, tokenizer: Union[FM9GTokenizer, LlamaTokenizerWrapper], transform):
|
||||
model.eval()
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.transform = transform
|
||||
|
||||
def _convert_to_tensors(self, data: Any, max_inp_length: Optional[int] = None):
|
||||
if isinstance(self.tokenizer, LlamaTokenizerWrapper) and self.tokenizer.add_bos_token:
|
||||
input_ids = self.tokenizer.encode(data["input"])
|
||||
else:
|
||||
input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(data["input"])
|
||||
if max_inp_length is not None:
|
||||
input_ids = input_ids[: max_inp_length]
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int32)
|
||||
|
||||
image_start_tokens = torch.where(input_ids == self.tokenizer.im_start_id)[0]
|
||||
# 跳过 im_start
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(input_ids == self.tokenizer.im_end_id)[0]
|
||||
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
||||
image_bound = torch.hstack(
|
||||
[image_start_tokens[: valid_image_nums].unsqueeze(-1),
|
||||
image_end_tokens[:valid_image_nums].unsqueeze(-1)]
|
||||
)
|
||||
|
||||
model_input = {}
|
||||
model_input["input_ids"] = input_ids.unsqueeze(0)
|
||||
|
||||
model_input["context"] = torch.zeros(
|
||||
(model_input["input_ids"].shape[0], model_input["input_ids"].shape[1]), dtype=torch.int16
|
||||
)
|
||||
model_input["span"] = torch.ones((model_input["input_ids"].shape[1],), dtype=torch.int16).unsqueeze(0)
|
||||
model_input["length"] = torch.tensor([model_input["input_ids"].shape[1]], dtype=torch.int16)
|
||||
model_input["image_bound"] = image_bound
|
||||
|
||||
return model_input
|
||||
|
||||
def _process_list(self, data_list: List[Any], max_inp_length: Optional[int] = None):
|
||||
pad_keys = ['input_ids', 'context', 'span']
|
||||
input_tensors = []
|
||||
for data in data_list:
|
||||
input_tensors.append(self._convert_to_tensors(data, max_inp_length))
|
||||
padded = {}
|
||||
for key in pad_keys:
|
||||
padded[key] = pad(input_tensors, key, padding_side="left").cuda()
|
||||
padded['length'] = torch.hstack([i['length'] for i in input_tensors]).cuda()
|
||||
padded['image_bound'] = [i['image_bound'] for i in input_tensors]
|
||||
|
||||
return padded
|
||||
|
||||
def generate(
|
||||
self,
|
||||
data_list=None,
|
||||
img_list=None,
|
||||
max_inp_length: Optional[int] = None,
|
||||
vision_hidden_states=None,
|
||||
return_vision_hidden_states=False,
|
||||
use_transform=True,
|
||||
**kwargs
|
||||
):
|
||||
# img_list List[List[images]]
|
||||
assert data_list is not None
|
||||
bs = len(data_list)
|
||||
if img_list == None:
|
||||
img_list = [[] for i in range(bs)]
|
||||
assert bs == len(img_list)
|
||||
|
||||
model_inputs = self._process_list(data_list, max_inp_length)
|
||||
|
||||
if vision_hidden_states is None:
|
||||
if use_transform:
|
||||
pixel_values = []
|
||||
for i in range(bs):
|
||||
img_inps = []
|
||||
for img in img_list[i]:
|
||||
img_inps.append(self.transform(img))
|
||||
if img_inps:
|
||||
pixel_values.append(torch.stack(img_inps).cuda())
|
||||
else:
|
||||
pixel_values.append([])
|
||||
model_inputs['pixel_values'] = pixel_values
|
||||
else:
|
||||
pixel_values = img_list
|
||||
model_inputs['pixel_values'] = pixel_values
|
||||
else:
|
||||
model_inputs['vision_hidden_states'] = vision_hidden_states
|
||||
|
||||
with torch.inference_mode():
|
||||
model_inputs['hidden_states'], vision_hidden_states = self.model.get_vllm_embedding(
|
||||
model_inputs)
|
||||
result = self._decode(model_inputs, **kwargs)
|
||||
|
||||
if return_vision_hidden_states:
|
||||
return result, vision_hidden_states
|
||||
|
||||
return result
|
||||
|
||||
def _decode(self, model_inputs, **kwargs):
|
||||
raise NotImplementedError("_decode is not implemented.")
|
||||
|
||||
def _decode_text(self, result_ids):
|
||||
result_text = []
|
||||
for result in result_ids:
|
||||
if result[-1] == self.tokenizer.eos_id:
|
||||
result = result[:-1]
|
||||
result_text.append(self.tokenizer.decode(result))
|
||||
return result_text
|
||||
|
||||
|
||||
class VLLMFM9GBeamSearch(VLLMFM9GGeneration):
|
||||
def _decode(
|
||||
self,
|
||||
model_inputs,
|
||||
beam_size=3,
|
||||
max_length=100,
|
||||
min_length=0,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
temperature=1.0,
|
||||
repetition_window=None,
|
||||
):
|
||||
"""
|
||||
Beam search
|
||||
Args:
|
||||
model_inputs (dict): input ids.
|
||||
beam_size (int, optional, defaults to 3): beam size of beam search.
|
||||
generate_length (int, optional, defaults to 100): maximum generation length.
|
||||
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
||||
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
||||
""" # noqa: E501
|
||||
# expand dimmension
|
||||
batch_size = model_inputs["input_ids"].size(0)
|
||||
input: torch.Tensor = (
|
||||
model_inputs["input_ids"]
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size, -1)
|
||||
.contiguous()
|
||||
.view(batch_size * beam_size, -1)
|
||||
)
|
||||
length = (
|
||||
model_inputs["length"]
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size)
|
||||
.contiguous()
|
||||
.view(
|
||||
batch_size * beam_size,
|
||||
)
|
||||
)
|
||||
span: torch.Tensor = (
|
||||
model_inputs["span"]
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size, -1)
|
||||
.contiguous()
|
||||
.view(batch_size * beam_size, -1)
|
||||
)
|
||||
|
||||
context: torch.Tensor = (
|
||||
model_inputs["context"]
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size, -1)
|
||||
.contiguous()
|
||||
.view(batch_size * beam_size, -1)
|
||||
)
|
||||
|
||||
hidden_states: torch.Tensor = (
|
||||
model_inputs["hidden_states"]
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size, *model_inputs["hidden_states"].shape[1:])
|
||||
.contiguous()
|
||||
.view(batch_size * beam_size, *model_inputs["hidden_states"].shape[1:])
|
||||
)
|
||||
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input.device)
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view(-1)
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
BeamHypotheses(beam_size, max_length, length_penalty=length_penalty, early_stopping=False)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
|
||||
pred_start_index = input.size(-1)
|
||||
past_key_values = None
|
||||
|
||||
for i in range(max_length + 1):
|
||||
if i == 0:
|
||||
logits, _, past_key_values = self.model.llm.inference(
|
||||
input=input,
|
||||
context=context,
|
||||
span=span,
|
||||
length=length,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=hidden_states
|
||||
)
|
||||
else:
|
||||
logits, _, past_key_values = self.model.llm.inference(
|
||||
input=input[:, -1:],
|
||||
context=context,
|
||||
span=span,
|
||||
length=length,
|
||||
past_key_values=past_key_values
|
||||
)
|
||||
# skip all steps when we are done with each sentence
|
||||
if all(done):
|
||||
break
|
||||
|
||||
# (batch * beam, seqlen, model_dim)
|
||||
logits = logits[:, -1, :]
|
||||
if i == 0:
|
||||
logits[:, self.tokenizer.eos_id] = -float("inf")
|
||||
|
||||
apply_repetition_penalty(
|
||||
logits,
|
||||
batch_size,
|
||||
beam_size,
|
||||
input,
|
||||
repetition_penalty,
|
||||
pred_start_index,
|
||||
input.size(-1) - 1,
|
||||
repetition_window,
|
||||
)
|
||||
logits = logits / temperature
|
||||
if i < min_length:
|
||||
logits[:, self.tokenizer.eos_id] = -float("inf")
|
||||
|
||||
scores = F.log_softmax(logits, dim=-1)
|
||||
|
||||
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * beam_size, vocab_size)
|
||||
|
||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||
next_scores = next_scores.view(batch_size, -1) # (batch_size, beam_size * vocab_size)
|
||||
next_scores, next_words = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
|
||||
|
||||
assert next_scores.size() == next_words.size() == (batch_size, 2 * beam_size)
|
||||
next_batch_beam = []
|
||||
|
||||
for sent_id in range(batch_size):
|
||||
# if we are done with this sentence
|
||||
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item(), i)
|
||||
if done[sent_id]:
|
||||
next_batch_beam.extend([(0, 0, 0)] * beam_size) # pad the batch
|
||||
continue
|
||||
|
||||
# next sentence beam content
|
||||
next_sent_beam = []
|
||||
|
||||
# next words for this sentence
|
||||
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
|
||||
# get beam and word IDs
|
||||
beam_id = torch.div(idx, scores.size(-1), rounding_mode="floor")
|
||||
word_id = idx % scores.size(-1)
|
||||
|
||||
# end of sentence, or next word
|
||||
if word_id == self.tokenizer.eos_id or i == max_length:
|
||||
generated_hyps[sent_id].add(
|
||||
input[sent_id * beam_size + beam_id, pred_start_index:].clone().cpu().tolist(),
|
||||
value.item(),
|
||||
)
|
||||
else:
|
||||
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
if len(next_sent_beam) == beam_size:
|
||||
break
|
||||
|
||||
# update next beam content
|
||||
assert len(next_sent_beam) == 0 if i == max_length else beam_size
|
||||
if len(next_sent_beam) == 0:
|
||||
next_sent_beam = [(0, 0, 0)] * beam_size # pad the batch
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
assert len(next_batch_beam) == beam_size * (sent_id + 1)
|
||||
|
||||
# we have reached the last step
|
||||
if i == max_length:
|
||||
break
|
||||
|
||||
# sanity check / prepare next batch
|
||||
assert len(next_batch_beam) == batch_size * beam_size
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
beam_words = input.new([x[1] for x in next_batch_beam])
|
||||
beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=input.device).long()
|
||||
# re-order batch and internal states
|
||||
input = input[beam_idx, :]
|
||||
|
||||
past_key_values["buffer"] = [list(each) if each is not None else each for each in past_key_values["buffer"]] # type: ignore # noqa: E501
|
||||
for key_value_layer in past_key_values["buffer"]:
|
||||
if key_value_layer is not None:
|
||||
key_value_layer[0] = key_value_layer[0][beam_idx]
|
||||
key_value_layer[1] = key_value_layer[1][beam_idx]
|
||||
|
||||
input = torch.cat([input, beam_words.unsqueeze(1)], dim=-1)
|
||||
context = torch.cat(
|
||||
[context, context[:, -1:]],
|
||||
dim=-1,
|
||||
)
|
||||
length += 1
|
||||
|
||||
span = torch.cat([span, span[:, -1:]], dim=-1)
|
||||
|
||||
# select the best hypotheses
|
||||
|
||||
results = []
|
||||
for i, hypotheses in enumerate(generated_hyps):
|
||||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
||||
results.append(best_hyp)
|
||||
|
||||
result_text = self._decode_text(results)
|
||||
return result_text
|
||||
|
||||
|
||||
class VLLMFM9GRandomSampling(VLLMFM9GGeneration):
|
||||
def _decode(
|
||||
self,
|
||||
model_inputs,
|
||||
min_length=24,
|
||||
max_length=50,
|
||||
top_k=0.0,
|
||||
top_p=1.0,
|
||||
temperature=1,
|
||||
repetition_penalty=1.0,
|
||||
repetition_window=None
|
||||
):
|
||||
"""
|
||||
Top-k and top-p sampling.
|
||||
Args:
|
||||
model_inputs (dict): input ids
|
||||
max_length (int, optional, defaults to 100): maximum generation length
|
||||
top_k (int, optional, defaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens.
|
||||
top_p (int, optional, defaults to 0.9): keep the top tokens with cumulative probability >= top_p.
|
||||
temperature (int, optional, defaults to 0.9): the value that can cool down the logits distribution.
|
||||
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
||||
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
||||
""" # noqa: E501
|
||||
max_length += 1
|
||||
|
||||
input = model_inputs["input_ids"]
|
||||
context = model_inputs["context"]
|
||||
length = model_inputs["length"]
|
||||
span = model_inputs["span"]
|
||||
hidden_states = model_inputs["hidden_states"]
|
||||
batch_size = input.size(0)
|
||||
|
||||
pred_start_index = input.size(-1)
|
||||
past_key_values = None
|
||||
done = [False for _ in range(batch_size)]
|
||||
results = [None for _ in range(batch_size)]
|
||||
for i in range(max_length):
|
||||
if i == 0:
|
||||
logits, _, past_key_values = self.model.llm.inference(
|
||||
input=input,
|
||||
context=context,
|
||||
length=length,
|
||||
span=span,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=hidden_states
|
||||
)
|
||||
else:
|
||||
logits, _, past_key_values = self.model.llm.inference(
|
||||
input=input[:, -1:],
|
||||
context=context,
|
||||
length=length,
|
||||
span=span,
|
||||
past_key_values=past_key_values
|
||||
)
|
||||
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if i == 0:
|
||||
logits[:, self.tokenizer.eos_id] = -float("inf")
|
||||
|
||||
apply_repetition_penalty(
|
||||
logits,
|
||||
batch_size,
|
||||
1,
|
||||
input,
|
||||
repetition_penalty,
|
||||
pred_start_index,
|
||||
input.size(-1) - 1,
|
||||
repetition_window,
|
||||
)
|
||||
|
||||
logits = logits / temperature
|
||||
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
for idx in range(batch_size):
|
||||
if not done[idx] and (next_token[idx].item() == self.tokenizer.eos_id or i == max_length - 1):
|
||||
done[idx] = True
|
||||
results[idx] = input[idx, pred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501
|
||||
|
||||
if sum(done) == batch_size:
|
||||
break
|
||||
|
||||
# update input ids
|
||||
input = torch.cat([input, next_token], dim=-1)
|
||||
length += 1
|
||||
|
||||
context = torch.cat(
|
||||
[context, context[:, -1:]],
|
||||
dim=-1,
|
||||
)
|
||||
span = torch.cat(
|
||||
[span, span[:, -1:]],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
result_text = self._decode_text(results)
|
||||
return result_text
|
|
@ -11,5 +11,4 @@ from .position_embedding import ChatGLMRotaryEmbedding
|
|||
from .position_embedding import RotaryEmbedding
|
||||
from .position_embedding import RotaryEmbeddingESM
|
||||
from .position_embedding import SegmentPositionEmbedding
|
||||
from .transformer import Encoder
|
||||
#from _attention_pp_sp import OpAttnPipeSP
|
||||
from .transformer import Encoder
|
|
@ -0,0 +1,283 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
try:
|
||||
from .flash_triton import FlashAttnFunc
|
||||
except:
|
||||
FlashAttnFunc = None
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from .linear import Linear
|
||||
from .position_embedding import apply_chatglm_rotary_pos_emb
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
except:
|
||||
flash_attn_unpadded_func = None
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
except:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
try:
|
||||
from flash_attn.bert_padding import pad_input
|
||||
from flash_attn.bert_padding import unpad_input
|
||||
except:
|
||||
pad_input = None
|
||||
unpad_input = None
|
||||
|
||||
|
||||
class FlashSelfAttention(torch.nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
|
||||
super().__init__()
|
||||
assert flash_attn_unpadded_func is not None, (
|
||||
"Please install FlashAttention first, " "e.g., with pip install flash-attn"
|
||||
)
|
||||
assert rearrange is not None, "Please install einops first, e.g., with pip install einops"
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
def forward(self, q, k, v, attention_mask=None, length_mask=None, context_mask=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
|
||||
"""
|
||||
assert q.dtype in [torch.float16, torch.bfloat16], q.dtype
|
||||
assert q.is_cuda
|
||||
batch_size, seqlen = q.shape[0], q.shape[1]
|
||||
d = q.shape[-1]
|
||||
|
||||
if length_mask is not None:
|
||||
q, k, v = [rearrange(x, "b s h d -> b s (h d)") for x in [q, k, v]]
|
||||
q, indices_q, cu_seqlens, max_s = unpad_input(q, length_mask)
|
||||
k, _, _, _ = unpad_input(k, length_mask)
|
||||
v, _, _, _ = unpad_input(v, length_mask)
|
||||
q, k, v = [rearrange(x, "nnz (h d) -> nnz h d", d=d) for x in [q, k, v]]
|
||||
output = flash_attn_unpadded_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
self.dropout_p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=self.causal,
|
||||
attention_mask=attention_mask,
|
||||
context_mask=context_mask[:, :max_s],
|
||||
)
|
||||
# TODO reimplement (un)pad_input to remove redundant rearranges.
|
||||
output = rearrange(output, "nnz h d -> nnz (h d)")
|
||||
output = pad_input(output, indices_q, batch_size, seqlen)
|
||||
output = rearrange(output, "b s (h d) -> b s h d", d=d)
|
||||
else:
|
||||
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
||||
max_s = seqlen
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=q.device)
|
||||
|
||||
output = flash_attn_unpadded_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
self.dropout_p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=self.causal,
|
||||
attention_mask=attention_mask,
|
||||
context_mask=context_mask,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
||||
return output
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
dim_head: int,
|
||||
dtype: torch.dtype = torch.half,
|
||||
dropout_p: Optional[float] = None,
|
||||
scale: bool = True,
|
||||
add_qkv_bias: bool = False,
|
||||
use_flash_attn: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dim_model = dim_model
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_groups = num_heads // num_kv_heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.project_q = Linear(
|
||||
self.dim_model, self.num_heads * self.dim_head, bias=add_qkv_bias, dtype=dtype, scale=scale
|
||||
)
|
||||
self.project_k = Linear(
|
||||
self.dim_model, self.num_kv_heads * self.dim_head, bias=add_qkv_bias, dtype=dtype, scale=scale
|
||||
)
|
||||
self.project_v = Linear(
|
||||
self.dim_model, self.num_kv_heads * self.dim_head, bias=add_qkv_bias, dtype=dtype, scale=scale
|
||||
)
|
||||
|
||||
self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype, scale=scale)
|
||||
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
|
||||
if dropout_p is not None:
|
||||
self.dropout = torch.nn.Dropout(p=dropout_p)
|
||||
self.dropout_p = dropout_p
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
# if use_flash_attn:
|
||||
# self.core_attention_flash = FlashSelfAttention(causal=False, attention_dropout=0.0)
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_q: torch.Tensor,
|
||||
hidden_kv: torch.Tensor,
|
||||
attention_mask: torch.BoolTensor,
|
||||
position_bias: torch.Tensor,
|
||||
use_cache: bool = False,
|
||||
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
pos_bias_type: Optional[str] = "relative",
|
||||
length_mask: Optional[torch.Tensor] = None,
|
||||
context_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask_bias: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: int = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""This model inherits from bmt.DistributedModule.
|
||||
Args:
|
||||
hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
|
||||
hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding.
|
||||
attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices.
|
||||
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`.
|
||||
Return:
|
||||
out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output.
|
||||
""" # noqa: E501
|
||||
|
||||
batch_size = hidden_q.size(0)
|
||||
len_q = hidden_q.size(1)
|
||||
len_k = hidden_kv.size(1)
|
||||
|
||||
h_q = self.project_q(hidden_q)
|
||||
h_k = self.project_k(hidden_kv)
|
||||
h_v = self.project_v(hidden_kv)
|
||||
|
||||
if not self.use_flash_attn:
|
||||
h_q = h_q / math.sqrt(math.sqrt(self.dim_head))
|
||||
h_k = h_k / math.sqrt(math.sqrt(self.dim_head))
|
||||
|
||||
h_q = h_q.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
||||
h_k = h_k.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
|
||||
h_v = h_v.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
|
||||
|
||||
if pos_bias_type == "rotary":
|
||||
# b h s d
|
||||
h_q, h_k = position_bias(h_q, h_k, -2, offset=past_kv[0].size(-2) if past_kv is not None else 0)
|
||||
elif pos_bias_type == "chatglm_rotary":
|
||||
h_q = apply_chatglm_rotary_pos_emb(h_q, position_bias)
|
||||
h_k = apply_chatglm_rotary_pos_emb(h_k, position_bias)
|
||||
|
||||
if past_kv is not None:
|
||||
h_k = torch.cat([past_kv[0], h_k], dim=-2)
|
||||
h_v = torch.cat([past_kv[1], h_v], dim=-2)
|
||||
len_k = h_k.size(-2)
|
||||
|
||||
# (b, n_h, len_q, d_h) @ (b, n_h, d_h, len_k) -> (b, n_h, len_q, len_k)
|
||||
# (b, n_kv_h, n_h_groups*len_q, d_h) @ (b, n_kv_h, d_h, len_k) -> (b, n_kv_h, n_h_groups*len_q, len_k) -> (b, n_h, len_q, len_k)
|
||||
if self.head_groups == 1:
|
||||
score = torch.matmul(h_q, h_k.transpose(-1, -2)) # / math.sqrt(self.dim_head) moved to line 75~76
|
||||
else:
|
||||
score = torch.matmul(
|
||||
h_q.reshape(batch_size, self.num_kv_heads, self.head_groups * len_q, self.dim_head),
|
||||
h_k.transpose(-1, -2),
|
||||
).view(
|
||||
batch_size, self.num_heads, len_q, len_k
|
||||
) # / math.sqrt(self.dim_head) moved to line 75~76
|
||||
if pos_bias_type == "relative":
|
||||
if len_q == 1: # inference with cache
|
||||
if len(position_bias.size()) == 4:
|
||||
position_bias = position_bias[:, :, -1:, :]
|
||||
else:
|
||||
position_bias = position_bias[:, -1:, :]
|
||||
score = score + position_bias
|
||||
score = torch.masked_fill(
|
||||
score,
|
||||
attention_mask.view(batch_size, 1, len_q, len_k) == False,
|
||||
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
|
||||
)
|
||||
|
||||
score = self.softmax(score)
|
||||
|
||||
score = torch.masked_fill(
|
||||
score,
|
||||
attention_mask.view(batch_size, 1, len_q, len_k) == False,
|
||||
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
|
||||
)
|
||||
|
||||
if self.dropout is not None:
|
||||
score = self.dropout(score)
|
||||
|
||||
# (b, n_h, len_q, len_k) @ (b, n_h, len_k, d_h) -> (b, n_h, len_q, d_h)
|
||||
# (b, n_kv_h, n_h_groups*len_q, len_k) @ (b, n_kv_h, len_k, d_h) -> (b, n_kv_h, n_h_groups*len_q, d_h) -> (b, n_h, len_q, d_h)
|
||||
score = torch.matmul(score.view(batch_size, self.num_kv_heads, self.head_groups * len_q, len_k), h_v).view(
|
||||
batch_size, self.num_heads, len_q, self.dim_head
|
||||
)
|
||||
|
||||
score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
|
||||
score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
|
||||
|
||||
else:
|
||||
if attention_mask_bias is not None:
|
||||
assert pos_bias_type == "rotary"
|
||||
h_q = h_q.view(batch_size, len_q, self.num_heads, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_k = h_k.view(batch_size, len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_v = h_v.view(batch_size, len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_q, h_k = position_bias(h_q, h_k, -3)
|
||||
score = FlashAttnFunc.apply(h_q, h_k, h_v, attention_mask_bias, False, None)
|
||||
else:
|
||||
if pos_bias_type == "chatglm_rotary":
|
||||
raise NotImplemented("No FlashAttn version for ChatGLM at present!")
|
||||
h_q = h_q.view(len_q, self.num_heads, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_k = h_k.view(len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_v = h_v.view(len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_q, h_k = position_bias(
|
||||
h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids
|
||||
)
|
||||
|
||||
score = flash_attn_varlen_func(
|
||||
h_q, h_k, h_v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, self.dropout_p, causal=True
|
||||
)
|
||||
|
||||
score = score.view(batch_size, len_q, self.num_heads * self.dim_head)
|
||||
|
||||
score = self.attention_out(score)
|
||||
|
||||
if use_cache:
|
||||
return score, (h_k, h_v)
|
||||
else:
|
||||
return score
|
|
@ -2,7 +2,6 @@ from typing import Optional
|
|||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
|
||||
from .attention import Attention
|
||||
|
@ -12,7 +11,7 @@ from .position_embedding import RotaryEmbedding
|
|||
from .position_embedding import RotaryEmbeddingESM
|
||||
|
||||
|
||||
class SelfAttentionBlock(bmt.DistributedModule):
|
||||
class SelfAttentionBlock(torch.nn.Module):
|
||||
"""The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection.
|
||||
|
||||
Args:
|
||||
|
@ -31,12 +30,11 @@ class SelfAttentionBlock(bmt.DistributedModule):
|
|||
num_kv_heads: int,
|
||||
dim_head: int,
|
||||
dtype=torch.half,
|
||||
eps: float = 1e-5,
|
||||
eps: float = 1e-6,
|
||||
dropout_p: Optional[float] = None,
|
||||
scale: bool = True,
|
||||
add_qkv_bias: bool = False,
|
||||
use_flash_attn: bool = False,
|
||||
tp: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -56,7 +54,6 @@ class SelfAttentionBlock(bmt.DistributedModule):
|
|||
scale=scale,
|
||||
add_qkv_bias=add_qkv_bias,
|
||||
use_flash_attn=use_flash_attn,
|
||||
tp=tp,
|
||||
)
|
||||
|
||||
if dropout_p:
|
||||
|
@ -73,6 +70,7 @@ class SelfAttentionBlock(bmt.DistributedModule):
|
|||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
pos_bias_type: Optional[str] = "relative",
|
||||
length_mask: Optional[torch.Tensor] = None,
|
||||
context_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask_bias: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: int = None,
|
||||
|
@ -98,6 +96,7 @@ class SelfAttentionBlock(bmt.DistributedModule):
|
|||
past_key_value,
|
||||
pos_bias_type=pos_bias_type,
|
||||
length_mask=length_mask,
|
||||
context_mask=context_mask,
|
||||
attention_mask_bias=attention_mask_bias,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
|
@ -137,7 +136,6 @@ class FFNBlock(torch.nn.Module):
|
|||
eps: float = 1e-6,
|
||||
dropout_p: Optional[float] = 0,
|
||||
scale: bool = True,
|
||||
tp: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -154,7 +152,6 @@ class FFNBlock(torch.nn.Module):
|
|||
dtype=dtype,
|
||||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
tp=tp,
|
||||
)
|
||||
|
||||
if dropout_p:
|
||||
|
@ -178,7 +175,9 @@ class FFNBlock(torch.nn.Module):
|
|||
x = self.ffn(x)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
hidden_states = hidden_states + x # / 1.05
|
||||
|
||||
hidden_states = hidden_states + x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@ -211,7 +210,6 @@ class TransformerBlock(torch.nn.Module):
|
|||
mask_att: bool = False,
|
||||
mask_ffn: bool = False,
|
||||
use_flash_attn: bool = False,
|
||||
tp: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.mask_att = mask_att
|
||||
|
@ -229,7 +227,6 @@ class TransformerBlock(torch.nn.Module):
|
|||
scale=scale,
|
||||
add_qkv_bias=add_qkv_bias,
|
||||
use_flash_attn=use_flash_attn,
|
||||
tp=tp,
|
||||
)
|
||||
|
||||
if not self.mask_ffn:
|
||||
|
@ -241,7 +238,6 @@ class TransformerBlock(torch.nn.Module):
|
|||
eps=eps,
|
||||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
tp=tp,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -253,6 +249,7 @@ class TransformerBlock(torch.nn.Module):
|
|||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
pos_bias_type: Optional[str] = "relative",
|
||||
length_mask: Optional[torch.Tensor] = None,
|
||||
context_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask_bias: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
|
@ -279,12 +276,12 @@ class TransformerBlock(torch.nn.Module):
|
|||
past_key_value=past_key_value,
|
||||
pos_bias_type=pos_bias_type,
|
||||
length_mask=length_mask,
|
||||
context_mask=context_mask,
|
||||
attention_mask_bias=attention_mask_bias,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
hidden_states, current_key_value = hidden_states
|
||||
else:
|
||||
|
@ -297,4 +294,4 @@ class TransformerBlock(torch.nn.Module):
|
|||
if use_cache:
|
||||
return hidden_states, current_key_value
|
||||
else:
|
||||
return hidden_states
|
||||
return hidden_states
|
|
@ -1,14 +1,13 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .position_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class Embedding(bmt.DistributedModule):
|
||||
class Embedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
|
@ -21,10 +20,7 @@ class Embedding(bmt.DistributedModule):
|
|||
super().__init__()
|
||||
|
||||
self.dim_model = embedding_size
|
||||
self.weight = bmt.DistributedParameter(
|
||||
torch.empty(vocab_size, embedding_size, dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
)
|
||||
self.weight = torch.nn.parameter.Parameter(torch.empty(vocab_size, embedding_size, dtype=dtype))
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, ids: torch.Tensor):
|
||||
|
@ -39,7 +35,7 @@ class Embedding(bmt.DistributedModule):
|
|||
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
|
||||
else:
|
||||
embeds = F.embedding(ids, self.weight)
|
||||
return embeds
|
||||
return embeds.clone()
|
||||
|
||||
def projection(self, x: torch.Tensor):
|
||||
"""
|
||||
|
@ -56,7 +52,7 @@ class Embedding(bmt.DistributedModule):
|
|||
return logits
|
||||
|
||||
|
||||
class EmbeddingExt(bmt.DistributedModule):
|
||||
class EmbeddingExt(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
|
@ -71,10 +67,10 @@ class EmbeddingExt(bmt.DistributedModule):
|
|||
self.dim_model = embedding_size
|
||||
self.rotary_emb = RotaryEmbedding(dim=embedding_size, distance_scale=distance_scale, dtype=dtype)
|
||||
|
||||
self.weight = bmt.DistributedParameter(
|
||||
self.weight = torch.nn.parameter.Parameter(
|
||||
torch.empty(vocab_size, embedding_size, dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
)
|
||||
torch.nn.init.normal_(self.weight, mean=init_mean, std=init_std)
|
||||
|
||||
def forward(self, ids: torch.Tensor, ids_sub: torch.Tensor):
|
||||
"""
|
||||
|
@ -101,69 +97,4 @@ class EmbeddingExt(bmt.DistributedModule):
|
|||
if ext_table is not None:
|
||||
logits_ext = F.linear(x, ext_table)
|
||||
logits = torch.cat([logits, logits_ext], dim=-1)
|
||||
return logits
|
||||
|
||||
|
||||
class VocabParallelEmbedding(bmt.DistributedModule):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embedding_size: int,
|
||||
dtype: torch.dtype = torch.half,
|
||||
scale: bool = True,
|
||||
init_mean: float = 0.0,
|
||||
init_std: float = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim_model = embedding_size
|
||||
assert vocab_size % config["tp_size"] == 0
|
||||
self.vocab_size_per_partition = vocab_size // config["tp_size"]
|
||||
self.start_index = config["tp_rank"] * self.vocab_size_per_partition
|
||||
self.end_index = (config["tp_rank"] + 1) * self.vocab_size_per_partition
|
||||
self.weight = bmt.DistributedParameter(
|
||||
torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
tp_split_dim=0,
|
||||
tp_mode=True,
|
||||
)
|
||||
|
||||
def forward(self, ids: torch.Tensor, gather_input=True):
|
||||
"""
|
||||
Args:
|
||||
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
|
||||
gather_input (bool) : whether gather input is required between tensor parallel group)
|
||||
Return:
|
||||
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
|
||||
""" # noqa: E501
|
||||
|
||||
if gather_input:
|
||||
ids = all_gather(ids, comm=config["tp_comm"])
|
||||
input_mask = (ids < self.start_index) | (ids >= self.end_index)
|
||||
ids = ids.clone() - self.start_index
|
||||
ids[input_mask] = 0
|
||||
|
||||
embeds = F.embedding(ids, self.weight)
|
||||
|
||||
embeds[input_mask, :] = 0.0
|
||||
embeds = all_reduce(embeds, op="sum", comm=config["tp_comm"])
|
||||
embed_list = embeds.chunk(config["tp_size"], dim=0)
|
||||
embeds = embed_list[config["tp_rank"]].flatten(0, 1)
|
||||
|
||||
if self.scale:
|
||||
embeds = embeds / math.sqrt(self.dim_model)
|
||||
|
||||
return embeds
|
||||
|
||||
def projection(self, x: torch.Tensor, gather_output=False, gather_input=True):
|
||||
"""
|
||||
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
|
||||
Returns:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
|
||||
""" # noqa: E501
|
||||
if self.scale:
|
||||
x = x / math.sqrt(self.dim_model)
|
||||
out = bmt.nn.OpParallelLinear.apply(x, self.weight, None, gather_input, gather_output, False, None)
|
||||
return out
|
||||
return logits
|
|
@ -1,13 +1,11 @@
|
|||
from typing import Optional
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
|
||||
from .linear import LastLinear
|
||||
from .linear import Linear
|
||||
|
||||
|
||||
class DenseGatedACT(bmt.DistributedModule):
|
||||
class DenseGatedACT(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
|
@ -15,7 +13,6 @@ class DenseGatedACT(bmt.DistributedModule):
|
|||
activate_fn: str = "gelu",
|
||||
scale: bool = True,
|
||||
dtype=torch.half,
|
||||
tp: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -25,7 +22,6 @@ class DenseGatedACT(bmt.DistributedModule):
|
|||
dtype=dtype,
|
||||
scale=scale,
|
||||
scale_before=False,
|
||||
tp=tp,
|
||||
)
|
||||
|
||||
self.w_1 = Linear(
|
||||
|
@ -34,9 +30,7 @@ class DenseGatedACT(bmt.DistributedModule):
|
|||
dtype=dtype,
|
||||
scale=scale,
|
||||
scale_before=False,
|
||||
tp=tp,
|
||||
)
|
||||
|
||||
if activate_fn == "gelu":
|
||||
self.act = torch.nn.GELU()
|
||||
elif activate_fn == "silu":
|
||||
|
@ -45,8 +39,7 @@ class DenseGatedACT(bmt.DistributedModule):
|
|||
raise NotImplementedError(f"{activate_fn} is not supported")
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""This model inherits from bmt.DistributedModule.
|
||||
Transform an input tensor from one feature space to another via a nonlinear operation
|
||||
"""Transform an input tensor from one feature space to another via a nonlinear operation
|
||||
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): Tensor that will be subject to nonlinear operations.
|
||||
|
@ -62,7 +55,7 @@ class DenseGatedACT(bmt.DistributedModule):
|
|||
return x
|
||||
|
||||
|
||||
class FeedForward(bmt.DistributedModule):
|
||||
class FeedForward(torch.nn.Module):
|
||||
r"""FeedForward module
|
||||
|
||||
Args:
|
||||
|
@ -85,7 +78,6 @@ class FeedForward(bmt.DistributedModule):
|
|||
dtype=torch.half,
|
||||
dropout_p: Optional[float] = None,
|
||||
scale: bool = True,
|
||||
tp: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -95,7 +87,6 @@ class FeedForward(bmt.DistributedModule):
|
|||
activate_fn=activate_fn,
|
||||
dtype=dtype,
|
||||
scale=scale,
|
||||
tp=tp,
|
||||
)
|
||||
|
||||
if dropout_p is not None:
|
||||
|
@ -103,13 +94,12 @@ class FeedForward(bmt.DistributedModule):
|
|||
else:
|
||||
self.dropout = None
|
||||
|
||||
self.w_out = LastLinear(
|
||||
self.w_out = Linear(
|
||||
dim_in=dim_ff,
|
||||
dim_out=dim_model,
|
||||
dtype=dtype,
|
||||
scale=scale,
|
||||
scale_before=False,
|
||||
tp=tp * 2,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
|
@ -127,4 +117,4 @@ class FeedForward(bmt.DistributedModule):
|
|||
|
||||
x = self.w_out(x)
|
||||
|
||||
return x
|
||||
return x
|
|
@ -1063,4 +1063,4 @@ class FlashAttnFunc(torch.autograd.Function):
|
|||
return dq, dk, dv, None, None, None
|
||||
|
||||
|
||||
flash_attn_func = FlashAttnFunc.apply
|
||||
flash_attn_func = FlashAttnFunc.apply
|
|
@ -1,8 +1,7 @@
|
|||
import bmtrain as bmt
|
||||
import torch
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@torch.jit.script # type: ignore
|
||||
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
|
||||
old_dtype = hidden.dtype
|
||||
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
||||
|
@ -10,21 +9,22 @@ def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
|
|||
return hidden * weight
|
||||
|
||||
|
||||
class LayerNorm(bmt.DistributedModule):
|
||||
class LayerNorm(torch.nn.Module):
|
||||
"""RMS LayerNorm"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_norm: int,
|
||||
dtype: torch.dtype = torch.half,
|
||||
eps: float = 1e-5,
|
||||
eps: float = 1e-6,
|
||||
init_var: float = 1.0,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.dim_norm = dim_norm
|
||||
self.weight = bmt.DistributedParameter(torch.full((dim_norm,), init_var, dtype=dtype))
|
||||
self.weight = torch.nn.parameter.Parameter(torch.full((dim_norm,), init_var, dtype=dtype))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
|
@ -34,4 +34,4 @@ class LayerNorm(bmt.DistributedModule):
|
|||
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``: The layernorm output.
|
||||
""" # noqa: E501
|
||||
assert x.size(-1) == self.dim_norm
|
||||
return rms_layernorm(x, self.weight, self.eps)
|
||||
return rms_layernorm(x, self.weight, self.eps)
|
|
@ -0,0 +1,45 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
dim_out: int,
|
||||
bias: bool = False,
|
||||
dtype: torch.dtype = torch.half,
|
||||
init_mean: float = 0.0,
|
||||
init_std: float = 1,
|
||||
scale: bool = True,
|
||||
scale_before: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_in = self.in_features = dim_in
|
||||
self.dim_out = self.out_features = dim_out
|
||||
self.scale = scale
|
||||
self.scale_before = scale_before
|
||||
self.weight = torch.nn.parameter.Parameter(torch.empty((dim_out, dim_in), dtype=dtype))
|
||||
self.bias = None
|
||||
if bias:
|
||||
self.bias = torch.nn.parameter.Parameter(torch.empty(dim_out, dtype=dtype))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
|
||||
Returns:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
|
||||
""" # noqa: E501
|
||||
if self.scale:
|
||||
if self.scale_before:
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
x = F.linear(x, self.weight, bias=self.bias)
|
||||
else:
|
||||
x = F.linear(x, self.weight, bias=self.bias)
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
else:
|
||||
x = F.linear(x, self.weight, bias=self.bias)
|
||||
return x
|
|
@ -2,17 +2,11 @@ import math
|
|||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from flash_attn.layers.rotary import apply_rotary_emb_func
|
||||
except:
|
||||
apply_rotary_emb_func = None
|
||||
|
||||
|
||||
class SegmentPositionEmbedding(bmt.DistributedModule):
|
||||
class SegmentPositionEmbedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
|
@ -32,10 +26,10 @@ class SegmentPositionEmbedding(bmt.DistributedModule):
|
|||
self.bidirectional = bidirectional
|
||||
self.num_segments = num_segments
|
||||
|
||||
self.relative_attention_bias = bmt.DistributedParameter(
|
||||
torch.empty(num_segments * num_segments + num_buckets, num_heads, dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
self.relative_attention_bias = torch.nn.parameter.Parameter(
|
||||
torch.empty(num_segments * num_segments + num_buckets, num_heads, dtype=dtype)
|
||||
)
|
||||
torch.nn.init.normal_(self.relative_attention_bias, mean=init_mean, std=init_std)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -107,7 +101,7 @@ class SegmentPositionEmbedding(bmt.DistributedModule):
|
|||
return relative_buckets
|
||||
|
||||
|
||||
class BucketPositionBias(bmt.DistributedModule):
|
||||
class BucketPositionBias(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
|
@ -125,10 +119,10 @@ class BucketPositionBias(bmt.DistributedModule):
|
|||
self.num_segment_bucket = num_segment_bucket
|
||||
self.max_distance = max_distance
|
||||
|
||||
self.relative_attention_bias = bmt.DistributedParameter(
|
||||
torch.empty(num_buckets + num_segment_bucket, num_heads, dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
self.relative_attention_bias = torch.nn.parameter.Parameter(
|
||||
torch.empty(num_buckets + num_segment_bucket, num_heads, dtype=dtype)
|
||||
)
|
||||
torch.nn.init.normal_(self.relative_attention_bias, mean=init_mean, std=init_std)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -186,11 +180,11 @@ class BucketPositionBias(bmt.DistributedModule):
|
|||
return relative_buckets
|
||||
|
||||
|
||||
class RotaryEmbedding(bmt.DistributedModule):
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
base: Union[int, float] = 10000,
|
||||
base=10000,
|
||||
distance_scale: Union[int, float] = 1,
|
||||
dtype: torch.dtype = torch.half,
|
||||
):
|
||||
|
@ -226,19 +220,19 @@ def rotate_half(x):
|
|||
|
||||
|
||||
def apply_rotary_pos_emb(x, cos, sin, seq_dim, offset):
|
||||
if x.size(seq_dim) < cos.size(seq_dim): # == do not need narrow
|
||||
if x.size(seq_dim) < cos.size(seq_dim):
|
||||
cos = cos.narrow(seq_dim, offset, x.size(seq_dim))
|
||||
sin = sin.narrow(seq_dim, offset, x.size(seq_dim))
|
||||
return (x * cos) + (rotate_half(x) * sin)
|
||||
|
||||
|
||||
def unpad_apply_rotary_pos_emb(x, cos, sin, seq_dim, position_ids):
|
||||
cos = cos.index_select(seq_dim, position_ids.view(-1))
|
||||
sin = sin.index_select(seq_dim, position_ids.view(-1))
|
||||
cos = cos.index_select(seq_dim, position_ids.squeeze(0))
|
||||
sin = sin.index_select(seq_dim, position_ids.squeeze(0))
|
||||
return (x * cos) + (rotate_half(x) * sin)
|
||||
|
||||
|
||||
class RotaryEmbeddingESM(bmt.DistributedModule):
|
||||
class RotaryEmbeddingESM(torch.nn.Module):
|
||||
"""
|
||||
Rotary position embeddings based on those in
|
||||
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
|
||||
|
@ -274,7 +268,8 @@ class RotaryEmbeddingESM(bmt.DistributedModule):
|
|||
self.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||
self.unpad_apply_rotary_pos_emb = unpad_apply_rotary_pos_emb
|
||||
|
||||
def _update_cos_sin_tables(self, x, seq_dim, seq_len):
|
||||
def _update_cos_sin_tables(self, x, seq_dim, offset):
|
||||
seq_len = x.size(seq_dim) + offset
|
||||
if seq_len > self._seq_len_cached or self._cos_cached.device != x.device:
|
||||
self._seq_len_cached = seq_len
|
||||
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
||||
|
@ -295,19 +290,11 @@ class RotaryEmbeddingESM(bmt.DistributedModule):
|
|||
self, q: torch.Tensor, k: torch.Tensor, seq_dim, offset=0, cu_seqlens=None, max_length=None, position_ids=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_dim = (seq_dim + k.dim()) % k.dim()
|
||||
if cu_seqlens is None:
|
||||
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, k.size(seq_dim) + offset)
|
||||
return (
|
||||
self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, offset),
|
||||
self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, offset),
|
||||
)
|
||||
else:
|
||||
assert offset == 0, "past kv is not supported in flash attn"
|
||||
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, max_length)
|
||||
return (
|
||||
self.unpad_apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, position_ids),
|
||||
self.unpad_apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, position_ids),
|
||||
)
|
||||
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, offset)
|
||||
return (
|
||||
self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, offset),
|
||||
self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, offset),
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
|
@ -334,7 +321,7 @@ def apply_chatglm_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> t
|
|||
return ret
|
||||
|
||||
|
||||
class ChatGLMRotaryEmbedding(bmt.DistributedModule):
|
||||
class ChatGLMRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, device="cuda", dtype=torch.float16, persistent=True):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=dtype, device=device) / dim))
|
||||
|
@ -365,4 +352,4 @@ class ChatGLMRotaryEmbedding(bmt.DistributedModule):
|
|||
return cache
|
||||
|
||||
def forward(self, max_seq_len, offset: int = 0):
|
||||
return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
|
||||
return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
|
|
@ -2,14 +2,13 @@ from typing import List
|
|||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
|
||||
from .blocks import TransformerBlock
|
||||
from .layernorm import LayerNorm
|
||||
|
||||
|
||||
class Encoder(bmt.DistributedModule):
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Layers of encoder transformer blocks plus an final layernorm.
|
||||
|
||||
Args:
|
||||
|
@ -19,7 +18,7 @@ class Encoder(bmt.DistributedModule):
|
|||
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
|
||||
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
|
||||
dtype (optional): Defaults to torch.half.
|
||||
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
|
||||
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-6.
|
||||
dropout_p (float, optional): Defaults to 0.
|
||||
""" # noqa: E501
|
||||
|
||||
|
@ -33,49 +32,47 @@ class Encoder(bmt.DistributedModule):
|
|||
num_kv_heads: int = -1,
|
||||
activate_fn: str = "gelu",
|
||||
dtype: torch.dtype = torch.half,
|
||||
eps: float = 1e-5,
|
||||
eps: float = 1e-6,
|
||||
dropout_p: Optional[float] = None,
|
||||
scale: bool = True,
|
||||
add_qkv_bias: bool = False,
|
||||
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
|
||||
use_flash_attn: bool = False,
|
||||
tp: int = 0,
|
||||
disabled_checkpoint: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if num_kv_heads == -1:
|
||||
num_kv_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
|
||||
if mask_modules is not None:
|
||||
assert len(mask_modules) == num_layers, "The total number of masks should equal to num_layers"
|
||||
for mask_module in mask_modules:
|
||||
assert len(mask_module) == 2, "For encoder, each mask should be (mask_att, mask_ffn)"
|
||||
else:
|
||||
mask_modules = [(False, False)] * num_layers
|
||||
self.layers = bmt.TransformerBlockList(
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[
|
||||
bmt.CheckpointBlock(
|
||||
TransformerBlock(
|
||||
dim_model=dim_model,
|
||||
dim_ff=dim_ff,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
dim_head=dim_head,
|
||||
activate_fn=activate_fn,
|
||||
dtype=dtype,
|
||||
eps=eps,
|
||||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
add_qkv_bias=add_qkv_bias,
|
||||
mask_att=mask_modules[ith][0],
|
||||
mask_ffn=mask_modules[ith][1],
|
||||
use_flash_attn=use_flash_attn,
|
||||
tp=tp,
|
||||
),
|
||||
TransformerBlock(
|
||||
dim_model=dim_model,
|
||||
dim_ff=dim_ff,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
dim_head=dim_head,
|
||||
activate_fn=activate_fn,
|
||||
dtype=dtype,
|
||||
eps=eps,
|
||||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
add_qkv_bias=add_qkv_bias,
|
||||
mask_att=mask_modules[ith][0],
|
||||
mask_ffn=mask_modules[ith][1],
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for ith in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.output_layernorm = LayerNorm(dim_norm=dim_model, dtype=dtype, eps=eps)
|
||||
|
||||
def forward(
|
||||
|
@ -87,6 +84,7 @@ class Encoder(bmt.DistributedModule):
|
|||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
pos_bias_type: Optional[str] = "relative",
|
||||
length_mask: Optional[torch.Tensor] = None,
|
||||
context_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask_bias: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
|
@ -103,19 +101,21 @@ class Encoder(bmt.DistributedModule):
|
|||
|
||||
""" # noqa: E501
|
||||
if not use_cache:
|
||||
hidden_states = self.layers(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_bias,
|
||||
False,
|
||||
None,
|
||||
pos_bias_type,
|
||||
length_mask,
|
||||
attention_mask_bias,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
position_ids,
|
||||
)
|
||||
for i, module in enumerate(self.layers):
|
||||
hidden_states = module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_bias,
|
||||
False,
|
||||
None,
|
||||
pos_bias_type,
|
||||
length_mask,
|
||||
context_mask,
|
||||
attention_mask_bias,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
position_ids,
|
||||
)
|
||||
hidden_states = self.output_layernorm(hidden_states)
|
||||
return hidden_states
|
||||
else:
|
||||
|
@ -131,6 +131,7 @@ class Encoder(bmt.DistributedModule):
|
|||
past_key_values[i] if past_key_values else None,
|
||||
pos_bias_type,
|
||||
length_mask,
|
||||
context_mask,
|
||||
attention_mask_bias,
|
||||
)
|
||||
if use_cache:
|
||||
|
@ -141,4 +142,4 @@ class Encoder(bmt.DistributedModule):
|
|||
if use_cache:
|
||||
return hidden_states, current_key_values, current_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
return hidden_states
|
|
@ -0,0 +1,237 @@
|
|||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from vis_fm9g.layers import Embedding
|
||||
from vis_fm9g.layers import Encoder
|
||||
from vis_fm9g.layers import RotaryEmbeddingESM
|
||||
from vis_fm9g.utils.config import Config
|
||||
|
||||
|
||||
class FM9GConfig(Config):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
dim_model=4096,
|
||||
num_heads=32,
|
||||
num_kv_heads=32,
|
||||
dim_head=128,
|
||||
dim_ff=11008,
|
||||
num_layers=32,
|
||||
dropout_p=0.0,
|
||||
activate_fn="silu",
|
||||
scale=True,
|
||||
eps=1e-5,
|
||||
bf16: bool = False,
|
||||
half: bool = True,
|
||||
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
|
||||
use_flash_attn: bool = True,
|
||||
flash_attn_mask_shape="1d",
|
||||
flash_impl="cuda",
|
||||
base=10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim_model = dim_model
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.dim_head = dim_head
|
||||
self.dim_ff = dim_ff
|
||||
self.num_layers = num_layers
|
||||
self.dropout_p = dropout_p
|
||||
self.activate_fn = activate_fn
|
||||
self.scale = scale
|
||||
self.eps = eps
|
||||
if bf16:
|
||||
self.dtype = torch.bfloat16
|
||||
elif half:
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.float
|
||||
self.flash_impl = flash_impl
|
||||
self.mask_modules = mask_modules
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.flash_attn_mask_shape = flash_attn_mask_shape
|
||||
self.base = base
|
||||
|
||||
|
||||
class FM9GInferenceState(TypedDict):
|
||||
buffer_context: torch.Tensor
|
||||
buffer_sample_ids: torch.Tensor
|
||||
buffer: List[Tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
|
||||
class FM9GTorch(torch.nn.Module):
|
||||
def __init__(self, config: FM9GConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.encoder = Encoder(
|
||||
num_layers=config.num_layers,
|
||||
dim_model=config.dim_model,
|
||||
dim_ff=config.dim_ff,
|
||||
num_heads=config.num_heads,
|
||||
num_kv_heads=config.num_kv_heads,
|
||||
dim_head=config.dim_head,
|
||||
activate_fn=config.activate_fn,
|
||||
dtype=config.dtype,
|
||||
eps=config.eps,
|
||||
dropout_p=config.dropout_p,
|
||||
scale=config.scale,
|
||||
mask_modules=config.mask_modules,
|
||||
use_flash_attn=config.use_flash_attn,
|
||||
)
|
||||
|
||||
self.input_embedding = Embedding(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=config.dim_model,
|
||||
scale=config.scale,
|
||||
dtype=config.dtype,
|
||||
init_std=0.02,
|
||||
)
|
||||
|
||||
self.position_bias = RotaryEmbeddingESM(
|
||||
dim=config.dim_head, dtype=config.dtype, base=config.base, persistent=False, mixed_precision=True
|
||||
)
|
||||
|
||||
self.lm_head = Embedding(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=config.dim_model,
|
||||
scale=config.scale,
|
||||
dtype=config.dtype,
|
||||
init_std=0.02,
|
||||
)
|
||||
self.flash_impl = config.flash_impl
|
||||
self.use_flash_attn = config.use_flash_attn
|
||||
self.flash_attn_mask_shape = config.flash_attn_mask_shape
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor, # (batch, seqlen) int32
|
||||
length: torch.Tensor = None, # (batch) int32
|
||||
context: torch.Tensor = None, # (batch, seqlen) bool
|
||||
span: torch.Tensor = None, # (batch, seqlen) int32
|
||||
cu_seqlens: torch.Tensor = None, # (real_batch+2) int32
|
||||
max_seqlen: int = None,
|
||||
position_ids: torch.Tensor = None, # (batch, seqlen) int32
|
||||
hidden_states: torch.Tensor = None
|
||||
):
|
||||
batch = input.size(0)
|
||||
seqlen = input.size(1)
|
||||
device = input.device
|
||||
|
||||
if length is not None and length.dim() == 1:
|
||||
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
|
||||
|
||||
# processing masks and position bias bucket
|
||||
if not self.use_flash_attn or (self.flash_attn_mask_shape == "2d" and self.flash_impl == "triton"):
|
||||
with torch.no_grad():
|
||||
# directional mask
|
||||
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
|
||||
-1, 1
|
||||
)
|
||||
# context mask
|
||||
attention_mask = context[:, None, :] | (
|
||||
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
||||
)
|
||||
# span mask
|
||||
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
||||
# length mask
|
||||
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
||||
|
||||
if hidden_states is None:
|
||||
hidden_states = self.input_embedding(input)
|
||||
|
||||
if self.use_flash_attn:
|
||||
if self.flash_attn_mask_shape == "1d":
|
||||
hidden_states = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
position_bias=self.position_bias,
|
||||
pos_bias_type="rotary",
|
||||
length_mask=length,
|
||||
context_mask=context.to(torch.int16) + 2 * (span.to(torch.int16) + length.to(torch.int16)),
|
||||
)
|
||||
else:
|
||||
if self.flash_impl == "triton":
|
||||
mask = attention_mask.unsqueeze(dim=1).contiguous()
|
||||
attention_mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16)
|
||||
attention_mask_bias[mask == False] -= torch.inf
|
||||
else:
|
||||
attention_mask_bias = None
|
||||
assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl"
|
||||
hidden_states = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
position_bias=self.position_bias,
|
||||
pos_bias_type="rotary",
|
||||
length_mask=None,
|
||||
context_mask=None,
|
||||
attention_mask_bias=attention_mask_bias,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.encoder(
|
||||
hidden_states, attention_mask=attention_mask, position_bias=self.position_bias, pos_bias_type="rotary"
|
||||
)
|
||||
|
||||
logits = self.lm_head.projection(hidden_states)
|
||||
|
||||
return logits, hidden_states
|
||||
|
||||
def inference(
|
||||
self,
|
||||
input: torch.Tensor, # (batch, len_q) int32
|
||||
length: torch.Tensor, # (batch) int32
|
||||
context: torch.Tensor, # (batch, seqlen) int16
|
||||
span: torch.Tensor, # (batch, seqlen) int32
|
||||
past_key_values: Optional[FM9GInferenceState] = None,
|
||||
hidden_states: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, FM9GInferenceState]:
|
||||
batch = input.size(0)
|
||||
len_q = input.size(1)
|
||||
len_buffer = 0
|
||||
if past_key_values is None:
|
||||
present_buffer = None
|
||||
else:
|
||||
present_buffer = past_key_values["buffer"]
|
||||
len_buffer = present_buffer[0][0].shape[-2]
|
||||
seqlen = len_buffer + len_q
|
||||
with torch.no_grad():
|
||||
device = input.device
|
||||
if length.dim() == 1:
|
||||
length = (torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) + length[:, None]) >= seqlen
|
||||
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
|
||||
# context mask
|
||||
attention_mask = context[:, None, :] | (
|
||||
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
||||
)
|
||||
# span mask
|
||||
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
||||
# length mask
|
||||
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
||||
|
||||
if hidden_states is None:
|
||||
hidden_states = self.input_embedding(input)
|
||||
|
||||
hidden_states, present_key_values, _ = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask[:, len_buffer:],
|
||||
position_bias=self.position_bias,
|
||||
use_cache=True,
|
||||
past_key_values=present_buffer,
|
||||
pos_bias_type="rotary",
|
||||
)
|
||||
|
||||
logits = self.lm_head.projection(hidden_states)
|
||||
|
||||
return (
|
||||
logits,
|
||||
hidden_states,
|
||||
{"buffer": present_key_values},
|
||||
)
|
|
@ -0,0 +1,158 @@
|
|||
# Copyright (c) Alibaba Cloud.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
|
||||
def get_abs_pos(abs_pos, tgt_size):
|
||||
# abs_pos: L, C
|
||||
# tgt_size: M
|
||||
# return: M, C
|
||||
src_size = int(math.sqrt(abs_pos.size(0)))
|
||||
tgt_size = int(math.sqrt(tgt_size))
|
||||
dtype = abs_pos.dtype
|
||||
|
||||
if src_size != tgt_size:
|
||||
return F.interpolate(
|
||||
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
||||
size=(tgt_size, tgt_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
|
||||
else:
|
||||
return abs_pos
|
||||
|
||||
|
||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token:
|
||||
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
"""
|
||||
A 2D perceiver-resampler network with one cross attention layers by
|
||||
(grid_size**2) learnable queries and 2d sincos pos_emb
|
||||
Outputs:
|
||||
A tensor with the shape of (grid_size**2, embed_dim)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grid_size,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kv_dim=None,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
||||
):
|
||||
super().__init__()
|
||||
self.num_queries = grid_size ** 2
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
|
||||
).requires_grad_(False)
|
||||
|
||||
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
||||
trunc_normal_(self.query, std=.02)
|
||||
|
||||
if kv_dim is not None and kv_dim != embed_dim:
|
||||
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
|
||||
else:
|
||||
self.kv_proj = nn.Identity()
|
||||
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.ln_q = norm_layer(embed_dim)
|
||||
self.ln_kv = norm_layer(embed_dim)
|
||||
|
||||
self.ln_post = norm_layer(embed_dim)
|
||||
self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
|
||||
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
||||
|
||||
x = self.kv_proj(x)
|
||||
x = self.ln_kv(x).permute(1, 0, 2)
|
||||
|
||||
N = x.shape[1]
|
||||
q = self.ln_q(self.query)
|
||||
out = self.attn(
|
||||
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
|
||||
x + pos_embed.unsqueeze(1),
|
||||
x,
|
||||
attn_mask=attn_mask)[0]
|
||||
x = out.permute(1, 0, 2)
|
||||
|
||||
x = self.ln_post(x)
|
||||
x = x @ self.proj
|
||||
return x
|
||||
|
||||
def _repeat(self, query, N: int):
|
||||
return query.unsqueeze(1).repeat(1, N, 1)
|
|
@ -0,0 +1,110 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import timm
|
||||
import torch
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from vis_fm9g.model.fm9g import FM9GTorch
|
||||
from vis_fm9g.model.resampler import Resampler
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalVLLMOutput(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class VLU_FM9G(torch.nn.Module):
|
||||
def __init__(self, llm: FM9GTorch, vpm: timm.models.VisionTransformer, vision_dim, query_num) -> None:
|
||||
super().__init__()
|
||||
self.vpm = vpm
|
||||
self.llm = llm
|
||||
|
||||
self.vision_dim = vision_dim
|
||||
self.query_num = query_num
|
||||
|
||||
embed_dim = self.llm.config.dim_model
|
||||
self.resampler = Resampler(
|
||||
grid_size=int(math.sqrt(query_num)),
|
||||
embed_dim=embed_dim,
|
||||
num_heads=embed_dim // 128,
|
||||
kv_dim=self.vpm.embed_dim,
|
||||
)
|
||||
|
||||
def get_vision_embedding(self, pixel_values):
|
||||
res = []
|
||||
dtype = self.vpm.pos_embed.data.dtype
|
||||
for pixel_value in pixel_values:
|
||||
vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
|
||||
if hasattr(self.vpm, 'num_prefix_tokens') and self.vpm.num_prefix_tokens > 0:
|
||||
vision_embedding = vision_embedding[:, self.vpm.num_prefix_tokens:]
|
||||
res.append(self.resampler(vision_embedding))
|
||||
return torch.vstack(res)
|
||||
|
||||
|
||||
def get_vllm_embedding(self, data):
|
||||
if 'vision_hidden_states' not in data:
|
||||
pixel_values_list = data['pixel_values']
|
||||
vision_hidden_states = []
|
||||
for pixel_values in pixel_values_list:
|
||||
if len(pixel_values) > 0:
|
||||
vision_hidden_states.append(self.get_vision_embedding(pixel_values))
|
||||
else:
|
||||
vision_hidden_states.append([])
|
||||
else:
|
||||
vision_hidden_states = data['vision_hidden_states']
|
||||
|
||||
vllm_embedding = self.llm.input_embedding(data['input_ids'])
|
||||
vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
|
||||
i, torch.Tensor) else i for i in vision_hidden_states]
|
||||
|
||||
bs = len(data['input_ids'])
|
||||
for i in range(bs):
|
||||
cur_vs_hs = vision_hidden_states[i]
|
||||
if len(cur_vs_hs) > 0:
|
||||
cur_vllm_emb = vllm_embedding[i]
|
||||
cur_image_bound = data['image_bound'][i]
|
||||
|
||||
image_indices = torch.stack(
|
||||
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
||||
).to(vllm_embedding.device)
|
||||
|
||||
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
||||
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
||||
|
||||
return vllm_embedding, vision_hidden_states
|
||||
|
||||
def forward(self, data, **kwargs):
|
||||
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
||||
|
||||
input_ids = data["input_ids"]
|
||||
input_length = data["length"]
|
||||
if self.llm.config.flash_impl == 'cuda':
|
||||
cu_seqlens = data["cu_seqlens"]
|
||||
max_seqlen = data["max_seqlen"]
|
||||
position_ids = data["position_ids"]
|
||||
logits, hidden_states = self.llm(
|
||||
input_ids,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
position_ids=position_ids,
|
||||
hidden_states=vllm_embedding
|
||||
)
|
||||
else:
|
||||
input_context = torch.zeros_like(input_ids).cuda().bool()
|
||||
input_span = data["spans"]
|
||||
logits, hidden_states = self.llm(
|
||||
input_ids,
|
||||
input_length,
|
||||
input_context,
|
||||
input_span,
|
||||
hidden_states=vllm_embedding
|
||||
)
|
||||
|
||||
return CausalVLLMOutput(
|
||||
logits=logits,
|
||||
hidden_states=hidden_states,
|
||||
)
|
|
@ -1,12 +1,14 @@
|
|||
import os
|
||||
import io
|
||||
import json
|
||||
from typing import Dict
|
||||
from typing import IO
|
||||
from typing import List
|
||||
|
||||
import pkg_resources
|
||||
from pytrie import StringTrie
|
||||
from transformers import LlamaTokenizer
|
||||
from transformers.tokenization_utils import Trie
|
||||
|
||||
file_path = os.path.dirname(__file__)
|
||||
|
||||
def load_vocab(fp: IO[bytes]) -> Dict[str, int]:
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
|
@ -23,50 +25,66 @@ def load_vocab(fp: IO[bytes]) -> Dict[str, int]:
|
|||
|
||||
|
||||
class FM9GTokenizer(object):
|
||||
def __init__(self, path=None):
|
||||
def __init__(self, vocabs_path=None):
|
||||
self.unk_token = "<unk>"
|
||||
self.bos_token = "<s>"
|
||||
self.eos_token = "</s>"
|
||||
self.im_start = "<image>"
|
||||
self.im_end = "</image>"
|
||||
self.ref_start = "<ref>"
|
||||
self.ref_end = "</ref>"
|
||||
self.box_start = "<quad>"
|
||||
self.box_end = "</quad>"
|
||||
self.quad_start = "<quad>"
|
||||
self.quad_end = "</quad>"
|
||||
|
||||
self.tire = Trie()
|
||||
|
||||
self.byte_list = ["<0x0{}>".format(hex(i).upper()[2:]) for i in range(0x10)] + [
|
||||
"<0x{}>".format(hex(i).upper()[2:]) for i in range(0x10, 0x100)
|
||||
]
|
||||
|
||||
self._special_token_set = set([self.unk_token, self.bos_token, self.eos_token] + self.byte_list)
|
||||
self._special_token_set = {self.unk_token, self.bos_token, self.eos_token,
|
||||
self.im_start, self.im_end,
|
||||
self.ref_start, self.ref_end,
|
||||
self.box_start, self.box_end,
|
||||
self.quad_start, self.quad_end}
|
||||
|
||||
if path:
|
||||
all_tokens = load_vocab(io.FileIO(path, "rb"))
|
||||
else:
|
||||
all_tokens = load_vocab(pkg_resources.resource_stream("fm9g", "/fm9g/vocabs/fm9g.txt"))
|
||||
self._byte_set = set(self.byte_list)
|
||||
|
||||
# never split special token
|
||||
for t in self._special_token_set:
|
||||
self.tire.add(t)
|
||||
|
||||
if not vocabs_path:
|
||||
vocabs_path = os.path.join(file_path, "../config/caterpillar.txt")
|
||||
|
||||
all_tokens = load_vocab(io.FileIO(vocabs_path, "rb"))
|
||||
|
||||
self.encoder: Dict[str, int] = {}
|
||||
self._special_encoder: Dict[str, int] = {}
|
||||
for token, token_id in all_tokens.items():
|
||||
if token in self._special_token_set:
|
||||
self.encoder[token] = token_id
|
||||
self._special_encoder[token] = token_id
|
||||
elif token in self._byte_set:
|
||||
self._special_encoder[token] = token_id
|
||||
else:
|
||||
self.encoder[token] = token_id
|
||||
|
||||
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self._byte_decoder = {self._special_encoder[token]: i for i, token in enumerate(self.byte_list)}
|
||||
|
||||
self._max_word_len = max([len(x) for x in self.encoder.keys()])
|
||||
|
||||
self._len_word_first = {}
|
||||
for x in self.encoder.keys():
|
||||
if not x[0] in self._len_word_first:
|
||||
self._len_word_first[x[0]] = 1
|
||||
if len(x) > self._len_word_first[x[0]]:
|
||||
self._len_word_first[x[0]] = len(x)
|
||||
self.tencoder = StringTrie(self.encoder)
|
||||
|
||||
def get_piece(self, text: str) -> str:
|
||||
if text[0] in self._len_word_first:
|
||||
text = text[: self._len_word_first[text[0]]]
|
||||
len_text = len(text)
|
||||
for i in range(len(text)):
|
||||
sub = text[: len_text - i]
|
||||
if sub in self.encoder:
|
||||
return sub
|
||||
text = text[: self._max_word_len]
|
||||
len_text = len(text)
|
||||
for i in range(len(text)):
|
||||
sub = text[: len_text - i]
|
||||
if sub in self.encoder:
|
||||
return sub
|
||||
return text[0]
|
||||
|
||||
@property
|
||||
|
@ -85,16 +103,26 @@ class FM9GTokenizer(object):
|
|||
def unk_id(self):
|
||||
return self._special_encoder[self.unk_token]
|
||||
|
||||
@property
|
||||
def im_start_id(self):
|
||||
return self._special_encoder[self.im_start]
|
||||
|
||||
@property
|
||||
def im_end_id(self):
|
||||
return self._special_encoder[self.im_end]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self._special_encoder)
|
||||
|
||||
def tokenize(self, text: str) -> List[str]:
|
||||
texts = self.tire.split(text)
|
||||
output_tokens: List[str] = []
|
||||
st = 0
|
||||
while st < len(text):
|
||||
piece = self.get_piece(text[st:])
|
||||
output_tokens.append(piece)
|
||||
st += len(piece)
|
||||
for text in texts:
|
||||
st = 0
|
||||
while st < len(text):
|
||||
piece = self.get_piece(text[st:])
|
||||
output_tokens.append(piece)
|
||||
st += len(piece)
|
||||
return output_tokens
|
||||
|
||||
@staticmethod
|
||||
|
@ -106,8 +134,6 @@ class FM9GTokenizer(object):
|
|||
return text
|
||||
|
||||
def encode(self, text: str) -> List[int]:
|
||||
#if len(text) > 20480:
|
||||
# return [0 for _ in range(20480)]
|
||||
ret = []
|
||||
for x in self.tokenize(text):
|
||||
if x in self.encoder:
|
||||
|
@ -127,25 +153,6 @@ class FM9GTokenizer(object):
|
|||
st += 1
|
||||
elif tokens[st] in self._byte_decoder:
|
||||
if (
|
||||
st + 3 < len(tokens)
|
||||
and tokens[st + 1] in self._byte_decoder
|
||||
and tokens[st + 2] in self._byte_decoder
|
||||
and tokens[st + 3] in self._byte_decoder
|
||||
):
|
||||
first_id = self._byte_decoder[tokens[st]]
|
||||
plane_id = self._byte_decoder[tokens[st + 1]]
|
||||
row_id = self._byte_decoder[tokens[st + 2]]
|
||||
cell_id = self._byte_decoder[tokens[st + 3]]
|
||||
int_bytes = int.to_bytes(first_id << 24 | plane_id << 16 | row_id << 8 | cell_id, 4, "big")
|
||||
try:
|
||||
decoded_str = int_bytes.decode("utf-8", errors="replace")
|
||||
ret.append(decoded_str)
|
||||
#print(decoded_str)
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"UnicodeDecodeError: {e}")
|
||||
|
||||
st += 4
|
||||
elif (
|
||||
st + 2 < len(tokens)
|
||||
and tokens[st + 1] in self._byte_decoder
|
||||
and tokens[st + 2] in self._byte_decoder
|
||||
|
@ -153,33 +160,16 @@ class FM9GTokenizer(object):
|
|||
plane_id = self._byte_decoder[tokens[st]]
|
||||
row_id = self._byte_decoder[tokens[st + 1]]
|
||||
cell_id = self._byte_decoder[tokens[st + 2]]
|
||||
int_bytes = int.to_bytes(plane_id << 16 | row_id << 8 | cell_id, 3, "big")
|
||||
try:
|
||||
decoded_str = int_bytes.decode("utf-8", errors="replace")
|
||||
ret.append(decoded_str)
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"UnicodeDecodeError: {e}")
|
||||
ret.append(int.to_bytes(plane_id << 16 | row_id << 8 | cell_id, 3, "big").decode("utf-8"))
|
||||
st += 3
|
||||
elif st + 1 < len(tokens) and tokens[st + 1] in self._byte_decoder:
|
||||
row_id = self._byte_decoder[tokens[st]]
|
||||
cell_id = self._byte_decoder[tokens[st + 1]]
|
||||
int_bytes = int.to_bytes(row_id << 8 | cell_id, 2, "big")
|
||||
try:
|
||||
decoded_str = int_bytes.decode("utf-8", errors="replace")
|
||||
ret.append(decoded_str)
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"UnicodeDecodeError: {e}")
|
||||
#ret.append(int.to_bytes(row_id << 8 | cell_id, 2, "big").decode("utf-8"))
|
||||
ret.append(int.to_bytes(row_id << 8 | cell_id, 2, "big").decode("utf-8"))
|
||||
st += 2
|
||||
else:
|
||||
cell_id = self._byte_decoder[tokens[st]]
|
||||
int_bytes = int.to_bytes(cell_id, 1, "big")
|
||||
try:
|
||||
decoded_str = int_bytes.decode("utf-8", errors="replace")
|
||||
ret.append(decoded_str)
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"UnicodeDecodeError: {e}")
|
||||
#ret.append(int.to_bytes(cell_id, 1, "big").decode("utf-8"))
|
||||
ret.append(int.to_bytes(cell_id, 1, "big").decode("utf-8"))
|
||||
st += 1
|
||||
elif tokens[st] == self.eos_id:
|
||||
ret.append(self.eos_token)
|
||||
|
@ -196,16 +186,53 @@ class FM9GTokenizer(object):
|
|||
# wrap unicode encoding into a helper function
|
||||
ids = []
|
||||
utf8_id = token.encode("utf-8")
|
||||
for _id in utf8_id:
|
||||
ids.append(self._special_encoder[self.byte_list[_id]])
|
||||
plane_id = utf8_id[-3] if len(utf8_id) >= 3 else 0
|
||||
row_id = utf8_id[-2] if len(utf8_id) >= 2 else 0
|
||||
cell_id = utf8_id[-1] if len(utf8_id) >= 1 else 0
|
||||
if plane_id > 0:
|
||||
ids.append(self._special_encoder[self.byte_list[plane_id]])
|
||||
if row_id > 0:
|
||||
ids.append(self._special_encoder[self.byte_list[row_id]])
|
||||
ids.append(self._special_encoder[self.byte_list[cell_id]])
|
||||
return ids
|
||||
|
||||
def next_token(self, text):
|
||||
# fast next token matching
|
||||
token, token_id = self.tencoder.longest_prefix_item(text, (None, None))
|
||||
if token is None:
|
||||
token = text[0]
|
||||
token_ids = self._encode_unicode(token)
|
||||
else:
|
||||
token_ids = [token_id]
|
||||
return token, token_ids
|
||||
|
||||
class LlamaTokenizerWrapper(LlamaTokenizer):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.im_start = "<image>"
|
||||
self.im_end = "</image>"
|
||||
self.ref_start = "<ref>"
|
||||
self.ref_end = "</ref>"
|
||||
self.box_start = "<box>"
|
||||
self.box_end = "</box>"
|
||||
self.quad_start = "<quad>"
|
||||
self.quad_end = "</quad>"
|
||||
|
||||
@property
|
||||
def eos_id(self):
|
||||
return self.sp_model.eos_id()
|
||||
|
||||
@property
|
||||
def bos_id(self):
|
||||
return self.sp_model.bos_id()
|
||||
|
||||
@property
|
||||
def unk_id(self):
|
||||
return self.sp_model.unk_id()
|
||||
|
||||
@property
|
||||
def im_start_id(self):
|
||||
return self._convert_token_to_id(self.im_start)
|
||||
|
||||
@property
|
||||
def im_end_id(self):
|
||||
return self._convert_token_to_id(self.im_end)
|
||||
|
||||
@staticmethod
|
||||
def escape(text: str) -> str:
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def unescape(text: str) -> str:
|
||||
return text
|
|
@ -0,0 +1,56 @@
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed
|
||||
from deepspeed.utils import logger
|
||||
|
||||
from utils.utils import is_main_process
|
||||
|
||||
|
||||
def export(vllm_engine, cur_epoch_step, global_step, epoch, args):
|
||||
if args.save_deepspeed:
|
||||
logger.info(f'start to deepspped ckpt, save_dir={args.exp_ckpt_dir}')
|
||||
vllm_engine.save_checkpoint(save_dir=args.exp_ckpt_dir, tag=f'global_step{global_step}', client_state={
|
||||
'checkpoint_step': global_step, 'epoch': epoch, 'cur_epoch_step': cur_epoch_step})
|
||||
|
||||
export_model_dir = os.path.join(args.exp_ckpt_dir, f'{args.exp_name}_epoch_{epoch}_ckpt_{global_step}')
|
||||
os.makedirs(export_model_dir, exist_ok=True)
|
||||
base_file_name = f'{args.exp_name}_{cur_epoch_step}_{global_step}'
|
||||
|
||||
# model files
|
||||
if is_main_process():
|
||||
model_state_dict_path = os.path.join(export_model_dir, base_file_name + '.pt')
|
||||
# config 和 vocabs 和模型文件一起存储
|
||||
model_cfg_path = os.path.join(export_model_dir, 'config.json')
|
||||
model_vocab_path = os.path.join(export_model_dir, 'vocabs.txt')
|
||||
paths = [model_state_dict_path, model_cfg_path, model_vocab_path]
|
||||
|
||||
torch.save(vllm_engine.module.state_dict(), model_state_dict_path)
|
||||
shutil.copy(args.llm_path, model_cfg_path)
|
||||
shutil.copy(args.vocabs_path, model_vocab_path)
|
||||
|
||||
info = {
|
||||
'global_step': global_step,
|
||||
'epoch': epoch,
|
||||
'cur_epoch_step': cur_epoch_step,
|
||||
'last_ckpt': model_state_dict_path,
|
||||
'config': model_cfg_path,
|
||||
'vocab': model_vocab_path
|
||||
}
|
||||
with open(os.path.join(export_model_dir, 'lastest_info'), 'w') as f:
|
||||
json.dump(info, f, indent=2)
|
||||
logger.info(f'Successfully save model files! {paths}')
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def export_eval_file(df, global_step, args):
|
||||
export_dir = os.path.join(args.exp_ckpt_dir, 'pretrain_eval')
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
base_file_name = f'{"_".join(args.model_checkpoint.split("/")[-2:])}_{args.vision_encoder.split("_")[0]}_{global_step}'
|
||||
|
||||
if is_main_process():
|
||||
eval_result_path = os.path.join(export_dir, base_file_name + '.csv')
|
||||
logger.info(f'save eval result file to {eval_result_path}')
|
||||
df.to_csv(eval_result_path, index=False)
|
|
@ -0,0 +1,212 @@
|
|||
import os
|
||||
import glob
|
||||
import argparse
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from utils.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__, level="INFO")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser('VLLM pre-training script', add_help=False)
|
||||
|
||||
parser.add_argument('--self_dir', type=str)
|
||||
parser.add_argument('--train_file', type=str)
|
||||
parser.add_argument('--eval_file', type=str)
|
||||
parser.add_argument('--test_file', type=str)
|
||||
parser.add_argument('--exp_name', type=str, default='multimodal')
|
||||
|
||||
parser.add_argument('--batch_size', default=2, type=int)
|
||||
parser.add_argument('--split_by_rank', action='store_true', default=False, help="split parquet by rank")
|
||||
parser.add_argument('--epochs', default=100, type=int)
|
||||
parser.add_argument('--log_step', default=50, type=int)
|
||||
parser.add_argument('--save_step', default=100, type=int)
|
||||
parser.add_argument('--sft', action='store_true', help='is training all parameter')
|
||||
parser.add_argument('--tune_vision', action='store_true', help='is train vision parameter')
|
||||
parser.add_argument('--tune_resampler', action='store_true', help='is train resampler parameter')
|
||||
parser.add_argument('--tune_llm', action='store_true', help='is train llm parameter')
|
||||
parser.add_argument('--delta_tuning', action='store_true', help='is use lora')
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument('--img_size', default=224, type=int)
|
||||
parser.add_argument('--llm_path', default=None, help='Path to LLM model to use', type=str)
|
||||
parser.add_argument('--vocabs_path', default=None, help='Path to vocabs to use', type=str)
|
||||
parser.add_argument('--model_checkpoint', default=None, help='Path to VLLM model to use', type=str)
|
||||
parser.add_argument('--llm_checkpoint', default=None, help='Path to LLM model to use', type=str)
|
||||
parser.add_argument('--data_state_dict_path', default=None, help='Path to dataset state dict', type=str)
|
||||
parser.add_argument('--vpm_path', help='Path to VPM model to use', type=str)
|
||||
parser.add_argument('--vpm_checkpoint', help='Path to VPM model to use', type=str)
|
||||
parser.add_argument('--vision_encoder', default='eva02_enormous_patch14_clip_224.laion2b_plus',
|
||||
choices=['eva02_enormous_patch14_clip_224.laion2b_plus', 'vit_so400m_patch14_siglip_384.webli'], type=str)
|
||||
parser.add_argument('--drop_vision_last_layer', action='store_true', help='is drop last vit layer')
|
||||
parser.add_argument('--prefix', default=None, help='Path prefix to save file', type=str)
|
||||
|
||||
# deepspeed
|
||||
parser.add_argument('--deepspeed_config', default=None, help='Path to deepspeed config to use', type=str)
|
||||
parser.add_argument('--save_deepspeed', action='store_true', default=False, help="is save deepspeed checkpoint")
|
||||
|
||||
# vlu
|
||||
parser.add_argument('--skip_overlength', action='store_true', default=False, help="is skip over length data")
|
||||
parser.add_argument('--skip_no_image', action='store_true', default=False, help="is skip no image data")
|
||||
parser.add_argument("--flash", default="none", choices=["none", "1d", "triton", "cuda"])
|
||||
|
||||
parser.add_argument('--max_length', default=256, type=int, help='max length of input')
|
||||
|
||||
# ----- Training -----
|
||||
parser.add_argument('--device', default='cuda',
|
||||
help='device to use for training / testing')
|
||||
parser.add_argument('--query_num', default=32, type=int,
|
||||
help='query numbers')
|
||||
parser.add_argument('--max_len', default=96, type=int,
|
||||
help='max len')
|
||||
parser.add_argument('--seed', default=0, type=int)
|
||||
parser.add_argument('--start_epoch', default=0, type=int)
|
||||
parser.add_argument('--start_step', default=0, type=int)
|
||||
parser.add_argument('--skip_step', default=0, type=int)
|
||||
parser.add_argument('--num_workers', default=5, type=int)
|
||||
parser.add_argument('--pin_mem', action='store_true',
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
|
||||
help='')
|
||||
parser.add_argument('--eval', action='store_true', default=False,
|
||||
help="Perform evaluation only")
|
||||
parser.add_argument('--eval_step', default=5000, type=int, help='evaluate step')
|
||||
parser.add_argument('--save_dataset', action='store_true', default=False, help="is save dataset state dict")
|
||||
parser.add_argument('--export_dir', default=None, help='Path to export', type=str)
|
||||
|
||||
|
||||
# 项目训练中断后,标识是否基于之前训练中断时已保存的 deepspeed 参数(梯度、优化器等)续训
|
||||
# 替代原有的 load_ckpt_dir,load_ckpt_tag
|
||||
parser.add_argument('--need_resume', action='store_true', default=False,
|
||||
help="resume with deepspeed states")
|
||||
parser.add_argument('--need_resume_tag')
|
||||
|
||||
# ----- distributed training parameters -----
|
||||
parser.add_argument('--world_size', default=1, type=int,
|
||||
help='number of distributed processes')
|
||||
parser.add_argument('--local_rank', default=-1, type=int)
|
||||
parser.add_argument('--dist_on_itp', action='store_true')
|
||||
parser.add_argument('--dist_url', default='env://',
|
||||
help='url used to set up distributed training')
|
||||
|
||||
args = parser.parse_args()
|
||||
# 1、文件保存/导出相关
|
||||
args.tensorboard = '{base}/{timestamp}'.format(
|
||||
base=os.path.join(args.export_dir, args.exp_name, 'tensorboard'), timestamp=datetime.now().strftime("%Y%m%d%H%M%S"))
|
||||
args.exp_ckpt_dir= os.path.join(args.export_dir, args.exp_name)
|
||||
os.makedirs(args.tensorboard, exist_ok=True)
|
||||
|
||||
# ----- repo 内路径相关参数 -----
|
||||
# 模型 config,从基准模型的 config 复制而来
|
||||
if not args.llm_path:
|
||||
args.llm_path = _check_default_path(os.path.join(args.self_dir, 'config/config.json'))
|
||||
if not args.vocabs_path:
|
||||
args.vocabs_path = _check_default_path(os.path.join(args.self_dir, 'config/vocabs.txt'))
|
||||
if not args.deepspeed_config:
|
||||
args.deepspeed_config = _check_default_path(os.path.join(args.self_dir, 'config/deepspeed.json'))
|
||||
|
||||
logger.info("get_args() done")
|
||||
return args
|
||||
|
||||
|
||||
|
||||
def _check_default_path(path: str):
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_ckpt_path(base_dir: str):
|
||||
paths = glob.glob(base_dir + '/*.pt')
|
||||
if len(paths) > 0:
|
||||
return paths[0]
|
||||
else:
|
||||
logger.warning(f'.pt file not found in base_dir({base_dir})')
|
||||
return None
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if args.dist_on_itp:
|
||||
logger.info('init_distributed_mode dist_on_itp')
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = os.environ["WORLD_SIZE"]
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
||||
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
logger.info('init_distributed_mode LOCAL_RANK')
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
logger.info('init_distributed_mode SLURM_PROCID')
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
logger.info('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
torch.set_num_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}, gpu {}'.format(
|
||||
args.rank, args.dist_url, args.gpu), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
def setup(args):
|
||||
# init dist
|
||||
init_distributed_mode(args)
|
||||
rank = get_rank()
|
||||
logger.info(f"rank={rank} init_distributed_mode done")
|
||||
|
||||
seed = args.seed + rank
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
logger.info(f"rank={rank} setup(args) done")
|
|
@ -0,0 +1,397 @@
|
|||
# import sys
|
||||
import os
|
||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
import gc
|
||||
import pandas as pd
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
from timm.data.constants import *
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchvision.transforms import transforms, InterpolationMode
|
||||
|
||||
from vis_fm9g.dataset.data import register_data_path
|
||||
from utils import utils
|
||||
from vis_fm9g.dataset.utils import SkipBatchSampler
|
||||
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
|
||||
from vis_fm9g.generation.vllm_fm9g import VLLMFM9GBeamSearch
|
||||
|
||||
import torch
|
||||
import datetime
|
||||
import deepspeed
|
||||
import timm
|
||||
import torch.distributed
|
||||
import torch.utils.data
|
||||
|
||||
from vis_fm9g.dataset.itembuilder import FM9GImageTextBuilder, FM9GCollater
|
||||
from vis_fm9g.dataset.datasets import SingleDataSourceDataset, MultiDataSourceDataset
|
||||
from vis_fm9g.model.vlu_fm9g import VLU_FM9G
|
||||
from vis_fm9g.model.fm9g import FM9GConfig, FM9GTorch
|
||||
|
||||
from deepspeed.utils import logger
|
||||
|
||||
from vis_fm9g.train import exporter, initializer
|
||||
from vis_fm9g.utils.constants import bot_indicator
|
||||
|
||||
import safetensors
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def collect_statsd_metric(name, time_monitor):
|
||||
time_monitor[name] = time.time()
|
||||
return time_monitor
|
||||
|
||||
|
||||
def convert_data_to_cuda(data: Dict):
|
||||
for k, v in data.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
data[k] = data[k].cuda()
|
||||
|
||||
if isinstance(v, List) and len(v) > 0 and isinstance(v[0], torch.Tensor):
|
||||
for i in range(len(v)):
|
||||
v[i] = v[i].cuda()
|
||||
return data
|
||||
|
||||
|
||||
def create_multi_data_source_dataset(file, item_builder):
|
||||
with open(file) as f:
|
||||
data_list = json.load(f)
|
||||
data_source_names = [i['data_source_name'] for i in data_list]
|
||||
data_source_weights = [i['data_source_weight'] for i in data_list]
|
||||
|
||||
ds_list = []
|
||||
for name in data_source_names:
|
||||
if os.path.isdir(name):
|
||||
ds = SingleDataSourceDataset(name, item_builder, *register_data_path[name](name))
|
||||
else:
|
||||
ds = SingleDataSourceDataset(name, item_builder, *register_data_path[name]())
|
||||
ds_list.append(ds)
|
||||
if len(ds_list) > 1:
|
||||
ds = MultiDataSourceDataset(ds_list, data_source_weights)
|
||||
return ds
|
||||
|
||||
|
||||
def get_transform(args):
|
||||
mean = IMAGENET_DEFAULT_MEAN
|
||||
std = IMAGENET_DEFAULT_STD
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((args.img_size, args.img_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
|
||||
return transform
|
||||
|
||||
|
||||
def get_dataloader(tokenizer, data_file, args, is_training=True):
|
||||
transform = get_transform(args)
|
||||
|
||||
builder = FM9GImageTextBuilder(
|
||||
tokenizer=tokenizer,
|
||||
max_len=args.max_len,
|
||||
transform=transform,
|
||||
query_len=args.query_num,
|
||||
min_resolution=0,
|
||||
skip_overlength=args.skip_overlength
|
||||
)
|
||||
|
||||
dataset = create_multi_data_source_dataset(data_file, builder)
|
||||
|
||||
datasampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=args.rank)
|
||||
|
||||
if args.skip_step > 0:
|
||||
datasampler = SkipBatchSampler(datasampler, args.skip_step)
|
||||
|
||||
unpad = (args.flash == 'cuda') and is_training
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
sampler=datasampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
num_workers=2,
|
||||
collate_fn=FM9GCollater(tokenizer=tokenizer, max_len=args.max_len, unpad=unpad)
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
def load_llm_tokenizer(args):
|
||||
return FM9GTokenizer(args.vocabs_path)
|
||||
|
||||
|
||||
def train(vllm_model, args):
|
||||
vllm_model.train()
|
||||
if not args.tune_vision:
|
||||
vllm_model.vpm.requires_grad_(False)
|
||||
if not args.tune_resampler:
|
||||
vllm_model.resampler.requires_grad_(False)
|
||||
if not args.tune_llm:
|
||||
vllm_model.llm.requires_grad_(False)
|
||||
|
||||
if args.drop_vision_last_layer:
|
||||
vllm_model.norm.requires_grad_(True)
|
||||
|
||||
vllm_engine, vllm_optim, _, _ = deepspeed.initialize(
|
||||
args=args, model=vllm_model, model_parameters=vllm_model.parameters()
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
logger.info(f'rank={utils.get_rank()} load model successful')
|
||||
|
||||
tokenizer = load_llm_tokenizer(args)
|
||||
dataloader_train = get_dataloader(tokenizer, data_file=args.train_file, args=args, is_training=True)
|
||||
if args.eval and args.eval_file:
|
||||
dataloader_eval = get_dataloader(tokenizer, data_file=args.eval_file, args=args, is_training=False)
|
||||
else:
|
||||
dataloader_eval = None
|
||||
logger.info(f'rank={utils.get_rank()} load dataloader successful')
|
||||
|
||||
global_step = args.start_step
|
||||
log_loss = 0
|
||||
|
||||
if args.need_resume:
|
||||
load_path, client_state = vllm_engine.load_checkpoint(
|
||||
args.exp_ckpt_dir, tag=args.need_resume_tag)
|
||||
logger.info(f'Load pre-trained checkpoint from {load_path}, states: {client_state}')
|
||||
global_step = client_state['checkpoint_step']
|
||||
args.start_epoch = client_state.get('epoch', args.start_epoch)
|
||||
logger.info(f'rank={utils.get_rank()} load grad successful')
|
||||
|
||||
# init tensorboard writer
|
||||
if args.tensorboard is not None and utils.is_main_process():
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||
else:
|
||||
writer = None
|
||||
|
||||
loss_fct = CrossEntropyLoss(reduction='mean', ignore_index=-100)
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
dataloader_train.sampler.set_epoch(epoch)
|
||||
|
||||
logger.info(f'start epoch={epoch}')
|
||||
time_monitor = {}
|
||||
collect_statsd_metric("init", time_monitor)
|
||||
for step, batch in enumerate(dataloader_train):
|
||||
batch = convert_data_to_cuda(batch)
|
||||
collect_statsd_metric('dataload', time_monitor)
|
||||
vllm_model.zero_grad()
|
||||
output = vllm_model(data=batch)
|
||||
|
||||
logits = output.logits.view(-1, output.logits.shape[-1]).contiguous()
|
||||
target = batch['target'].view(-1).type(torch.long).contiguous()
|
||||
loss = loss_fct(logits, target)
|
||||
# logger.info(f'epoch={epoch}, logits={logits}, target={target}, loss={loss}')
|
||||
collect_statsd_metric("forward", time_monitor)
|
||||
vllm_engine.backward(loss)
|
||||
collect_statsd_metric("backward", time_monitor)
|
||||
|
||||
vllm_engine.step()
|
||||
collect_statsd_metric("optim", time_monitor)
|
||||
|
||||
cost_info = f'dataload cost={(time_monitor["dataload"] - time_monitor["init"]): .2f} ' \
|
||||
+ f'forward cost={(time_monitor["forward"] - time_monitor["dataload"]): .2f} ' \
|
||||
+ f'backward cost={(time_monitor["backward"] - time_monitor["forward"]): .2f} ' \
|
||||
+ f'optim cost={(time_monitor["optim"] - time_monitor["backward"]): .2f}'
|
||||
|
||||
log_loss += loss.item()
|
||||
global_step += 1
|
||||
|
||||
if args.tensorboard is not None and utils.is_main_process():
|
||||
writer.add_scalar("Loss/train", loss.item(), global_step)
|
||||
|
||||
if global_step % args.log_step == 0:
|
||||
log_loss = utils.mean(utils.all_gather(log_loss))
|
||||
if utils.is_main_process():
|
||||
logger.info(
|
||||
f'Datetime: {datetime.datetime.now()} Step: {global_step-args.log_step} - {global_step}: loss: {log_loss/args.log_step: .4f}')
|
||||
logger.info(f'time cost info {cost_info}')
|
||||
log_loss = 0
|
||||
|
||||
if global_step % args.save_step == 0:
|
||||
exporter.export(vllm_engine, step, global_step, epoch, args)
|
||||
|
||||
# end step
|
||||
collect_statsd_metric('init', time_monitor)
|
||||
|
||||
if args.eval and global_step % args.eval_step == 0:
|
||||
evaluate(vllm_model, tokenizer, dataloader_eval, global_step, args)
|
||||
vllm_model.train()
|
||||
|
||||
# 最终模型
|
||||
exporter.export(vllm_engine, 0, global_step, args.epochs-1, args)
|
||||
if args.eval:
|
||||
evaluate(vllm_model, tokenizer, dataloader_eval, global_step, args)
|
||||
|
||||
|
||||
def evaluate(vllm_model, tokenizer, dataloader_eval, global_step, args):
|
||||
vllm_model.eval()
|
||||
torch.cuda.empty_cache()
|
||||
config = deepcopy(vllm_model.llm.config)
|
||||
# 推理不用 flash_attn, 新初始化一个 model
|
||||
config.use_flash_attn = False
|
||||
llm = FM9GTorch(config)
|
||||
|
||||
vpm = load_vpm(args)
|
||||
vision_dim = vpm.embed_dim
|
||||
eval_model = VLU_FM9G(llm, vpm, vision_dim, args.query_num)
|
||||
eval_model.eval().cuda()
|
||||
eval_model.load_state_dict(vllm_model.state_dict())
|
||||
|
||||
torch.cuda.synchronize()
|
||||
logger.info(f'rank={utils.get_rank()} start to eval')
|
||||
transform = get_transform(args)
|
||||
beam_search = VLLMFM9GBeamSearch(eval_model, tokenizer, transform)
|
||||
|
||||
results = []
|
||||
for step, batch in enumerate(dataloader_eval):
|
||||
batch = convert_data_to_cuda(batch)
|
||||
|
||||
for piexl_values, raw, source in zip(batch['pixel_values'], deepcopy(batch['raw_data']), batch['source']):
|
||||
raw = raw[3:-4] # 去 <s> </s>
|
||||
last_idx = raw.rfind(bot_indicator) + len(bot_indicator)
|
||||
data_list = [{'input': raw[:last_idx]}]
|
||||
gt = raw[last_idx:]
|
||||
|
||||
with torch.inference_mode():
|
||||
res = beam_search.generate(
|
||||
img_list=[piexl_values],
|
||||
data_list=data_list,
|
||||
use_transform=False,
|
||||
beam_size=1,
|
||||
max_length=1,
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
'y_pred': res[0],
|
||||
'y_true': gt,
|
||||
'source': source
|
||||
}
|
||||
)
|
||||
|
||||
compose_results = utils.all_gather(results)
|
||||
compose_results_flatten = []
|
||||
for r in compose_results:
|
||||
compose_results_flatten.extend(r)
|
||||
|
||||
df = pd.DataFrame.from_dict(compose_results_flatten)
|
||||
exporter.export_eval_file(df, global_step=global_step, args=args)
|
||||
|
||||
source = sorted(list(set(df['source'])))
|
||||
for ds in source:
|
||||
ds_df = df[df['source'] == ds]
|
||||
print(ds, '%.3f' % (sum(ds_df['y_pred'] == ds_df['y_true']) / len(ds_df)), len(ds_df))
|
||||
print('avg', '%.3f' % (sum(df['y_pred'] == df['y_true']) / len(df)), len(df))
|
||||
|
||||
del eval_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
vllm_model.train()
|
||||
|
||||
|
||||
|
||||
def load_llm(args):
|
||||
config = FM9GConfig.from_json_file(args.llm_path)
|
||||
|
||||
if args.flash == "none":
|
||||
config.use_flash_attn = False
|
||||
else:
|
||||
config.use_flash_attn = True
|
||||
if args.flash == "1d":
|
||||
config.flash_attn_mask_shape = "1d"
|
||||
else:
|
||||
config.flash_attn_mask_shape = "2d"
|
||||
if args.flash == "triton":
|
||||
config.flash_impl = "triton"
|
||||
elif args.flash == "cuda":
|
||||
config.flash_impl = "cuda"
|
||||
|
||||
cpm_model = FM9GTorch(config)
|
||||
return cpm_model
|
||||
|
||||
|
||||
def load_vpm(args):
|
||||
model = timm.create_model(
|
||||
args.vision_encoder,
|
||||
pretrained=False,
|
||||
num_classes=0,
|
||||
dynamic_img_size=True,
|
||||
dynamic_img_pad=True
|
||||
)
|
||||
|
||||
if isinstance(model, timm.models.VisionTransformer):
|
||||
if model.attn_pool is not None:
|
||||
model.attn_pool = torch.nn.Identity()
|
||||
|
||||
if args.drop_vision_last_layer:
|
||||
model.blocks[-1] = torch.nn.Identity()
|
||||
return model
|
||||
|
||||
def load_sharded_safetensors(sharded_folder):
|
||||
safetensors_files = sorted([f for f in os.listdir(sharded_folder) if f.endswith('.safetensors')])
|
||||
|
||||
if not safetensors_files:
|
||||
raise FileNotFoundError(f"No safetensors files found in {sharded_folder}")
|
||||
|
||||
merged_state_dict = {}
|
||||
for file_name in safetensors_files:
|
||||
file_path = os.path.join(sharded_folder, file_name)
|
||||
logger.info(f"Loading safetensors shard from {file_path}")
|
||||
state_dict = load_file(file_path)
|
||||
merged_state_dict.update(state_dict)
|
||||
|
||||
return merged_state_dict
|
||||
|
||||
def setup_model(args):
|
||||
start = time.time()
|
||||
|
||||
llm = load_llm(args)
|
||||
vpm = load_vpm(args)
|
||||
vision_dim = vpm.embed_dim
|
||||
model = VLU_FM9G(llm, vpm, vision_dim, args.query_num)
|
||||
if args.model_checkpoint:
|
||||
logger.info(f'load model_checkpoint from {args.model_checkpoint}')
|
||||
|
||||
model_checkpoint = args.model_checkpoint
|
||||
file_extension = os.path.splitext(model_checkpoint)[1]
|
||||
|
||||
if file_extension == '.safetensors':
|
||||
logger.info(f"Loading safetensors checkpoint from {model_checkpoint}")
|
||||
state_dict = load_file(model_checkpoint)
|
||||
info = model.load_state_dict(state_dict, strict=True)
|
||||
logger.info(f"Loaded checkpoint info={info}")
|
||||
|
||||
elif os.path.isdir(model_checkpoint):
|
||||
logger.info(f"Loading safetensors from sharded folder: {model_checkpoint}")
|
||||
state_dict = load_sharded_safetensors(model_checkpoint)
|
||||
info = model.load_state_dict(state_dict, strict=True)
|
||||
logger.info(f"Loaded checkpoint info={info}")
|
||||
|
||||
else:
|
||||
state_dict = torch.load(args.model_checkpoint, map_location='cpu')
|
||||
info = model.load_state_dict(state_dict, strict=True)
|
||||
logger.info(f"load checkpoint info={info}" )
|
||||
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
model.cuda()
|
||||
torch.cuda.empty_cache()
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
args = initializer.get_args()
|
||||
# setup file and device
|
||||
initializer.setup(args)
|
||||
# load model
|
||||
model = setup_model(args)
|
||||
# train
|
||||
train(model, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,11 +1,7 @@
|
|||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Union
|
||||
|
||||
from .log import logger
|
||||
import copy
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
|
||||
class Config(object):
|
|
@ -0,0 +1,6 @@
|
|||
from datetime import datetime
|
||||
|
||||
current_time = datetime.now().strftime("%Y年%m月%d日")
|
||||
usr_indicator = "<用户>"
|
||||
bot_indicator = "<AI>"
|
||||
SYSTEM = f"{usr_indicator}你叫九格,是由启元实验室研发的多模态大型语言模型。\n你的知识库截止至2022年4月,当前时间是{current_time}。"
|
|
@ -1,4 +0,0 @@
|
|||
# !/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024, QiYuan Inc
|