Compare commits

...

13 Commits

Author SHA1 Message Date
p20367984 52dbf63007 Update quick_start.md 2025-01-09 09:51:32 +08:00
p20367984 9aeefe95d5 0108 2025-01-08 15:20:38 +08:00
p20367984 a214708255 0108 2025-01-08 15:17:35 +08:00
p20367984 25c3643a05 0108 2025-01-08 15:15:33 +08:00
p20367984 45d7c9b99d 0108 2025-01-08 14:41:14 +08:00
p20367984 23d3844492 0108 2025-01-08 14:34:03 +08:00
carboncoo c89395164e 1119 2024-11-19 11:01:06 +08:00
carboncoo 8e693d5876 1119 2024-11-19 11:00:26 +08:00
carboncoo 415c624322 Read me 2024-11-19 10:48:12 +08:00
carboncoo 4139ba5dfe 1119 2024-11-19 10:44:01 +08:00
carboncoo a8d431c14f 1119 2024-11-19 10:42:49 +08:00
carboncoo 1857f60d1e 1118 2024-11-18 17:59:15 +08:00
carboncoo a041469104 1118 2024-11-18 17:47:43 +08:00
171 changed files with 4381 additions and 354589 deletions

193
FM9G-V/chat.py Normal file
View File

@ -0,0 +1,193 @@
import gc
from io import BytesIO
import requests
import timm
import torch
import random
import json
from PIL import Image
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from torchvision.transforms import InterpolationMode, transforms
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import os,sys
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
from vis_fm9g.generation.vllm_fm9g import VLLMFM9GBeamSearch
from vis_fm9g.model.fm9g import FM9GConfig, FM9GTorch
from vis_fm9g.model.vlu_fm9g import VLU_FM9G
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
from vis_fm9g.utils.constants import SYSTEM
def random_selection_device(device_num):
device_list = [f"cuda:{i}" for i in range(device_num)]
if len(device_list) == 1:
return [device_list[0]] * 4
elif len(device_list) == 2:
a, b = device_list
return [a, b, a, b]
else:
selected = random.sample(device_list, 3)
repeated = random.choice(selected)
return selected + [repeated]
def has_pt_file(directory):
pt_files = []
files = os.listdir(directory)
for file in files:
if file.endswith(".pt"):
full_path = os.path.join(directory, file)
pt_files.append(full_path)
return pt_files
def load_checkpoint(model, pretrained_model_name_or_path, device_num=4):
pt_files = has_pt_file(pretrained_model_name_or_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if pt_files:
device_map = {}
state_dict = model.state_dict()
layer_names = list(state_dict.keys())
device_list = random_selection_device(device_num=device_num)
for i, layer_name in enumerate(layer_names):
if "vpm" in layer_name: # 假设 vpm 参数包含 "vpm" 前缀
device_map[layer_name] = device_list[0]
elif "llm" in layer_name: # 假设 llm 参数包含 "llm" 前缀
device_map[layer_name] = device_list[1]
else:
device_map[layer_name] = device_list[2]
model = load_checkpoint_and_dispatch(model, pt_files[0], device_map=device_map)
model.to(device)
return model
else:
model_checkpoint = os.path.join(pretrained_model_name_or_path, "sharded")
index_file = os.path.join(pretrained_model_name_or_path, "sharded", "model.safetensors.index.json")
with open(index_file, "r") as f:
index_data = json.load(f)
weight_map = index_data["weight_map"]
all_params = {name for name, _ in model.named_parameters()}
mapped_params = set(weight_map.keys())
unmapped_params = all_params - mapped_params
# 解析 weight_map设置 device_map
#TODO强制不同的分支放置到不同的GPU上避免出现计算问题但是比较丑陋需要做后续优化
device_list = random_selection_device(device_num=device_num)
device_map = {}
for param_name, weight_file in weight_map.items():
if "vpm" in param_name: # 假设 vpm 参数包含 "vpm" 前缀
device_map[param_name] = device_list[0]
elif "llm" in param_name: # 假设 llm 参数包含 "llm" 前缀
device_map[param_name] = device_list[1]
else:
device_map[param_name] = device_list[2]
for param_name in unmapped_params:
device_map[param_name] = device_list[3] # 随机选择设备
model = load_checkpoint_and_dispatch(
model,
model_checkpoint,
device_map=device_map,
).to(device)
return model
def chat(model, image, question, context, tokenizer, query_nums=64, vision_hidden_states=None, max_length=1024):
if not context:
question = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + question
final_input = f'{SYSTEM}<用户>{question}<AI>'
else:
final_input = f'{context}<用户>{question}<AI>'
data_list = [
{'input': final_input}
]
res, vision_hidden_states = model.generate(
data_list=data_list,
max_inp_length=2048,
beam_size=3,
img_list=[[image]],
max_length=max_length,
repetition_penalty=1.1,
temperature=0.7,
length_penalty=3,
return_vision_hidden_states=True
)
answer = res[0]
context = final_input + answer
return answer, context, vision_hidden_states
def load_llm(llm_path):
config = FM9GConfig.from_json_file(llm_path)
config.use_flash_attn = False
cpm_model = FM9GTorch(config)
return cpm_model
def load_vpm(vision_encoder, drop_vision_last_layer=False):
model = timm.create_model(
vision_encoder,
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True
)
if isinstance(model, timm.models.VisionTransformer):
if model.attn_pool is not None:
model.attn_pool = torch.nn.Identity()
if drop_vision_last_layer:
model.blocks[-1] = torch.nn.Identity()
return model
def load_vis_fm9g(llm_path, vision_encoder):
llm =load_llm(llm_path)
vpm = load_vpm(vision_encoder, drop_vision_last_layer=False)
vision_dim = vpm.embed_dim
model = VLU_FM9G(llm, vpm, vision_dim, query_num=256)
return model
def load_tokenizer(vocabs_path):
return FM9GTokenizer(vocabs_path)
def load_transform(img_size):
transform = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD)
])
return transform
if __name__ == '__main__':
root = "checkpoint/"
llm_path = root + 'config.json'
vocabs_path = root + 'vocabs.txt'
vision_encoder = 'eva02_enormous_patch14_clip_224.laion2b_plus'
img_size = 448
device_num = 4
with init_empty_weights():
model = load_vis_fm9g(llm_path, vision_encoder)
model = load_checkpoint(model, root, device_num=device_num)
model.eval()
tokenizer = load_tokenizer(vocabs_path)
transform = load_transform(img_size)
beam_search = VLLMFM9GBeamSearch(model, tokenizer, transform)
# 图像输入
url = 'test.jpg'
image = Image.open(url).convert('RGB')
# 文本输入
prompt = '这幅图描述了什么?'
answer, context, _ = chat(
beam_search, image, prompt, context=None, tokenizer=tokenizer, query_nums=256
)
print(answer)

View File

@ -0,0 +1,44 @@
[
{
"id": "000000052347",
"conversations": [
{
"from": "human",
"value": "<image>\nPlease generate detailed descriptions of the given image."
},
{
"from": "gpt",
"value": "The image depicts a tranquil rural scene. In the foreground, there are two cows. The cow on the left is white and is standing, while the cow on the right is brown and is lying down. Both cows appear to be resting or grazing. The background of the image shows a vast landscape with rolling hills and a clear sky. The overall atmosphere of the image is peaceful and serene, capturing a moment of calm in the countryside."
}
],
"image": "00000/000000052347.jpg"
},
{
"id": "000000382625",
"conversations": [
{
"from": "human",
"value": "<image>\nPlease generate detailed descriptions of the given image."
},
{
"from": "gpt",
"value": "The image depicts a group of individuals, likely in a celebratory or formal setting. In the foreground, there is a table covered with a blue tablecloth, upon which sit three cakes. The cakes are decorated with various designs, including what appears to be a flag and other symbols, suggesting a theme related to a specific organization or event.\n\nBehind the table, there are six people standing, arranged in two rows. The front row consists of three individuals, two of whom are wearing chef's hats, indicating they may be chefs or cooks. The third person in the front row is wearing a military uniform, suggesting a connection to the military.\n\nThe back row features three more individuals, two of whom are also in military uniforms, while the third is wearing a tie, which could imply a formal or administrative role. All individuals are facing the camera and appear to be posing for the photograph.\n\nThe setting appears to be an indoor space, possibly a dining area or a kitchen, given the presence of the table and the attire of the individuals. The lighting is bright, and the overall atmosphere seems to be one of celebration or recognition.\n\nThere are no visible texts or logos that provide specific information about the event or the individuals. The style of the image is a standard, candid photograph capturing a moment during an event."
}
],
"image": "00000/000000382625.jpg"
},
{
"id": "000000559875",
"conversations": [
{
"from": "human",
"value": "<image>\nPlease generate detailed descriptions of the given image."
},
{
"from": "gpt",
"value": "In the image, a young girl is the main subject. She is dressed in a white tennis outfit, complete with a matching white visor. In her hands, she holds a tennis racket, which is predominantly yellow and black. The racket is adorned with the word \"HEAD\" in bold orange letters.\n\nThe girl is standing on a concrete surface, which suggests an outdoor setting. Behind her, there's a white wall, on which a black and white poster is hanging. The poster features a tennis player, further emphasizing the theme of the image.\n\nThe girl is also holding a small trophy in her left hand. The trophy is gold in color, adding a touch of elegance to the scene. The girl's smile suggests she is proud of her achievement.\n\nOverall, the image captures a moment of joy and accomplishment in the world of tennis."
}
],
"image": "00000/000000559875.jpg"
}
]

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

View File

@ -0,0 +1,44 @@
[
{
"id": "000000052347",
"conversations": [
{
"from": "human",
"value": "<image>\nPlease generate detailed descriptions of the given image."
},
{
"from": "gpt",
"value": "The image depicts a tranquil rural scene. In the foreground, there are two cows. The cow on the left is white and is standing, while the cow on the right is brown and is lying down. Both cows appear to be resting or grazing. The background of the image shows a vast landscape with rolling hills and a clear sky. The overall atmosphere of the image is peaceful and serene, capturing a moment of calm in the countryside."
}
],
"image": "00000/000000052347.jpg"
},
{
"id": "000000382625",
"conversations": [
{
"from": "human",
"value": "<image>\nPlease generate detailed descriptions of the given image."
},
{
"from": "gpt",
"value": "The image depicts a group of individuals, likely in a celebratory or formal setting. In the foreground, there is a table covered with a blue tablecloth, upon which sit three cakes. The cakes are decorated with various designs, including what appears to be a flag and other symbols, suggesting a theme related to a specific organization or event.\n\nBehind the table, there are six people standing, arranged in two rows. The front row consists of three individuals, two of whom are wearing chef's hats, indicating they may be chefs or cooks. The third person in the front row is wearing a military uniform, suggesting a connection to the military.\n\nThe back row features three more individuals, two of whom are also in military uniforms, while the third is wearing a tie, which could imply a formal or administrative role. All individuals are facing the camera and appear to be posing for the photograph.\n\nThe setting appears to be an indoor space, possibly a dining area or a kitchen, given the presence of the table and the attire of the individuals. The lighting is bright, and the overall atmosphere seems to be one of celebration or recognition.\n\nThere are no visible texts or logos that provide specific information about the event or the individuals. The style of the image is a standard, candid photograph capturing a moment during an event."
}
],
"image": "00000/000000382625.jpg"
},
{
"id": "000000559875",
"conversations": [
{
"from": "human",
"value": "<image>\nPlease generate detailed descriptions of the given image."
},
{
"from": "gpt",
"value": "In the image, a young girl is the main subject. She is dressed in a white tennis outfit, complete with a matching white visor. In her hands, she holds a tennis racket, which is predominantly yellow and black. The racket is adorned with the word \"HEAD\" in bold orange letters.\n\nThe girl is standing on a concrete surface, which suggests an outdoor setting. Behind her, there's a white wall, on which a black and white poster is hanging. The poster features a tennis player, further emphasizing the theme of the image.\n\nThe girl is also holding a small trophy in her left hand. The trophy is gold in color, adding a touch of elegance to the scene. The girl's smile suggests she is proud of her achievement.\n\nOverall, the image captures a moment of joy and accomplishment in the world of tennis."
}
],
"image": "00000/000000559875.jpg"
}
]

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

View File

@ -0,0 +1,3 @@
0
63557
137410

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,3 @@
0
63557
137410

File diff suppressed because one or more lines are too long

130
FM9G-V/json2tsv.py Normal file
View File

@ -0,0 +1,130 @@
import os
import json
import base64
def json_to_tsv_with_base64(json_file, output_base_path):
"""
JSON 文件转换为带 Base64 编码的 TSV 文件并生成自定义命名的 .lineidx 文件
Args:
json_file: 输入的 JSON 文件路径
output_base_path: 默认的输出路径在此路径下创建文件夹
data_weight: 数据权重定义
"""
# 获取 JSON 文件所在的目录
json_dir = os.path.dirname(os.path.abspath(json_file))
image_root = os.path.join(json_dir, "images") # 根目录 + "images"
json_filename = os.path.splitext(os.path.basename(json_file))[0]
data_name = json_filename.split(".")[0]
# 创建新的文件夹,命名为 JSON 文件名
output_folder = os.path.join(output_base_path, json_filename)
os.makedirs(output_folder, exist_ok=True)
# 读取 JSON 数据
with open(json_file, 'r') as jf:
data = json.load(jf)
# 获取数据条目数
num_entries = len(data)
# 构建 TSV 和 .lineidx 文件路径
tsv_file = os.path.join(output_folder, f"{json_filename}-{num_entries}.tsv")
lineidx_file = os.path.join(output_folder, f"{json_filename}-{num_entries}.lineidx")
# 写入 TSV 文件
with open(tsv_file, 'w') as tsv_out:
for item in data:
# 提取信息
item_id = item["id"]
image_id = item["image"]
# 图像路径调整为 JSON 文件所在目录的 "images" 子目录
image_path = os.path.join(image_root, image_id)
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found: {image_path}")
# 图像编码为 Base64
with open(image_path, "rb") as img_file:
image_base64 = base64.b64encode(img_file.read()).decode('utf-8')
# 保留原始 conversations 格式并编码为 Base64
conversations_json = json.dumps(item["conversations"], ensure_ascii=False)
conversations_base64 = base64.b64encode(conversations_json.encode('utf-8')).decode('utf-8')
# 写入 TSV
line = f"{data_name}\t{image_base64}\t{conversations_base64}\t{image_root}\t{image_root}\t{item_id}\t{image_id}\n"
tsv_out.write(line)
# 生成自定义命名的 .lineidx 文件
create_lineidx(tsv_file, lineidx_file)
print(f"TSV 文件生成完成: {tsv_file}")
print(f"索引文件生成完成: {lineidx_file}")
return output_folder
def create_lineidx(filein, idxout):
"""
根据 TSV 文件生成自定义命名的 .lineidx 索引文件
Args:
filein: 输入的 TSV 文件路径
idxout: 输出的索引文件路径
"""
idxout_tmp = idxout + '.tmp'
with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
fsize = os.fstat(tsvin.fileno()).st_size
fpos = 0
while fpos != fsize:
tsvout.write(str(fpos) + "\n")
tsvin.readline()
fpos = tsvin.tell()
os.rename(idxout_tmp, idxout)
def process_multiple_json(json_files, output_base_path, dataset_weights, train_data=True):
"""
Process a list of JSON files, converting each to TSV and creating a data.json file with weights.
Args:
json_files: List of paths to JSON files.
output_base_path: Base directory to store TSV files.
dataset_weights: Dictionary containing the weight for each dataset.
"""
data_sources = []
for json_file in json_files:
# Get the weight from the dataset_weights dictionary
json_filename = os.path.splitext(os.path.basename(json_file))[0]
data_weight = dataset_weights.get(json_filename, 1) # Default weight is 1
# Convert JSON to TSV and get the path
tsv_path = json_to_tsv_with_base64(json_file, output_base_path)
# Append to the data_sources list
data_sources.append({
"data_source_name": tsv_path,
"data_source_weight": data_weight
})
# 保存数据json文件默认存储到vis_fm9g/config/data下用于数据索引
if train_data:
data_json_path = os.path.join("vis_fm9g/config/data", "data.json")
else:
data_json_path = os.path.join("vis_fm9g/config/data", "eval_data.json")
with open(data_json_path, 'w') as f:
json.dump(data_sources, f, ensure_ascii=False, indent=4)
print(f"data.json file generated: {data_json_path}")
## 示例使用 ##
if __name__ == "__main__":
# 输入json文件的位置以及对应的数据权重如果没有设置数据权重默认为1
json_files = ["dataset/data_1/data_1.json", "dataset/data_2/data_2.json"]
# json文件会转化成tsv文件格式可以修改tsv_file_path来改变tsv文件的存储位置
tsv_file_path = "dataset/train"
# 设置数据权重
dataset_weights = {
"data": 1,
"data_2": 2
}
process_multiple_json(json_files, tsv_file_path, dataset_weights, train_data=True)

12
FM9G-V/readme.txt Normal file
View File

@ -0,0 +1,12 @@
## 环境安装 ##
cd FM9G-V
pip install -r requirements.txt
## demo使用 ##
python chat.py
## 数据准备 ##
python json2tsv.py
## 训练 ##
bash run_train.sh

23
FM9G-V/requirements.txt Normal file
View File

@ -0,0 +1,23 @@
torch==2.0.1
torchvision==0.15.2
transformers==4.31.0
tokenizers>=0.12.1,<0.14
sentencepiece==0.1.99
shortuuid
peft==0.4.0
bitsandbytes==0.41.0
pydantic<2,>=1
markdown2[all]
numpy==1.23.5
scikit-learn==1.2.2
gradio==3.35.2
gradio_client==0.2.9
requests
httpx==0.24.0
uvicorn
fastapi
einops==0.6.1
einops-exts==0.0.4
timm==0.9.8
deepspeed==0.11.1
flash-attn==2.3.3

69
FM9G-V/run_train.sh Normal file
View File

@ -0,0 +1,69 @@
#!/bin/bash
SELF_DIR=$(cd "$(dirname "$0")" || exit 1; pwd)
export PYTHONPATH=$(dirname $SELF_DIR):$PYTHONPATH
echo Working Directory at `pwd`
echo Bash at `which bash`
echo Python at `which python`
nvidia-smi
slave_or_master=$1
# 支持单机多卡和多机多卡
GPUS_PER_NODE=8
WORLD_SIZE=${WORLD_SIZE:-1}
RANK=${RANK:-0}
MASTER_ADDR=${MASTER_ADDR:-`hostname`}
MASTER_PORT=${MASTER_PORT:-12345}
rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT
if [[ $slave_or_master == "slave" ]]; then
rdzv_endpoint=$2:$MASTER_PORT
fi
TRAIN_FILE=vis_fm9g/config/data/data.json
EVAL_FILE=vis_fm9g/config/data/eval_data.json
LLM_PATH=vis_fm9g/config/model/fm9g-7b.json
VOCAB_PATH=vis_fm9g/config/vocab/fm9g.txt
MODEL_CHECKPOINT=checkpoint/sharded
DEEPSPEED_CONFIG=vis_fm9g/config/deepspeed/sft.json
EXPORT_DIR=./saved
echo WORLD_SIZE=$WORLD_SIZE RANK=$RANK rdzv_endpoint=$rdzv_endpoint
# --------------- 运行参数 ---------------
OPTS=""
OPTS+=" --self_dir ${SELF_DIR}"
OPTS+=" --train_file ${TRAIN_FILE}"
OPTS+=" --eval_file ${EVAL_FILE}"
OPTS+=" --llm_path ${LLM_PATH}"
OPTS+=" --vocabs_path ${VOCAB_PATH}"
OPTS+=" --model_checkpoint ${MODEL_CHECKPOINT}"
OPTS+=" --deepspeed_config ${DEEPSPEED_CONFIG}"
OPTS+=" --save_deepspeed"
OPTS+=" --export_dir ${EXPORT_DIR}"
OPTS+=" --exp_name vis-fm9g-train-eval"
OPTS+=" --flash cuda"
OPTS+=" --max_len 500"
OPTS+=" --batch_size 1"
OPTS+=" --save_step 1000"
OPTS+=" --epochs 1"
OPTS+=" --query_num 256"
OPTS+=" --vision_encoder eva02_enormous_patch14_clip_224.laion2b_plus"
OPTS+=" --tune_resampler"
OPTS+=" --tune_vision"
OPTS+=" --tune_llm"
OPTS+=" --eval"
RUNNER="torchrun --nnodes=${WORLD_SIZE} --nproc_per_node=${GPUS_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${rdzv_endpoint}"
CMD="$RUNNER ./vis_fm9g/train/train_vis_fm9g.py ${OPTS}"
echo "-------final CMD is------"
echo "${CMD}"
echo "-------final CMD end------"
# 执行 CMD
eval "${CMD}"

BIN
FM9G-V/test.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 128 KiB

Binary file not shown.

Binary file not shown.

Binary file not shown.

115
FM9G-V/utils/logger.py Normal file
View File

@ -0,0 +1,115 @@
""" Init a logger with options from env variables.
- set log level by ``LOG_LEVEL``, default: ``INFO``;
- output log message to file by ``LOG_FILE``, default: output to stdout.
TODO:
support setting log level and log file from config file.
"""
import logging
import os
_LOG_FMT = "[%(asctime)s][%(levelname).1s][%(process)d-%(name)s-%(filename)s:%(lineno)s]- %(message)s"
_DATE_FMT = "%Y-%m-%d,%H:%M:%S"
_logging_level = {
"CRITICAL": logging.CRITICAL,
"ERROR": logging.ERROR,
"WARNING": logging.WARNING,
"INFO": logging.INFO,
# Distributed Level, print log in main proc only by default, set this level to print all messages.
"DP": logging.INFO,
"DEBUG": logging.DEBUG,
None: logging.INFO,
}
_level = os.environ.get("LOG_LEVEL", "INFO").upper()
class ShortNameFormatter(logging.Formatter):
def format(self, record: logging.LogRecord):
raw = record.name # save and restore for other formatters if desired
parts = raw.split(".")
record.name = ".".join(p[:3] for p in parts) if len(parts) > 1 else raw # keep first char for module name.
result = super().format(record)
record.name = raw
return result
class StyleAdapter(logging.LoggerAdapter):
def __init__(self, logger, extra=None, style="default"):
super().__init__(logger, extra or {})
self._style = style
self._enable = self._enable()
@classmethod
def _enable(cls):
# Note: to make this Logger more standalone, perform basic check without extra deps, e.g. tf/torch et al.
worker = os.getenv("WORKER")
rank = os.getenv("RANK")
# not in DP/DDP mode or proc_id = "0"
is_main = (not worker and not rank) or (worker == "0" or rank == "0")
is_jeeves_job = os.getenv("JEEVES_JOB_ID")
return _level in ["DEBUG", "DP"] or is_jeeves_job or is_main
def _format(self, *msgs, color: str = None):
if self._style == "legacy":
if len(msgs) == 1:
msg_str = msgs[0]
else:
msg_str = msgs[0] % msgs[1:]
else:
msg_str = ", ".join([str(msg) for msg in msgs])
if color:
pass
return msg_str
def log(self, level, msg, *args, **kwargs):
color = kwargs.pop("color", None)
if self.isEnabledFor(level) and self._enable:
msg, kwargs = self.process(msg, kwargs)
msg_str = self._format(msg, *args, color=color)
# noinspection PyProtectedMember
self.logger._log(level, msg_str, (), **kwargs)
def init_logger(name="ai", filename=os.environ.get("LOG_FILE", ""), fmt=_LOG_FMT, level=_level, style="legacy"):
"""init logger
Args:
name(str): optional, default: ai.
filename(str): optional, default: "". Output log to file if specified, by default is set by env `LOG_FILE`.
fmt(str): optional, default: _LOG_FMT
level(str): optional, default: INFO
style(str): optional, choice from ["print", "legacy"]
- legacy: take first argument as a formatter, the remaining positional arguments as message values.
this is consistent with the constraint of `logging` pkg
- print: all positional arguments are message values which will be concatenated with ", "
Returns:
a logger instance
Examples:
>>> log = init_logger("log2stdout", level="INFO")
>>> log.error("info")
"""
logger = logging.getLogger(name)
logger.setLevel(_logging_level[level])
if fmt:
# formatter = logging.Formatter(fmt, datefmt=_DATE_FMT)
formatter = ShortNameFormatter(fmt, datefmt=_DATE_FMT)
else:
formatter = None
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logging.basicConfig(format=fmt, level=_logging_level[_level], handlers=[handler])
if filename:
handler = logging.FileHandler(filename)
handler.setFormatter(formatter)
logger.addHandler(handler)
return StyleAdapter(logger, style=style)

125
FM9G-V/utils/utils.py Normal file
View File

@ -0,0 +1,125 @@
# --------------------------------------------------------
# BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366)
# Github source: https://github.com/microsoft/unilm/tree/master/beitv2
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Zhiliang Peng
# Based on BEiT, timm, DeiT and DINO code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import pickle
import torch
import torch.distributed as dist
import torch.nn as nn
import matplotlib.pyplot as plt
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def plot_images(images: dict):
x = images["input"]
reconstruction = images["rec"]
half_sample = images["half_sample"]
new_sample = images["new_sample"]
fig, axarr = plt.subplots(1, 4)
axarr[0].imshow(x.cpu().detach().numpy()[0].transpose(1, 2, 0))
axarr[1].imshow(reconstruction.cpu().detach().numpy()[0].transpose(1, 2, 0))
axarr[2].imshow(half_sample.cpu().detach().numpy()[0].transpose(1, 2, 0))
axarr[3].imshow(new_sample.cpu().detach().numpy()[0].transpose(1, 2, 0))
plt.show()
def get_model(model):
if isinstance(model, torch.nn.DataParallel) \
or isinstance(model, torch.nn.parallel.DistributedDataParallel):
return model.module
else:
return model
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def mean(lst):
return sum(lst) / len(lst)

Binary file not shown.

View File

View File

@ -0,0 +1,10 @@
[
{
"data_source_name": "dataset/train/data_1",
"data_source_weight": 1
},
{
"data_source_name": "dataset/train/data_2",
"data_source_weight": 2
}
]

View File

@ -0,0 +1,3 @@
[
{ "data_source_name": "pretrain_eval_eval", "data_source_weight": 1 }
]

View File

@ -0,0 +1,33 @@
{
"train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps": 16,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 5e-5,
"betas": [
0.9,
0.98
],
"weight_decay": 0.01
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 1e-6,
"warmup_max_lr": 1e-5,
"warmup_num_steps": 500
}
},
"fp16": {
"enabled": true,
"initial_scale_power": 10,
"auto_cast": true
},
"zero_optimization": {
"stage": 2
},
"steps_per_print": 50,
"gradient_clipping": 1.0
}

View File

@ -0,0 +1,33 @@
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 16,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-5,
"betas": [
0.9,
0.98
],
"weight_decay": 0.01
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 1e-6,
"warmup_max_lr": 1e-5,
"warmup_num_steps": 500
}
},
"fp16": {
"enabled": true,
"initial_scale_power": 10,
"auto_cast": true
},
"zero_optimization": {
"stage": 2
},
"steps_per_print": 50,
"gradient_clipping": 1.0
}

View File

@ -3,25 +3,14 @@
"dropout_p": 0.0,
"eps": 1e-05,
"half": true,
"half_type": "bf16",
"use_flash_attn": true,
"flash_attn_mask_shape": "2d",
"dim_model": 4096,
"dim_ff": 14336,
"dim_ff": 11008,
"dim_head": 128,
"num_heads": 32,
"num_kv_heads": 32,
"num_layers": 32,
"activate_fn": "silu",
"init_std": 0.10,
"scale": false,
"scale_emb": 12,
"scale_depth": -1,
"model_type": "fm9g",
"architectures": [
"FM9GForCausalLM"
],
"qk_norm": false,
"tie_lm_head": false,
"ffn_gated": true
}
"scale": false
}

View File

@ -119687,10 +119687,10 @@
"𠳐"
"𥻗"
"𬉼"
"<|im_start|>"
"<|im_end|>"
"<pad_2>"
"<pad_3>"
"<pad_4>"
"<pad_5>"
"<pad_6>"
"<image>"
"</image>"
"<ref>"
"</ref>"
"<box>"
"</box>"
"<quad>"

View File

View File

@ -0,0 +1,347 @@
import io
import os
import re
import glob
import math
import json
import base64
import random
import copy
from PIL import Image
from typing import List
class Register(dict):
def __init__(self, *args, **kwargs):
super(Register, self).__init__(*args, **kwargs)
self._dict = {}
def register(self, target):
def add_register_item(keys, value):
if not callable(value):
raise Exception(
f"Register object must be callable! But receice:{value} is not callable!")
if not isinstance(keys, list):
keys = [keys]
for key in keys:
if key in self._dict:
print(
f"error: \033[33m{value.__name__} has been registered before, so we will overriden it\033[0m")
exit()
self[key] = value
return value
if callable(target):
return add_register_item(target.__name__, target)
else:
return lambda x: add_register_item(target, x)
def __call__(self, target):
return self.register(target)
def __setitem__(self, key, value):
self._dict[key] = value
def __getitem__(self, key):
# 如果 key 存在于注册的处理器中,直接返回
if key in self._dict:
return self._dict[key]
else:
# 如果 key 不存在,使用默认处理器
return self._dict['default']
def __contains__(self, key):
return key in self._dict
def __str__(self):
return str(self._dict)
def keys(self):
return self._dict.keys()
def values(self):
return self._dict.values()
def items(self):
return self._dict.items()
register_data_processor = Register()
register_data_path = Register()
def vqa_instruction_templates(question, idx=None):
instructions = [
"{Question} A short answer to the question is",
"Given the image, answer the following question with no more than three words. {Question}",
"Based on the image, respond to this question with a short answer: {Question} Answer:",
"Use the provided image to answer the question: {Question} Provide your answer as short as possible:",
]
if idx is None:
new_question = random.choice(
instructions).replace("{Question}", question)
else:
new_question = instructions[idx].replace("{Question}", question)
return new_question
def caption_instruction_templates():
instructions = [
"Describe the image concisely.",
"Provide a brief description of the given image.",
"Offer a succinct explanation of the picture presented.",
"Summarize the visual content of the image.",
"Give a short and clear explanation of the subsequent image.",
"Share a concise interpretation of the image provided.",
"Present a compact description of the photo's key features.",
"Relay a brief, clear account of the picture shown.",
"Render a clear and concise summary of the photo.",
"Write a terse but informative summary of the picture.",
"Create a compact narrative representing the image presented."
]
new_question = random.choice(instructions)
return new_question
def ocr_instruction_templates():
instructions = [
"Identify the text in the image with position."
"Pinpoint and indicate the text and its location within the image."
"Find the text in the image and identify its positional."
"Detect the text within the image and specify its position."
"Locate the text in the image and detail its position."
]
new_question = random.choice(instructions)
return new_question
def textvqa_instruction_templates(question):
instructions = [
"Answer the question shortly by reading the texts. {Question}"
"After reading the text in the image, {Question} A short answer to the question is",
"Given the text in the image, answer the following question with no more than three words. {Question}"
]
new_question = random.choice(instructions).replace("{Question}", question)
return new_question
def load_multimodal_conversation(text_b64, img_b64_buffer):
map_role = {
'human': 'human',
'gpt': 'gpt'
}
text = base64.b64decode(text_b64).decode('utf-8')
list_conv = json.loads(text)
out: List[dict] = []
for idx, sentence in enumerate(list_conv):
value = sentence['value']
if idx == 0 and '<image>' not in value:
value = f"<image>\n{value}"
if idx != 0 and '<image>' in value:
value = value.replace('<image>', '')
out.append({
'from': map_role[sentence['from']],
'value': value
})
img_io = io.BytesIO(base64.b64decode(img_b64_buffer))
img_io.seek(0)
image = Image.open(img_io).convert('RGB')
return image, out
def load_fm9g_multimodal_conversation(text_b64, img_b64_buffer):
map_role = {
'human': 'human',
'gpt': 'gpt'
}
text = base64.b64decode(text_b64).decode('utf-8')
list_conv = json.loads(text)
out: List[dict] = []
for idx, sentence in enumerate(list_conv):
value = sentence['value']
value = re.sub(r'<image>.+?</image>', '<image>', value)
out.append({
'from': map_role[sentence['from']],
'value': value.strip()
})
img_io = io.BytesIO(base64.b64decode(img_b64_buffer))
img_io.seek(0)
image = Image.open(img_io).convert('RGB')
return image, out
def load_pretrain_conversation(text_b64, img_b64_buffer):
map_role = {
'human': 'human',
'gpt': 'gpt'
}
text = base64.b64decode(text_b64).decode('utf-8')
list_conv = json.loads(text)
out: List[dict] = []
for idx, sentence in enumerate(list_conv):
print(sentence)
value = sentence['value']
value = re.sub(r'<image>.+?</image>', '<image>', value)
out.append({
'from': map_role[sentence['from']],
'value': value.strip()
})
img_io = io.BytesIO(base64.b64decode(img_b64_buffer))
img_io.seek(0)
image = Image.open(img_io).convert('RGB')
return image, out
def b64_to_PIL_image(img_b64_buffer):
img_io = io.BytesIO(base64.b64decode(img_b64_buffer))
img_io.seek(0)
image = Image.open(img_io).convert('RGB')
return image
def wrap_qa_to_single_turn_multimodal_conv(answer, question):
if '<image>' not in question:
question = f"<image>\n{question}"
out = [
{"from": "human", "value": question},
{"from": "gpt", "value": answer}
]
return question, out
def wrap_generation_single_turn_conv(out, template_func):
conv = [
{
"from": "human",
"value": f"<image>\n{template_func()}"
},
{
"from": "gpt",
"value": out
}
]
return conv
def wrap_ocr_generation_single_turn_conv(out):
return wrap_generation_single_turn_conv(out, ocr_instruction_templates)
def wrap_caption_generation_single_turn_conv(out):
return wrap_generation_single_turn_conv(out, caption_instruction_templates)
def gather_data_files_by_glob(root: str, pattern='*.tsv'):
filenames = []
for fullpath in glob.glob(f'{root}/{pattern}'):
filename = fullpath.split('/')[-1]
filenames.append(filename)
return root, filenames
@register_data_path('default')
def default_data_path(data_dir):
return gather_data_files_by_glob(data_dir, '*.tsv')
@register_data_processor('default')
def default_data_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path,
intent, img_transformer=None):
if intent == 'pretrain' or intent == 'sft':
image, out = load_multimodal_conversation(text_b64, img_b64_buffer)
metainfo = {
"origin_dataset": origin_dataset, # llava folder name
"origin_split": origin_split, # llava parquet file name
"origin_idx": origin_split_inner_idx, # index in llava parquet file
"image_id": img_path, # cocoid
}
return {
'image': image,
'conversations': out,
'idx': origin_split_inner_idx,
'metainfo': metainfo,
}
else:
raise NotImplemented
@register_data_path('llava')
def llava_instruct_data_path():
data_dir = "dataset/train/default"
return gather_data_files_by_glob(data_dir, '*.tsv')
@register_data_processor('llava')
def llava_instruct_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path,
intent, img_transformer=None):
if intent == 'pretrain' or intent == 'sft':
image, out = load_multimodal_conversation(text_b64, img_b64_buffer)
metainfo = {
"origin_dataset": origin_dataset, # llava folder name
"origin_split": origin_split, # llava parquet file name
"origin_idx": origin_split_inner_idx, # index in llava parquet file
"image_id": img_path, # cocoid
}
return {
'image': image,
'conversations': out,
'idx': origin_split_inner_idx,
'metainfo': metainfo,
}
else:
raise NotImplemented
@register_data_path('pretrain_eval_eval')
def pretrain_eval_train_data_path():
data_dir = "dataset/eval/default"
return gather_data_files_by_glob(data_dir, '*.tsv')
@register_data_processor('pretrain_eval_eval')
def unimmchat_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path,
intent, img_transformer=None):
if intent == 'pretrain' or intent == 'sft' or intent == 'eval':
if img_b64_buffer == '<no_image>':
image = '<no_image>'
out = base64.b64decode(text_b64).decode('utf-8')
out = json.loads(out)
else:
image, out = load_multimodal_conversation(text_b64, img_b64_buffer)
metainfo = {
"origin_dataset": origin_dataset, # unimm-chat folder name
"origin_split": origin_split, # unimm-chat parquet file name
"origin_idx": origin_split_inner_idx, # index in unimm-chat parquet file
"image_id": img_path, # cocoid
}
return {
'image': image,
'conversations': out,
'idx': origin_split_inner_idx,
'metainfo': metainfo,
}
else:
raise NotImplemented

View File

@ -0,0 +1,227 @@
import io
import json
import logging
import random
import numpy
import base64
import os.path as op
import torch.utils.data as torch_data
from PIL import Image
from typing import List, Iterator
from vis_fm9g.dataset.tsv_file import TSVFile
from vis_fm9g.dataset.data import register_data_processor
from vis_fm9g.dataset.itembuilder import ItemBuilder
logger = logging.getLogger(__file__)
class MultimodalQADataset(torch_data.Dataset):
def __init__(self, qa_file, question_process):
'''
qa_file: jsonl file that each line is a dict like {
'image': b64img,
'question': question_text
}
'''
super().__init__()
self.qa_file = qa_file
self.qa_data = [json.loads(line) for line in open(self.qa_file)]
if isinstance(self.qa_data[0], list):
self.qa_data = self.qa_data[0] # unwrap one-line json question file
self.question_process = question_process
def __getitem__(self, index):
item = self.qa_data[index]
img_b64 = item['image']
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert('RGB')
raw_question = item['question']
question_text = self.question_process(raw_question)
return {
'image': image,
'raw_question': raw_question,
'question': question_text
}
def __len__(self):
return len(self.qa_data)
class SingleDataSourceDataset(torch_data.Dataset):
def __init__(self, ds_name, item_builder: ItemBuilder, data_dir, tsv_filenames: List[str], intent='sft') -> None:
super().__init__()
self.data_dir = data_dir
self.filenames = tsv_filenames
self.ds_name = ds_name
self.sizes = []
for filename in self.filenames:
try:
size = int(filename[:-4].split('-')[-1])
except:
raise ValueError(f'TSV Data File {filename} is not valid, last component separated by `-` must be the number of sample in this file')
self.sizes.append(size)
self.file_border_index = []
self.prepare_border_index()
self.item_builder = item_builder
self.files = self.filenames[:]
self.intent = intent
def prepare_border_index(self):
self.file_border_index = [0]
temp_sum = 0
for size in self.sizes:
temp_sum += size
self.file_border_index.append(temp_sum)
def get_file_idx_and_row_idx(self, index):
found = False
file_idx = -1
for border_idx, border in enumerate(self.file_border_index):
if index < border:
file_idx = border_idx - 1
found = True
break
if not found:
raise ValueError(f'Index {index} out of range for {self.size_sum} border markers')
offset = self.file_border_index[file_idx]
row_idx = index - offset
return file_idx, row_idx
def __len__(self):
return self.file_border_index[-1]
def __getitem__(self, index):
file_idx, row_idx = self.get_file_idx_and_row_idx(index)
try:
sample = self.fetch_sample(file_idx, row_idx)
item = self.item_builder.build_item(sample)
except:
logger.warning(f"data fetch error")
return self.__getitem__(random.randint(0, len(self)))
return item
def fetch_sample(self, file_idx, row_idx):
file = self.files[file_idx]
if isinstance(file, str):
self.prepare_file(file_idx)
file = self.files[file_idx]
assert isinstance(file, TSVFile), f'Expecting TSVFile but get {file} as {type(file)}'
# tsv line as tuple
sample = file[row_idx]
ds_name, *values = sample
# data dict
sample = register_data_processor[self.ds_name](*values, intent=self.intent)
if row_idx + 1 == len(file):
del file
self.files[file_idx] = self.filenames[file_idx]
return sample
def prepare_file(self, idx):
filename = self.filenames[idx]
file = TSVFile(op.join(self.data_dir, filename))
self.files[idx] = file
class IterableSingleDataSourceDataset(torch_data.IterableDataset):
def __init__(self) -> None:
super().__init__()
raise NotImplemented
class MultiDataSourceDataset(torch_data.Dataset):
def __init__(self, data_sources: List[SingleDataSourceDataset], data_source_weights: List[int]):
super().__init__()
self.ds_list = data_sources
self.sum_weight = sum(data_source_weights)
self.ds_weights = data_source_weights
for weight in self.ds_weights:
assert isinstance(weight, int), 'weight must be integer'
self.offset2ds = {}
self.offset2wt = {}
self.offset2pd = {}
self.prepare_offset2ds()
ds_loops = []
for ds, wt in zip(self.ds_list, self.ds_weights):
ds_loop = len(ds) // wt
ds_loops.append(ds_loop)
max_loop = max(ds_loops)
self.size = max_loop * self.sum_weight
def prepare_offset2ds(self):
offset = 0
for ds, weight in zip(self.ds_list, self.ds_weights):
pd = offset
for _ in range(weight):
self.offset2ds[offset] = ds
self.offset2wt[offset] = weight
self.offset2pd[offset] = pd
offset += 1
def __getitem__(self, index):
n_loop = index // self.sum_weight
offset = index % self.sum_weight
ds = self.offset2ds[offset]
ds_inner_idx = n_loop * self.offset2wt[offset] + offset - self.offset2pd[offset]
ds_inner_idx = ds_inner_idx % len(ds)
return ds[ds_inner_idx]
def __len__(self):
return self.size
class IterableMultiDataSourceDataset(torch_data.IterableDataset):
def __init__(self, data_sources, data_source_weights):
super().__init__()
self.ds_list = data_sources
sum_weight = sum(data_source_weights)
self.ds_weights = [x / sum_weight for x in data_source_weights]
self.ds_consumption = []
self.ds_sizes = [len(ds) for ds in self.ds_list]
def __next__(self):
ds_idx = numpy.random.choice(range(len(self.ds_list)), 1, p=self.ds_weights)[0]
data_source = self.ds_list[ds_idx]
self.ds_consumption[ds_idx] += 1
if self.ds_consumption[ds_idx] % self.ds_sizes[ds_idx] == 0:
self.report_consumption()
sample = next(data_source)
return sample
def __iter__(self) -> Iterator:
return self
def __len__(self):
return sum(self.ds_sizes)
def report_consumption(self):
for ds, consumption, size in zip(self.ds_list, self.ds_consumption, self.ds_sizes):
print(f'Data {ds} consumption: {consumption / size:.2f} epoch', flush=True)

View File

@ -0,0 +1,308 @@
import io
import json
from typing import Dict, Tuple, List, Any
import torch
import pandas as pd
import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.data import default_collate
from utils.logger import init_logger
from vis_fm9g.dataset.utils import convert_data_to_id
from vis_fm9g.dataset.utils import convert_conversation_data_to_id
from vis_fm9g.dataset.utils import pad
import random
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
from vis_fm9g.utils.constants import usr_indicator, bot_indicator
from vis_fm9g.dataset.prompts import caption_zh, caption_en
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
logger = init_logger()
def is_contain_chinese(check_str):
"""
判断字符串中是否包含中文
:param check_str: {str} 需要检测的字符串
:return: {bool} 包含返回True 不包含返回False
"""
for ch in check_str:
if u'\u4e00' <= ch <= u'\u9fff':
return True
def maybe_select_text(raw_text):
candidates = raw_text.split('<cap_sep>')
return random.choice(candidates)
def maybe_parse_json(raw_text: str):
# VG raw
if raw_text.startswith('[{') and raw_text.endswith('}]'):
try:
data = json.loads(raw_text)
text_list = [x['phrase'] for x in data if x['height'] > 160 and x['width'] > 160]
if len(text_list) == 0:
return max(data, key=lambda x: len(x['phrase'].split()))['phrase']
else:
return random.choice(text_list)
except:
return raw_text
else:
return raw_text
def clean_text(raw_text):
text = raw_text.replace('<PERSON>', '')
text = maybe_parse_json(maybe_select_text(text))
return text
def check_text_valid(raw_text):
if pd.isna(raw_text):
return False
if not is_contain_chinese(raw_text) and len(raw_text.split()) <= 3:
return False
if '<img' in raw_text or '<a href' in raw_text:
return False
return True
def get_image_placeholder(tokenizer, query_len, use_im_start_end=False):
if use_im_start_end:
return tokenizer.im_start + tokenizer.unk_token * query_len + tokenizer.im_end
else:
return tokenizer.unk_token * query_len
class ItemBuilder():
def __init__(self, transform=None):
self.transform = transform
def build_item(self, data):
if self.transform is not None:
return self.transform(data)
return data
# --------------------- FM9G ---------------------
class FM9GBuilder(ItemBuilder):
def __init__(self, tokenizer: FM9GTokenizer, max_len, transform=None, skip_overlength=False):
super().__init__(transform)
self.tokenizer = tokenizer
self.max_len = max_len
self.skip_overlength = skip_overlength
def convert_data(self, inp_dicts: List[Dict], raw_data):
res = []
for inp_dict in inp_dicts:
input_ids, context = convert_data_to_id(self.tokenizer, data=inp_dict)
if len(input_ids) > self.max_len:
if self.skip_overlength:
if random.random() > 0.95:
logger.warn(f"overlength={len(input_ids)}, raw_inp={inp_dict}, skip data")
else:
logger.warn(f"overlength={len(input_ids)}, skip data")
continue
input_ids = input_ids[: self.max_len]
context = context[: self.max_len]
res.append({
'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
'context': torch.from_numpy(context).unsqueeze(0),
'raw_data': raw_data,
})
return res
def convert_conversation_data(self, conversation_list: List[List]):
res = []
for conversation in conversation_list:
input_ids, context, raw = convert_conversation_data_to_id(self.tokenizer, data=conversation, predict_roles={bot_indicator})
if len(input_ids) > self.max_len:
if self.skip_overlength:
if random.random() > 0.95:
logger.warn(f"overlength={len(input_ids)}, raw_inp={conversation}, skip data")
else:
logger.warn(f"overlength={len(input_ids)}, skip data")
continue
input_ids = input_ids[: self.max_len]
context = context[: self.max_len]
res.append({
'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
'context': torch.from_numpy(context).unsqueeze(0),
'raw_data': raw,
})
return res
def build_image_bound(self, res, images):
return_res = []
if isinstance(images, List) and len(images) > 0:
images = torch.stack(images)
for r in res:
# r['input_ids'] (1, len)
image_start_tokens = torch.where(r['input_ids'][0] == self.tokenizer.encoder[self.tokenizer.im_start])[0]
# 跳过 im_start
image_start_tokens += 1
image_end_tokens = torch.where(r['input_ids'][0] == self.tokenizer.encoder[self.tokenizer.im_end])[0]
if len(image_start_tokens) != len(image_end_tokens) or len(image_start_tokens) > len(images):
continue
image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])
r['pixel_values'] = images[:len(image_start_tokens)]
r['image_bound'] = image_bound
return_res.append(r)
return return_res
def build_item(self, data):
NotImplementedError("build_item is not implemented.")
class FM9GImageTextBuilder(FM9GBuilder):
def __init__(self, tokenizer: FM9GTokenizer, max_len, transform=None, query_len=64, min_resolution=0, skip_overlength=False):
super().__init__(tokenizer, max_len, transform, skip_overlength)
self.query_len = query_len
self.min_resolution = min_resolution
def build_item(self, data):
text = data['conversations']
image = data['image']
source = data.get('metainfo', {}).get('origin_dataset', 'unk')
image = self.transform(image)
raw_data = {'text': text}
image_placeholder = get_image_placeholder(self.tokenizer, self.query_len, use_im_start_end=True)
messages = []
for i in range(len(text)):
role = text[i]['from']
role = usr_indicator if role == 'human' else bot_indicator
value = self.tokenizer.escape(text[i]['value'])
if '<image>' in value:
value = value.replace('<image>', image_placeholder)
messages.append((role, value))
res = self.convert_conversation_data([messages])
self.build_image_bound(res, images=[image])
for r in res:
r['source'] = source
return res[0]
class FM9GCollater:
def __init__(self, tokenizer: FM9GTokenizer, max_len: int, unpad: bool = False, unilm: bool = False):
self.tokenizer = tokenizer
self._max_length = max_len
self._unpad = unpad
self._unilm = unilm
self.pad_keys = ['input_ids', 'context']
def __call__(self, batch):
batch_cnt = len(batch)
if self._unpad: # for flash_attention cuda
max_length = self._max_length * batch_cnt
batch_size = 1
else:
max_length = self._max_length
batch_size = batch_cnt
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
context_origin = np.zeros((batch_size, max_length), dtype=np.int8)
context = np.zeros((batch_size, max_length), dtype=np.int8)
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
spans = np.zeros((batch_size, max_length), dtype=np.int32)
length = np.zeros((batch_size,), dtype=np.int32)
position_ids = np.zeros((batch_size, max_length), dtype=np.int32)
if self._unpad: # for flash_attention cuda force batch_size=1
flatten_input_ids = np.concatenate([batch[i]['input_ids'][0] for i in range(batch_cnt)], axis=0)
flatten_context = np.concatenate([batch[i]['context'][0] for i in range(batch_cnt)], axis=0)
instance_length = flatten_input_ids.shape[0]
inputs[0, : instance_length] = flatten_input_ids
context_origin[0, : instance_length] = flatten_context
length[0] = instance_length
if self._unilm:
context[0, : instance_length] = flatten_context
# flatten batch
_spans = [list(np.cumsum([batch[i]['input_ids'][0].shape[0] for i in range(batch_cnt)]))]
else:
for i in range(batch_cnt):
instance_length = batch[i]['input_ids'][0].shape[0]
inputs[i, :instance_length] = batch[i]['input_ids'][0]
context_origin[i, : instance_length] = batch[i]['context'][0]
length[i] = instance_length
if self._unilm:
context[i, :instance_length] = batch[i]['context'][0]
_spans = [[batch[i]['input_ids'][0].shape[0]] for i in range(batch_cnt)]
# cu_seqlens 和 max_seqlen 在 flash_attention cuda 时需要
if _spans[0][-1] != max_length:
cu_seqlens = np.array([0] + _spans[0] + [max_length], dtype=np.int32)
else:
cu_seqlens = np.array([0] + _spans[0], dtype=np.int32)
max_seqlen = int(np.max(cu_seqlens[1:] - cu_seqlens[:-1]))
raw_data_list: List[Any] = [batch[i]['raw_data'] for i in range(batch_cnt)]
source_list: List[Any] = [batch[i].get('source', 'unk') for i in range(batch_cnt)]
for i in range(batch_size):
instance_length = length[i]
span_begin = 0
for span_id, span_end in enumerate(_spans[i]):
spans[i, span_begin: span_end] = span_id
position_ids[i, span_begin:span_end] = np.arange(span_end - span_begin)
span_begin = span_end
for j in range(instance_length):
idx = inputs[i][j]
if j > 1:
if context_origin[i][j] == 0:
if idx != self.tokenizer.bos_id and inputs[i][j - 1] != self.tokenizer.eos_id:
tgt[i, j - 1] = idx
if context_origin[i][j] == 1 and context_origin[i][j-1] == 0:
if idx != self.tokenizer.bos_id and inputs[i][j - 1] != self.tokenizer.eos_id:
tgt[i, j - 1] = self.tokenizer.eos_id
data = {}
# image
if 'pixel_values' in batch[0]:
if self._unpad:
data['pixel_values'] = [torch.vstack([i['pixel_values'] for i in batch])]
else:
data['pixel_values'] = [i['pixel_values'] for i in batch]
# image_bound
if 'image_bound' in batch[0]:
if self._unpad:
image_bounds = []
for i in range(batch_cnt):
offset = _spans[0][i-1] if i > 0 else 0
image_bounds.append(batch[i]['image_bound'] + offset)
data['image_bound'] = [torch.vstack(image_bounds)]
else:
data['image_bound'] = [i['image_bound'] for i in batch]
data['input_ids'] = torch.from_numpy(inputs)
data['context'] = torch.from_numpy(context) > 0
data['length'] = torch.from_numpy(length)
data['spans'] = torch.from_numpy(spans)
data['cu_seqlens'] = torch.from_numpy(cu_seqlens)
data['max_seqlen'] = max_seqlen
data['position_ids'] = torch.from_numpy(position_ids)
data['target'] = torch.from_numpy(tgt)
data['raw_data'] = raw_data_list
data['source'] = source_list
return data

View File

@ -0,0 +1,24 @@
caption_en = [
'Describe the image concisely',
'Provide a brief description of the given image',
'Offer a succinct explanation of the picture presented',
'Summarize the visual content of the image',
'Share a conciseinter pretation of the image provided',
'Present a compact description of the photos key features',
'Relay a brief and clear account of the picture shown',
'Render a clear and concise summary of the photo',
'Write a terse but informative summary of the picture',
'Create a compact narrative representing the image presented',
]
caption_zh = [
'简明扼要地描述图像',
'提供给定图像的简短描述',
'对所示的图片进行简要的解释',
'总结图像的视觉内容',
'对所提供的图像进行简要的解释',
'简明扼要并清楚地说明所示图片',
'对这张照片作一个简明扼要的总结',
'写一篇简洁但内容丰富的图片摘要',
'创造一个紧凑的叙事来代表所呈现的图像',
]

View File

@ -0,0 +1,106 @@
# Copyright (c) 2021 Microsoft Corporation. Licensed under the MIT license.
import os
import logging
import os.path as op
LARGEST_TSV_SIZE = 500_000
# LARGEST_TSV_SIZE = 10_000
def create_lineidx(filein, idxout):
idxout_tmp = idxout + '.tmp'
with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
fsize = os.fstat(tsvin.fileno()).st_size
fpos = 0
while fpos != fsize:
tsvout.write(str(fpos)+"\n")
tsvin.readline()
fpos = tsvin.tell()
os.rename(idxout_tmp, idxout)
def read_to_character(fp, c):
result = []
while True:
s = fp.read(32)
assert s != ''
if c in s:
result.append(s[: s.index(c)])
break
else:
result.append(s)
return ''.join(result)
class TSVFile(object):
def __init__(self, tsv_file, generate_lineidx=False):
self.tsv_file = tsv_file
self.lineidx = op.splitext(tsv_file)[0] + '.lineidx'
self._fp = None
self._lineidx = None
# the process always keeps the process which opens the file.
# If the pid is not equal to the currrent pid, we will re-open the file.
self.pid = None
# generate lineidx if not exist
if not op.isfile(self.lineidx) and generate_lineidx:
create_lineidx(self.tsv_file, self.lineidx)
def __del__(self):
if self._fp:
self._fp.close()
def __str__(self):
return "TSVFile(tsv_file='{}')".format(self.tsv_file)
def __repr__(self):
return str(self)
def num_rows(self):
self._ensure_lineidx_loaded()
assert len(
self._lineidx) <= LARGEST_TSV_SIZE, f'Do not support TSVFile larger than {LARGEST_TSV_SIZE} yet'
return len(self._lineidx)
def seek(self, idx):
self._ensure_tsv_opened()
self._ensure_lineidx_loaded()
try:
pos = self._lineidx[idx]
except:
logging.info('{}-{}'.format(self.tsv_file, idx))
raise
self._fp.seek(pos)
return [s.strip() for s in self._fp.readline().split('\t')]
def seek_first_column(self, idx):
self._ensure_tsv_opened()
self._ensure_lineidx_loaded()
pos = self._lineidx[idx]
self._fp.seek(pos)
return read_to_character(self._fp, '\t')
def get_key(self, idx):
return self.seek_first_column(idx)
def __getitem__(self, index):
return self.seek(index)
def __len__(self):
return self.num_rows()
def _ensure_lineidx_loaded(self):
if self._lineidx is None:
logging.debug('loading lineidx: {}'.format(self.lineidx))
with open(self.lineidx, 'r') as fp:
self._lineidx = [int(i.strip()) for i in fp.readlines()]
def _ensure_tsv_opened(self):
if self._fp is None:
self._fp = open(self.tsv_file, 'r')
self.pid = os.getpid()
if self.pid != os.getpid():
# logging.info('re-open {} because the process id changed'.format(self.tsv_file))
self._fp = open(self.tsv_file, 'r')
self.pid = os.getpid()

View File

@ -0,0 +1,170 @@
import importlib.machinery
import importlib.util
import types
from typing import Any, Set
from typing import Dict
from typing import List
from typing import Union
import torch
import numpy as np
from numpy.typing import NDArray
from torch.utils.data import BatchSampler
from typing_extensions import TypedDict
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
from vis_fm9g.utils.constants import SYSTEM
FM9GInputType = Union[str, Dict[str, "FM9GInputType"]]
class _TransformFuncDict(TypedDict):
loader: importlib.machinery.SourceFileLoader
module: types.ModuleType
last_m: float
class FM9GBatch(TypedDict):
inputs: NDArray[np.int32]
length: NDArray[np.int32]
context: NDArray[np.bool_]
sample_ids: NDArray[np.int32]
spans: NDArray[np.int32]
target: NDArray[np.int32]
task_ids: NDArray[np.int32]
task_names: List[str]
raw_data: List[Any]
def convert_data_to_id(tokenizer: FM9GTokenizer, data: Any):
"""
data: {
'input': xxx,
'output': xxx
}
"""
input_ids = tokenizer.encode(data["input"])
output_ids = tokenizer.encode(data["output"])
ids = [tokenizer.bos_id] + input_ids + output_ids + [tokenizer.eos_id]
ids = np.array(ids, dtype=np.int32)
context = np.zeros((ids.shape[0],), dtype=np.int8)
# context = 0 需要 target
context[: len(input_ids) + 1] = 1
return ids, context
def convert_conversation_data_to_id(tokenizer: FM9GTokenizer, data: Any, predict_roles: Set):
"""
predict_roles: {'<AI>'}
data: [
('<用户>', xxxx),
('<AI>', xxxx)
]
"""
assert (set([i[0] for i in data]) & predict_roles)
if SYSTEM:
system = tokenizer.bos_token + SYSTEM + '\n'
else:
system = tokenizer.bos_token
sys_idx = tokenizer.encode(system)
ret = system
input_ids = [sys_idx] if sys_idx else []
context = [np.ones((len(sys_idx),), dtype=np.int8)]
for idx, (role, message) in enumerate(data):
prefix = role
# 最后一句加上 eos
if idx == len(data)-1:
message = message + tokenizer.eos_token
prefix_ids = tokenizer.encode(prefix)
message_ids = tokenizer.encode(message)
input_ids.append(prefix_ids)
input_ids.append(message_ids)
context.append(np.ones((len(prefix_ids),), dtype=np.int8))
if role in predict_roles:
context.append(np.zeros((len(message_ids),), dtype=np.int8))
else:
context.append(np.ones((len(message_ids),), dtype=np.int8))
ret += (prefix + message)
ids = np.hstack(input_ids)
context = np.hstack(context)
return ids, context, ret
def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
items = []
if isinstance(orig_items[0][key], list):
assert isinstance(orig_items[0][key][0], torch.Tensor)
for it in orig_items:
for tr in it[key]:
items.append({key: tr})
else:
assert isinstance(orig_items[0][key], torch.Tensor)
items = orig_items
batch_size = len(items)
shape = items[0][key].shape
dim = len(shape)
assert dim <= 3
if max_length is None:
max_length = 0
max_length = max(max_length, max(item[key].shape[-1] for item in items))
min_length = min(item[key].shape[-1] for item in items)
dtype = items[0][key].dtype
if dim == 1:
return torch.cat([item[key] for item in items], dim=0)
elif dim == 2:
if max_length == min_length:
return torch.cat([item[key] for item in items], dim=0)
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
else:
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
for i, item in enumerate(items):
if dim == 2:
if padding_side == "left":
tensor[i, -len(item[key][0]):] = item[key][0].clone()
else:
tensor[i, : len(item[key][0])] = item[key][0].clone()
elif dim == 3:
if padding_side == "left":
tensor[i, -len(item[key][0]):, :] = item[key][0].clone()
else:
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
return tensor
class SkipBatchSampler(BatchSampler):
"""
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
"""
def __init__(self, batch_sampler, skip_batches=0):
self.batch_sampler = batch_sampler
self.skip_batches = skip_batches
self.first_epoch = True
def __iter__(self):
for index, samples in enumerate(self.batch_sampler):
if index >= self.skip_batches and self.first_epoch:
yield samples
self.first_epoch = False
@property
def total_length(self):
return len(self.batch_sampler)
def __len__(self):
return len(self.batch_sampler) - self.skip_batches

View File

View File

@ -0,0 +1,112 @@
import torch
import torch.nn.functional as F
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
batch_size = logits.size()[0]
if top_p > 0.0:
logits = logits.view(batch_size, -1).contiguous()
for index in range(len(logits)):
sorted_logits, sorted_indices = torch.sort(logits[index].view(-1), descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[index][indices_to_remove] = filter_value
logits = logits.view(batch_size, -1).contiguous()
return logits
def apply_repetition_penalty(
logits,
batch_size,
num_beams,
prev_output_tokens,
repetition_penalty,
start_idx=None,
end_idx=None,
window_size=None,
):
# only conduct repetition penalty for the output
assert repetition_penalty >= 1, "repetition penalty coefficient should >= 1"
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
for i in range(batch_size * num_beams):
if start_idx is None or end_idx is None:
output_tokens = prev_output_tokens[i].tolist()
else:
if end_idx >= start_idx:
if window_size:
output_tokens = prev_output_tokens[i][
max(start_idx, end_idx + 1 - window_size): end_idx + 1
].tolist()
else:
output_tokens = prev_output_tokens[i][start_idx: end_idx + 1].tolist()
else:
output_tokens = []
for previous_token in set(output_tokens):
# if score < 0 then repetition penalty has to
# multiplied to reduce the previous token probability
if logits[i, previous_token] < 0:
logits[i, previous_token] *= repetition_penalty
else:
logits[i, previous_token] /= repetition_penalty
class BeamHypotheses:
def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_len = max_len
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.n_hyp = n_hyp
self.hyp = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.hyp)
def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.n_hyp or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.n_hyp:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.n_hyp:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / cur_len**self.length_penalty

View File

@ -0,0 +1,425 @@
from typing import Any, List, Optional, Union
import torch
import torch.nn.functional as F
from vis_fm9g.generation.generation_utils import BeamHypotheses, apply_repetition_penalty, \
top_k_top_p_filtering
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer, LlamaTokenizerWrapper
from vis_fm9g.model.vlu_fm9g import VLU_FM9G
from vis_fm9g.dataset.utils import pad
class VLLMFM9GGeneration:
def __init__(self, model: VLU_FM9G, tokenizer: Union[FM9GTokenizer, LlamaTokenizerWrapper], transform):
model.eval()
self.model = model
self.tokenizer = tokenizer
self.transform = transform
def _convert_to_tensors(self, data: Any, max_inp_length: Optional[int] = None):
if isinstance(self.tokenizer, LlamaTokenizerWrapper) and self.tokenizer.add_bos_token:
input_ids = self.tokenizer.encode(data["input"])
else:
input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(data["input"])
if max_inp_length is not None:
input_ids = input_ids[: max_inp_length]
input_ids = torch.tensor(input_ids, dtype=torch.int32)
image_start_tokens = torch.where(input_ids == self.tokenizer.im_start_id)[0]
# 跳过 im_start
image_start_tokens += 1
image_end_tokens = torch.where(input_ids == self.tokenizer.im_end_id)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
image_bound = torch.hstack(
[image_start_tokens[: valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1)]
)
model_input = {}
model_input["input_ids"] = input_ids.unsqueeze(0)
model_input["context"] = torch.zeros(
(model_input["input_ids"].shape[0], model_input["input_ids"].shape[1]), dtype=torch.int16
)
model_input["span"] = torch.ones((model_input["input_ids"].shape[1],), dtype=torch.int16).unsqueeze(0)
model_input["length"] = torch.tensor([model_input["input_ids"].shape[1]], dtype=torch.int16)
model_input["image_bound"] = image_bound
return model_input
def _process_list(self, data_list: List[Any], max_inp_length: Optional[int] = None):
pad_keys = ['input_ids', 'context', 'span']
input_tensors = []
for data in data_list:
input_tensors.append(self._convert_to_tensors(data, max_inp_length))
padded = {}
for key in pad_keys:
padded[key] = pad(input_tensors, key, padding_side="left").cuda()
padded['length'] = torch.hstack([i['length'] for i in input_tensors]).cuda()
padded['image_bound'] = [i['image_bound'] for i in input_tensors]
return padded
def generate(
self,
data_list=None,
img_list=None,
max_inp_length: Optional[int] = None,
vision_hidden_states=None,
return_vision_hidden_states=False,
use_transform=True,
**kwargs
):
# img_list List[List[images]]
assert data_list is not None
bs = len(data_list)
if img_list == None:
img_list = [[] for i in range(bs)]
assert bs == len(img_list)
model_inputs = self._process_list(data_list, max_inp_length)
if vision_hidden_states is None:
if use_transform:
pixel_values = []
for i in range(bs):
img_inps = []
for img in img_list[i]:
img_inps.append(self.transform(img))
if img_inps:
pixel_values.append(torch.stack(img_inps).cuda())
else:
pixel_values.append([])
model_inputs['pixel_values'] = pixel_values
else:
pixel_values = img_list
model_inputs['pixel_values'] = pixel_values
else:
model_inputs['vision_hidden_states'] = vision_hidden_states
with torch.inference_mode():
model_inputs['hidden_states'], vision_hidden_states = self.model.get_vllm_embedding(
model_inputs)
result = self._decode(model_inputs, **kwargs)
if return_vision_hidden_states:
return result, vision_hidden_states
return result
def _decode(self, model_inputs, **kwargs):
raise NotImplementedError("_decode is not implemented.")
def _decode_text(self, result_ids):
result_text = []
for result in result_ids:
if result[-1] == self.tokenizer.eos_id:
result = result[:-1]
result_text.append(self.tokenizer.decode(result))
return result_text
class VLLMFM9GBeamSearch(VLLMFM9GGeneration):
def _decode(
self,
model_inputs,
beam_size=3,
max_length=100,
min_length=0,
repetition_penalty=1.0,
length_penalty=1.0,
temperature=1.0,
repetition_window=None,
):
"""
Beam search
Args:
model_inputs (dict): input ids.
beam_size (int, optional, defaults to 3): beam size of beam search.
generate_length (int, optional, defaults to 100): maximum generation length.
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
""" # noqa: E501
# expand dimmension
batch_size = model_inputs["input_ids"].size(0)
input: torch.Tensor = (
model_inputs["input_ids"]
.unsqueeze(1)
.expand(batch_size, beam_size, -1)
.contiguous()
.view(batch_size * beam_size, -1)
)
length = (
model_inputs["length"]
.unsqueeze(1)
.expand(batch_size, beam_size)
.contiguous()
.view(
batch_size * beam_size,
)
)
span: torch.Tensor = (
model_inputs["span"]
.unsqueeze(1)
.expand(batch_size, beam_size, -1)
.contiguous()
.view(batch_size * beam_size, -1)
)
context: torch.Tensor = (
model_inputs["context"]
.unsqueeze(1)
.expand(batch_size, beam_size, -1)
.contiguous()
.view(batch_size * beam_size, -1)
)
hidden_states: torch.Tensor = (
model_inputs["hidden_states"]
.unsqueeze(1)
.expand(batch_size, beam_size, *model_inputs["hidden_states"].shape[1:])
.contiguous()
.view(batch_size * beam_size, *model_inputs["hidden_states"].shape[1:])
)
done = [False for _ in range(batch_size)]
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1)
# generated hypotheses
generated_hyps = [
BeamHypotheses(beam_size, max_length, length_penalty=length_penalty, early_stopping=False)
for _ in range(batch_size)
]
pred_start_index = input.size(-1)
past_key_values = None
for i in range(max_length + 1):
if i == 0:
logits, _, past_key_values = self.model.llm.inference(
input=input,
context=context,
span=span,
length=length,
past_key_values=past_key_values,
hidden_states=hidden_states
)
else:
logits, _, past_key_values = self.model.llm.inference(
input=input[:, -1:],
context=context,
span=span,
length=length,
past_key_values=past_key_values
)
# skip all steps when we are done with each sentence
if all(done):
break
# (batch * beam, seqlen, model_dim)
logits = logits[:, -1, :]
if i == 0:
logits[:, self.tokenizer.eos_id] = -float("inf")
apply_repetition_penalty(
logits,
batch_size,
beam_size,
input,
repetition_penalty,
pred_start_index,
input.size(-1) - 1,
repetition_window,
)
logits = logits / temperature
if i < min_length:
logits[:, self.tokenizer.eos_id] = -float("inf")
scores = F.log_softmax(logits, dim=-1)
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * beam_size, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = next_scores.view(batch_size, -1) # (batch_size, beam_size * vocab_size)
next_scores, next_words = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (batch_size, 2 * beam_size)
next_batch_beam = []
for sent_id in range(batch_size):
# if we are done with this sentence
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item(), i)
if done[sent_id]:
next_batch_beam.extend([(0, 0, 0)] * beam_size) # pad the batch
continue
# next sentence beam content
next_sent_beam = []
# next words for this sentence
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
# get beam and word IDs
beam_id = torch.div(idx, scores.size(-1), rounding_mode="floor")
word_id = idx % scores.size(-1)
# end of sentence, or next word
if word_id == self.tokenizer.eos_id or i == max_length:
generated_hyps[sent_id].add(
input[sent_id * beam_size + beam_id, pred_start_index:].clone().cpu().tolist(),
value.item(),
)
else:
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
# the beam for next step is full
if len(next_sent_beam) == beam_size:
break
# update next beam content
assert len(next_sent_beam) == 0 if i == max_length else beam_size
if len(next_sent_beam) == 0:
next_sent_beam = [(0, 0, 0)] * beam_size # pad the batch
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == beam_size * (sent_id + 1)
# we have reached the last step
if i == max_length:
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * beam_size
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = input.new([x[1] for x in next_batch_beam])
beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=input.device).long()
# re-order batch and internal states
input = input[beam_idx, :]
past_key_values["buffer"] = [list(each) if each is not None else each for each in past_key_values["buffer"]] # type: ignore # noqa: E501
for key_value_layer in past_key_values["buffer"]:
if key_value_layer is not None:
key_value_layer[0] = key_value_layer[0][beam_idx]
key_value_layer[1] = key_value_layer[1][beam_idx]
input = torch.cat([input, beam_words.unsqueeze(1)], dim=-1)
context = torch.cat(
[context, context[:, -1:]],
dim=-1,
)
length += 1
span = torch.cat([span, span[:, -1:]], dim=-1)
# select the best hypotheses
results = []
for i, hypotheses in enumerate(generated_hyps):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
results.append(best_hyp)
result_text = self._decode_text(results)
return result_text
class VLLMFM9GRandomSampling(VLLMFM9GGeneration):
def _decode(
self,
model_inputs,
min_length=24,
max_length=50,
top_k=0.0,
top_p=1.0,
temperature=1,
repetition_penalty=1.0,
repetition_window=None
):
"""
Top-k and top-p sampling.
Args:
model_inputs (dict): input ids
max_length (int, optional, defaults to 100): maximum generation length
top_k (int, optional, defaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens.
top_p (int, optional, defaults to 0.9): keep the top tokens with cumulative probability >= top_p.
temperature (int, optional, defaults to 0.9): the value that can cool down the logits distribution.
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
""" # noqa: E501
max_length += 1
input = model_inputs["input_ids"]
context = model_inputs["context"]
length = model_inputs["length"]
span = model_inputs["span"]
hidden_states = model_inputs["hidden_states"]
batch_size = input.size(0)
pred_start_index = input.size(-1)
past_key_values = None
done = [False for _ in range(batch_size)]
results = [None for _ in range(batch_size)]
for i in range(max_length):
if i == 0:
logits, _, past_key_values = self.model.llm.inference(
input=input,
context=context,
length=length,
span=span,
past_key_values=past_key_values,
hidden_states=hidden_states
)
else:
logits, _, past_key_values = self.model.llm.inference(
input=input[:, -1:],
context=context,
length=length,
span=span,
past_key_values=past_key_values
)
logits = logits[:, -1, :]
if i == 0:
logits[:, self.tokenizer.eos_id] = -float("inf")
apply_repetition_penalty(
logits,
batch_size,
1,
input,
repetition_penalty,
pred_start_index,
input.size(-1) - 1,
repetition_window,
)
logits = logits / temperature
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
for idx in range(batch_size):
if not done[idx] and (next_token[idx].item() == self.tokenizer.eos_id or i == max_length - 1):
done[idx] = True
results[idx] = input[idx, pred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501
if sum(done) == batch_size:
break
# update input ids
input = torch.cat([input, next_token], dim=-1)
length += 1
context = torch.cat(
[context, context[:, -1:]],
dim=-1,
)
span = torch.cat(
[span, span[:, -1:]],
dim=-1,
)
result_text = self._decode_text(results)
return result_text

View File

@ -11,5 +11,4 @@ from .position_embedding import ChatGLMRotaryEmbedding
from .position_embedding import RotaryEmbedding
from .position_embedding import RotaryEmbeddingESM
from .position_embedding import SegmentPositionEmbedding
from .transformer import Encoder
#from _attention_pp_sp import OpAttnPipeSP
from .transformer import Encoder

View File

@ -0,0 +1,283 @@
import math
from typing import Optional
from typing import Tuple
try:
from .flash_triton import FlashAttnFunc
except:
FlashAttnFunc = None
import torch
from einops import rearrange
from .linear import Linear
from .position_embedding import apply_chatglm_rotary_pos_emb
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except:
flash_attn_unpadded_func = None
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except:
flash_attn_varlen_func = None
try:
from flash_attn.bert_padding import pad_input
from flash_attn.bert_padding import unpad_input
except:
pad_input = None
unpad_input = None
class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
assert flash_attn_unpadded_func is not None, (
"Please install FlashAttention first, " "e.g., with pip install flash-attn"
)
assert rearrange is not None, "Please install einops first, e.g., with pip install einops"
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, k, v, attention_mask=None, length_mask=None, context_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16], q.dtype
assert q.is_cuda
batch_size, seqlen = q.shape[0], q.shape[1]
d = q.shape[-1]
if length_mask is not None:
q, k, v = [rearrange(x, "b s h d -> b s (h d)") for x in [q, k, v]]
q, indices_q, cu_seqlens, max_s = unpad_input(q, length_mask)
k, _, _, _ = unpad_input(k, length_mask)
v, _, _, _ = unpad_input(v, length_mask)
q, k, v = [rearrange(x, "nnz (h d) -> nnz h d", d=d) for x in [q, k, v]]
output = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=self.causal,
attention_mask=attention_mask,
context_mask=context_mask[:, :max_s],
)
# TODO reimplement (un)pad_input to remove redundant rearranges.
output = rearrange(output, "nnz h d -> nnz (h d)")
output = pad_input(output, indices_q, batch_size, seqlen)
output = rearrange(output, "b s (h d) -> b s h d", d=d)
else:
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=q.device)
output = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=self.causal,
attention_mask=attention_mask,
context_mask=context_mask,
)
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
return output
class Attention(torch.nn.Module):
def __init__(
self,
dim_model: int,
num_heads: int,
num_kv_heads: int,
dim_head: int,
dtype: torch.dtype = torch.half,
dropout_p: Optional[float] = None,
scale: bool = True,
add_qkv_bias: bool = False,
use_flash_attn: bool = False,
) -> None:
super().__init__()
self.dim_model = dim_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_groups = num_heads // num_kv_heads
self.dim_head = dim_head
self.project_q = Linear(
self.dim_model, self.num_heads * self.dim_head, bias=add_qkv_bias, dtype=dtype, scale=scale
)
self.project_k = Linear(
self.dim_model, self.num_kv_heads * self.dim_head, bias=add_qkv_bias, dtype=dtype, scale=scale
)
self.project_v = Linear(
self.dim_model, self.num_kv_heads * self.dim_head, bias=add_qkv_bias, dtype=dtype, scale=scale
)
self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype, scale=scale)
self.softmax = torch.nn.Softmax(dim=-1)
if dropout_p is not None:
self.dropout = torch.nn.Dropout(p=dropout_p)
self.dropout_p = dropout_p
else:
self.dropout = None
# if use_flash_attn:
# self.core_attention_flash = FlashSelfAttention(causal=False, attention_dropout=0.0)
self.use_flash_attn = use_flash_attn
def forward(
self,
hidden_q: torch.Tensor,
hidden_kv: torch.Tensor,
attention_mask: torch.BoolTensor,
position_bias: torch.Tensor,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
attention_mask_bias: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: int = None,
position_ids: Optional[torch.Tensor] = None,
):
"""This model inherits from bmt.DistributedModule.
Args:
hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices.
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`.
Return:
out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output.
""" # noqa: E501
batch_size = hidden_q.size(0)
len_q = hidden_q.size(1)
len_k = hidden_kv.size(1)
h_q = self.project_q(hidden_q)
h_k = self.project_k(hidden_kv)
h_v = self.project_v(hidden_kv)
if not self.use_flash_attn:
h_q = h_q / math.sqrt(math.sqrt(self.dim_head))
h_k = h_k / math.sqrt(math.sqrt(self.dim_head))
h_q = h_q.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
h_k = h_k.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
h_v = h_v.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
if pos_bias_type == "rotary":
# b h s d
h_q, h_k = position_bias(h_q, h_k, -2, offset=past_kv[0].size(-2) if past_kv is not None else 0)
elif pos_bias_type == "chatglm_rotary":
h_q = apply_chatglm_rotary_pos_emb(h_q, position_bias)
h_k = apply_chatglm_rotary_pos_emb(h_k, position_bias)
if past_kv is not None:
h_k = torch.cat([past_kv[0], h_k], dim=-2)
h_v = torch.cat([past_kv[1], h_v], dim=-2)
len_k = h_k.size(-2)
# (b, n_h, len_q, d_h) @ (b, n_h, d_h, len_k) -> (b, n_h, len_q, len_k)
# (b, n_kv_h, n_h_groups*len_q, d_h) @ (b, n_kv_h, d_h, len_k) -> (b, n_kv_h, n_h_groups*len_q, len_k) -> (b, n_h, len_q, len_k)
if self.head_groups == 1:
score = torch.matmul(h_q, h_k.transpose(-1, -2)) # / math.sqrt(self.dim_head) moved to line 75~76
else:
score = torch.matmul(
h_q.reshape(batch_size, self.num_kv_heads, self.head_groups * len_q, self.dim_head),
h_k.transpose(-1, -2),
).view(
batch_size, self.num_heads, len_q, len_k
) # / math.sqrt(self.dim_head) moved to line 75~76
if pos_bias_type == "relative":
if len_q == 1: # inference with cache
if len(position_bias.size()) == 4:
position_bias = position_bias[:, :, -1:, :]
else:
position_bias = position_bias[:, -1:, :]
score = score + position_bias
score = torch.masked_fill(
score,
attention_mask.view(batch_size, 1, len_q, len_k) == False,
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
)
score = self.softmax(score)
score = torch.masked_fill(
score,
attention_mask.view(batch_size, 1, len_q, len_k) == False,
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
)
if self.dropout is not None:
score = self.dropout(score)
# (b, n_h, len_q, len_k) @ (b, n_h, len_k, d_h) -> (b, n_h, len_q, d_h)
# (b, n_kv_h, n_h_groups*len_q, len_k) @ (b, n_kv_h, len_k, d_h) -> (b, n_kv_h, n_h_groups*len_q, d_h) -> (b, n_h, len_q, d_h)
score = torch.matmul(score.view(batch_size, self.num_kv_heads, self.head_groups * len_q, len_k), h_v).view(
batch_size, self.num_heads, len_q, self.dim_head
)
score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
else:
if attention_mask_bias is not None:
assert pos_bias_type == "rotary"
h_q = h_q.view(batch_size, len_q, self.num_heads, self.dim_head) # .permute(0, 2, 1, 3)
h_k = h_k.view(batch_size, len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3)
h_v = h_v.view(batch_size, len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3)
h_q, h_k = position_bias(h_q, h_k, -3)
score = FlashAttnFunc.apply(h_q, h_k, h_v, attention_mask_bias, False, None)
else:
if pos_bias_type == "chatglm_rotary":
raise NotImplemented("No FlashAttn version for ChatGLM at present!")
h_q = h_q.view(len_q, self.num_heads, self.dim_head) # .permute(0, 2, 1, 3)
h_k = h_k.view(len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3)
h_v = h_v.view(len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3)
h_q, h_k = position_bias(
h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids
)
score = flash_attn_varlen_func(
h_q, h_k, h_v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, self.dropout_p, causal=True
)
score = score.view(batch_size, len_q, self.num_heads * self.dim_head)
score = self.attention_out(score)
if use_cache:
return score, (h_k, h_v)
else:
return score

View File

@ -2,7 +2,6 @@ from typing import Optional
from typing import Tuple
from typing import Union
import bmtrain as bmt
import torch
from .attention import Attention
@ -12,7 +11,7 @@ from .position_embedding import RotaryEmbedding
from .position_embedding import RotaryEmbeddingESM
class SelfAttentionBlock(bmt.DistributedModule):
class SelfAttentionBlock(torch.nn.Module):
"""The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection.
Args:
@ -31,12 +30,11 @@ class SelfAttentionBlock(bmt.DistributedModule):
num_kv_heads: int,
dim_head: int,
dtype=torch.half,
eps: float = 1e-5,
eps: float = 1e-6,
dropout_p: Optional[float] = None,
scale: bool = True,
add_qkv_bias: bool = False,
use_flash_attn: bool = False,
tp: int = 0,
):
super().__init__()
@ -56,7 +54,6 @@ class SelfAttentionBlock(bmt.DistributedModule):
scale=scale,
add_qkv_bias=add_qkv_bias,
use_flash_attn=use_flash_attn,
tp=tp,
)
if dropout_p:
@ -73,6 +70,7 @@ class SelfAttentionBlock(bmt.DistributedModule):
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
attention_mask_bias: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: int = None,
@ -98,6 +96,7 @@ class SelfAttentionBlock(bmt.DistributedModule):
past_key_value,
pos_bias_type=pos_bias_type,
length_mask=length_mask,
context_mask=context_mask,
attention_mask_bias=attention_mask_bias,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
@ -137,7 +136,6 @@ class FFNBlock(torch.nn.Module):
eps: float = 1e-6,
dropout_p: Optional[float] = 0,
scale: bool = True,
tp: int = 0,
):
super().__init__()
@ -154,7 +152,6 @@ class FFNBlock(torch.nn.Module):
dtype=dtype,
dropout_p=dropout_p,
scale=scale,
tp=tp,
)
if dropout_p:
@ -178,7 +175,9 @@ class FFNBlock(torch.nn.Module):
x = self.ffn(x)
if self.dropout is not None:
x = self.dropout(x)
hidden_states = hidden_states + x # / 1.05
hidden_states = hidden_states + x
return hidden_states
@ -211,7 +210,6 @@ class TransformerBlock(torch.nn.Module):
mask_att: bool = False,
mask_ffn: bool = False,
use_flash_attn: bool = False,
tp: int = 0,
):
super().__init__()
self.mask_att = mask_att
@ -229,7 +227,6 @@ class TransformerBlock(torch.nn.Module):
scale=scale,
add_qkv_bias=add_qkv_bias,
use_flash_attn=use_flash_attn,
tp=tp,
)
if not self.mask_ffn:
@ -241,7 +238,6 @@ class TransformerBlock(torch.nn.Module):
eps=eps,
dropout_p=dropout_p,
scale=scale,
tp=tp,
)
def forward(
@ -253,6 +249,7 @@ class TransformerBlock(torch.nn.Module):
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
attention_mask_bias: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
@ -279,12 +276,12 @@ class TransformerBlock(torch.nn.Module):
past_key_value=past_key_value,
pos_bias_type=pos_bias_type,
length_mask=length_mask,
context_mask=context_mask,
attention_mask_bias=attention_mask_bias,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
if use_cache:
hidden_states, current_key_value = hidden_states
else:
@ -297,4 +294,4 @@ class TransformerBlock(torch.nn.Module):
if use_cache:
return hidden_states, current_key_value
else:
return hidden_states
return hidden_states

View File

@ -1,14 +1,13 @@
import math
from typing import Optional
import bmtrain as bmt
import torch
import torch.nn.functional as F
from .position_embedding import RotaryEmbedding
class Embedding(bmt.DistributedModule):
class Embedding(torch.nn.Module):
def __init__(
self,
vocab_size: int,
@ -21,10 +20,7 @@ class Embedding(bmt.DistributedModule):
super().__init__()
self.dim_model = embedding_size
self.weight = bmt.DistributedParameter(
torch.empty(vocab_size, embedding_size, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
)
self.weight = torch.nn.parameter.Parameter(torch.empty(vocab_size, embedding_size, dtype=dtype))
self.scale = scale
def forward(self, ids: torch.Tensor):
@ -39,7 +35,7 @@ class Embedding(bmt.DistributedModule):
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
else:
embeds = F.embedding(ids, self.weight)
return embeds
return embeds.clone()
def projection(self, x: torch.Tensor):
"""
@ -56,7 +52,7 @@ class Embedding(bmt.DistributedModule):
return logits
class EmbeddingExt(bmt.DistributedModule):
class EmbeddingExt(torch.nn.Module):
def __init__(
self,
vocab_size: int,
@ -71,10 +67,10 @@ class EmbeddingExt(bmt.DistributedModule):
self.dim_model = embedding_size
self.rotary_emb = RotaryEmbedding(dim=embedding_size, distance_scale=distance_scale, dtype=dtype)
self.weight = bmt.DistributedParameter(
self.weight = torch.nn.parameter.Parameter(
torch.empty(vocab_size, embedding_size, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
)
torch.nn.init.normal_(self.weight, mean=init_mean, std=init_std)
def forward(self, ids: torch.Tensor, ids_sub: torch.Tensor):
"""
@ -101,69 +97,4 @@ class EmbeddingExt(bmt.DistributedModule):
if ext_table is not None:
logits_ext = F.linear(x, ext_table)
logits = torch.cat([logits, logits_ext], dim=-1)
return logits
class VocabParallelEmbedding(bmt.DistributedModule):
def __init__(
self,
vocab_size: int,
embedding_size: int,
dtype: torch.dtype = torch.half,
scale: bool = True,
init_mean: float = 0.0,
init_std: float = 1,
):
super().__init__()
self.dim_model = embedding_size
assert vocab_size % config["tp_size"] == 0
self.vocab_size_per_partition = vocab_size // config["tp_size"]
self.start_index = config["tp_rank"] * self.vocab_size_per_partition
self.end_index = (config["tp_rank"] + 1) * self.vocab_size_per_partition
self.weight = bmt.DistributedParameter(
torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
tp_split_dim=0,
tp_mode=True,
)
def forward(self, ids: torch.Tensor, gather_input=True):
"""
Args:
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
gather_input (bool) : whether gather input is required between tensor parallel group)
Return:
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
""" # noqa: E501
if gather_input:
ids = all_gather(ids, comm=config["tp_comm"])
input_mask = (ids < self.start_index) | (ids >= self.end_index)
ids = ids.clone() - self.start_index
ids[input_mask] = 0
embeds = F.embedding(ids, self.weight)
embeds[input_mask, :] = 0.0
embeds = all_reduce(embeds, op="sum", comm=config["tp_comm"])
embed_list = embeds.chunk(config["tp_size"], dim=0)
embeds = embed_list[config["tp_rank"]].flatten(0, 1)
if self.scale:
embeds = embeds / math.sqrt(self.dim_model)
return embeds
def projection(self, x: torch.Tensor, gather_output=False, gather_input=True):
"""
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
""" # noqa: E501
if self.scale:
x = x / math.sqrt(self.dim_model)
out = bmt.nn.OpParallelLinear.apply(x, self.weight, None, gather_input, gather_output, False, None)
return out
return logits

View File

@ -1,13 +1,11 @@
from typing import Optional
import bmtrain as bmt
import torch
from .linear import LastLinear
from .linear import Linear
class DenseGatedACT(bmt.DistributedModule):
class DenseGatedACT(torch.nn.Module):
def __init__(
self,
dim_in: int,
@ -15,7 +13,6 @@ class DenseGatedACT(bmt.DistributedModule):
activate_fn: str = "gelu",
scale: bool = True,
dtype=torch.half,
tp: int = 0,
):
super().__init__()
@ -25,7 +22,6 @@ class DenseGatedACT(bmt.DistributedModule):
dtype=dtype,
scale=scale,
scale_before=False,
tp=tp,
)
self.w_1 = Linear(
@ -34,9 +30,7 @@ class DenseGatedACT(bmt.DistributedModule):
dtype=dtype,
scale=scale,
scale_before=False,
tp=tp,
)
if activate_fn == "gelu":
self.act = torch.nn.GELU()
elif activate_fn == "silu":
@ -45,8 +39,7 @@ class DenseGatedACT(bmt.DistributedModule):
raise NotImplementedError(f"{activate_fn} is not supported")
def forward(self, x: torch.Tensor):
"""This model inherits from bmt.DistributedModule.
Transform an input tensor from one feature space to another via a nonlinear operation
"""Transform an input tensor from one feature space to another via a nonlinear operation
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): Tensor that will be subject to nonlinear operations.
@ -62,7 +55,7 @@ class DenseGatedACT(bmt.DistributedModule):
return x
class FeedForward(bmt.DistributedModule):
class FeedForward(torch.nn.Module):
r"""FeedForward module
Args:
@ -85,7 +78,6 @@ class FeedForward(bmt.DistributedModule):
dtype=torch.half,
dropout_p: Optional[float] = None,
scale: bool = True,
tp: int = 0,
):
super().__init__()
@ -95,7 +87,6 @@ class FeedForward(bmt.DistributedModule):
activate_fn=activate_fn,
dtype=dtype,
scale=scale,
tp=tp,
)
if dropout_p is not None:
@ -103,13 +94,12 @@ class FeedForward(bmt.DistributedModule):
else:
self.dropout = None
self.w_out = LastLinear(
self.w_out = Linear(
dim_in=dim_ff,
dim_out=dim_model,
dtype=dtype,
scale=scale,
scale_before=False,
tp=tp * 2,
)
def forward(self, x: torch.Tensor):
@ -127,4 +117,4 @@ class FeedForward(bmt.DistributedModule):
x = self.w_out(x)
return x
return x

View File

@ -1063,4 +1063,4 @@ class FlashAttnFunc(torch.autograd.Function):
return dq, dk, dv, None, None, None
flash_attn_func = FlashAttnFunc.apply
flash_attn_func = FlashAttnFunc.apply

View File

@ -1,8 +1,7 @@
import bmtrain as bmt
import torch
@torch.jit.script
@torch.jit.script # type: ignore
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
old_dtype = hidden.dtype
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
@ -10,21 +9,22 @@ def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
return hidden * weight
class LayerNorm(bmt.DistributedModule):
class LayerNorm(torch.nn.Module):
"""RMS LayerNorm"""
def __init__(
self,
dim_norm: int,
dtype: torch.dtype = torch.half,
eps: float = 1e-5,
eps: float = 1e-6,
init_var: float = 1.0,
):
super().__init__()
self.eps = eps
self.dim_norm = dim_norm
self.weight = bmt.DistributedParameter(torch.full((dim_norm,), init_var, dtype=dtype))
self.weight = torch.nn.parameter.Parameter(torch.full((dim_norm,), init_var, dtype=dtype))
def forward(self, x: torch.Tensor):
"""
@ -34,4 +34,4 @@ class LayerNorm(bmt.DistributedModule):
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``: The layernorm output.
""" # noqa: E501
assert x.size(-1) == self.dim_norm
return rms_layernorm(x, self.weight, self.eps)
return rms_layernorm(x, self.weight, self.eps)

View File

@ -0,0 +1,45 @@
import math
import torch
import torch.nn.functional as F
class Linear(torch.nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
bias: bool = False,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
scale: bool = True,
scale_before: bool = False,
):
super().__init__()
self.dim_in = self.in_features = dim_in
self.dim_out = self.out_features = dim_out
self.scale = scale
self.scale_before = scale_before
self.weight = torch.nn.parameter.Parameter(torch.empty((dim_out, dim_in), dtype=dtype))
self.bias = None
if bias:
self.bias = torch.nn.parameter.Parameter(torch.empty(dim_out, dtype=dtype))
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
""" # noqa: E501
if self.scale:
if self.scale_before:
x = x / math.sqrt(self.dim_in)
x = F.linear(x, self.weight, bias=self.bias)
else:
x = F.linear(x, self.weight, bias=self.bias)
x = x / math.sqrt(self.dim_in)
else:
x = F.linear(x, self.weight, bias=self.bias)
return x

View File

@ -2,17 +2,11 @@ import math
from typing import Tuple
from typing import Union
import bmtrain as bmt
import torch
import torch.nn.functional as F
try:
from flash_attn.layers.rotary import apply_rotary_emb_func
except:
apply_rotary_emb_func = None
class SegmentPositionEmbedding(bmt.DistributedModule):
class SegmentPositionEmbedding(torch.nn.Module):
def __init__(
self,
num_heads: int,
@ -32,10 +26,10 @@ class SegmentPositionEmbedding(bmt.DistributedModule):
self.bidirectional = bidirectional
self.num_segments = num_segments
self.relative_attention_bias = bmt.DistributedParameter(
torch.empty(num_segments * num_segments + num_buckets, num_heads, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
self.relative_attention_bias = torch.nn.parameter.Parameter(
torch.empty(num_segments * num_segments + num_buckets, num_heads, dtype=dtype)
)
torch.nn.init.normal_(self.relative_attention_bias, mean=init_mean, std=init_std)
def forward(
self,
@ -107,7 +101,7 @@ class SegmentPositionEmbedding(bmt.DistributedModule):
return relative_buckets
class BucketPositionBias(bmt.DistributedModule):
class BucketPositionBias(torch.nn.Module):
def __init__(
self,
num_heads: int,
@ -125,10 +119,10 @@ class BucketPositionBias(bmt.DistributedModule):
self.num_segment_bucket = num_segment_bucket
self.max_distance = max_distance
self.relative_attention_bias = bmt.DistributedParameter(
torch.empty(num_buckets + num_segment_bucket, num_heads, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
self.relative_attention_bias = torch.nn.parameter.Parameter(
torch.empty(num_buckets + num_segment_bucket, num_heads, dtype=dtype)
)
torch.nn.init.normal_(self.relative_attention_bias, mean=init_mean, std=init_std)
def forward(
self,
@ -186,11 +180,11 @@ class BucketPositionBias(bmt.DistributedModule):
return relative_buckets
class RotaryEmbedding(bmt.DistributedModule):
class RotaryEmbedding(torch.nn.Module):
def __init__(
self,
dim,
base: Union[int, float] = 10000,
base=10000,
distance_scale: Union[int, float] = 1,
dtype: torch.dtype = torch.half,
):
@ -226,19 +220,19 @@ def rotate_half(x):
def apply_rotary_pos_emb(x, cos, sin, seq_dim, offset):
if x.size(seq_dim) < cos.size(seq_dim): # == do not need narrow
if x.size(seq_dim) < cos.size(seq_dim):
cos = cos.narrow(seq_dim, offset, x.size(seq_dim))
sin = sin.narrow(seq_dim, offset, x.size(seq_dim))
return (x * cos) + (rotate_half(x) * sin)
def unpad_apply_rotary_pos_emb(x, cos, sin, seq_dim, position_ids):
cos = cos.index_select(seq_dim, position_ids.view(-1))
sin = sin.index_select(seq_dim, position_ids.view(-1))
cos = cos.index_select(seq_dim, position_ids.squeeze(0))
sin = sin.index_select(seq_dim, position_ids.squeeze(0))
return (x * cos) + (rotate_half(x) * sin)
class RotaryEmbeddingESM(bmt.DistributedModule):
class RotaryEmbeddingESM(torch.nn.Module):
"""
Rotary position embeddings based on those in
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
@ -274,7 +268,8 @@ class RotaryEmbeddingESM(bmt.DistributedModule):
self.apply_rotary_pos_emb = apply_rotary_pos_emb
self.unpad_apply_rotary_pos_emb = unpad_apply_rotary_pos_emb
def _update_cos_sin_tables(self, x, seq_dim, seq_len):
def _update_cos_sin_tables(self, x, seq_dim, offset):
seq_len = x.size(seq_dim) + offset
if seq_len > self._seq_len_cached or self._cos_cached.device != x.device:
self._seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
@ -295,19 +290,11 @@ class RotaryEmbeddingESM(bmt.DistributedModule):
self, q: torch.Tensor, k: torch.Tensor, seq_dim, offset=0, cu_seqlens=None, max_length=None, position_ids=None
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_dim = (seq_dim + k.dim()) % k.dim()
if cu_seqlens is None:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, k.size(seq_dim) + offset)
return (
self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, offset),
self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, offset),
)
else:
assert offset == 0, "past kv is not supported in flash attn"
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, max_length)
return (
self.unpad_apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, position_ids),
self.unpad_apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, position_ids),
)
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, offset)
return (
self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, offset),
self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, offset),
)
@torch.jit.script
@ -334,7 +321,7 @@ def apply_chatglm_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> t
return ret
class ChatGLMRotaryEmbedding(bmt.DistributedModule):
class ChatGLMRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, device="cuda", dtype=torch.float16, persistent=True):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=dtype, device=device) / dim))
@ -365,4 +352,4 @@ class ChatGLMRotaryEmbedding(bmt.DistributedModule):
return cache
def forward(self, max_seq_len, offset: int = 0):
return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)

View File

@ -2,14 +2,13 @@ from typing import List
from typing import Optional
from typing import Tuple
import bmtrain as bmt
import torch
from .blocks import TransformerBlock
from .layernorm import LayerNorm
class Encoder(bmt.DistributedModule):
class Encoder(torch.nn.Module):
"""Layers of encoder transformer blocks plus an final layernorm.
Args:
@ -19,7 +18,7 @@ class Encoder(bmt.DistributedModule):
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-6.
dropout_p (float, optional): Defaults to 0.
""" # noqa: E501
@ -33,49 +32,47 @@ class Encoder(bmt.DistributedModule):
num_kv_heads: int = -1,
activate_fn: str = "gelu",
dtype: torch.dtype = torch.half,
eps: float = 1e-5,
eps: float = 1e-6,
dropout_p: Optional[float] = None,
scale: bool = True,
add_qkv_bias: bool = False,
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
use_flash_attn: bool = False,
tp: int = 0,
disabled_checkpoint: Optional[int] = None,
):
super().__init__()
if num_kv_heads == -1:
num_kv_heads = num_heads
self.num_layers = num_layers
if mask_modules is not None:
assert len(mask_modules) == num_layers, "The total number of masks should equal to num_layers"
for mask_module in mask_modules:
assert len(mask_module) == 2, "For encoder, each mask should be (mask_att, mask_ffn)"
else:
mask_modules = [(False, False)] * num_layers
self.layers = bmt.TransformerBlockList(
self.layers = torch.nn.ModuleList(
[
bmt.CheckpointBlock(
TransformerBlock(
dim_model=dim_model,
dim_ff=dim_ff,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim_head=dim_head,
activate_fn=activate_fn,
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
scale=scale,
add_qkv_bias=add_qkv_bias,
mask_att=mask_modules[ith][0],
mask_ffn=mask_modules[ith][1],
use_flash_attn=use_flash_attn,
tp=tp,
),
TransformerBlock(
dim_model=dim_model,
dim_ff=dim_ff,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim_head=dim_head,
activate_fn=activate_fn,
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
scale=scale,
add_qkv_bias=add_qkv_bias,
mask_att=mask_modules[ith][0],
mask_ffn=mask_modules[ith][1],
use_flash_attn=use_flash_attn,
)
for ith in range(num_layers)
]
)
self.output_layernorm = LayerNorm(dim_norm=dim_model, dtype=dtype, eps=eps)
def forward(
@ -87,6 +84,7 @@ class Encoder(bmt.DistributedModule):
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
attention_mask_bias: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
@ -103,19 +101,21 @@ class Encoder(bmt.DistributedModule):
""" # noqa: E501
if not use_cache:
hidden_states = self.layers(
hidden_states,
attention_mask,
position_bias,
False,
None,
pos_bias_type,
length_mask,
attention_mask_bias,
cu_seqlens,
max_seqlen,
position_ids,
)
for i, module in enumerate(self.layers):
hidden_states = module(
hidden_states,
attention_mask,
position_bias,
False,
None,
pos_bias_type,
length_mask,
context_mask,
attention_mask_bias,
cu_seqlens,
max_seqlen,
position_ids,
)
hidden_states = self.output_layernorm(hidden_states)
return hidden_states
else:
@ -131,6 +131,7 @@ class Encoder(bmt.DistributedModule):
past_key_values[i] if past_key_values else None,
pos_bias_type,
length_mask,
context_mask,
attention_mask_bias,
)
if use_cache:
@ -141,4 +142,4 @@ class Encoder(bmt.DistributedModule):
if use_cache:
return hidden_states, current_key_values, current_hidden_states
else:
return hidden_states
return hidden_states

View File

Binary file not shown.

View File

@ -0,0 +1,237 @@
from typing import List
from typing import Optional
from typing import Tuple
import torch
from typing_extensions import TypedDict
from vis_fm9g.layers import Embedding
from vis_fm9g.layers import Encoder
from vis_fm9g.layers import RotaryEmbeddingESM
from vis_fm9g.utils.config import Config
class FM9GConfig(Config):
def __init__(
self,
vocab_size=32000,
dim_model=4096,
num_heads=32,
num_kv_heads=32,
dim_head=128,
dim_ff=11008,
num_layers=32,
dropout_p=0.0,
activate_fn="silu",
scale=True,
eps=1e-5,
bf16: bool = False,
half: bool = True,
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
use_flash_attn: bool = True,
flash_attn_mask_shape="1d",
flash_impl="cuda",
base=10000,
):
super().__init__()
self.vocab_size = vocab_size
self.dim_model = dim_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.dim_head = dim_head
self.dim_ff = dim_ff
self.num_layers = num_layers
self.dropout_p = dropout_p
self.activate_fn = activate_fn
self.scale = scale
self.eps = eps
if bf16:
self.dtype = torch.bfloat16
elif half:
self.dtype = torch.float16
else:
self.dtype = torch.float
self.flash_impl = flash_impl
self.mask_modules = mask_modules
self.use_flash_attn = use_flash_attn
self.flash_attn_mask_shape = flash_attn_mask_shape
self.base = base
class FM9GInferenceState(TypedDict):
buffer_context: torch.Tensor
buffer_sample_ids: torch.Tensor
buffer: List[Tuple[torch.Tensor, torch.Tensor]]
class FM9GTorch(torch.nn.Module):
def __init__(self, config: FM9GConfig):
super().__init__()
self.config = config
self.encoder = Encoder(
num_layers=config.num_layers,
dim_model=config.dim_model,
dim_ff=config.dim_ff,
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
dim_head=config.dim_head,
activate_fn=config.activate_fn,
dtype=config.dtype,
eps=config.eps,
dropout_p=config.dropout_p,
scale=config.scale,
mask_modules=config.mask_modules,
use_flash_attn=config.use_flash_attn,
)
self.input_embedding = Embedding(
vocab_size=config.vocab_size,
embedding_size=config.dim_model,
scale=config.scale,
dtype=config.dtype,
init_std=0.02,
)
self.position_bias = RotaryEmbeddingESM(
dim=config.dim_head, dtype=config.dtype, base=config.base, persistent=False, mixed_precision=True
)
self.lm_head = Embedding(
vocab_size=config.vocab_size,
embedding_size=config.dim_model,
scale=config.scale,
dtype=config.dtype,
init_std=0.02,
)
self.flash_impl = config.flash_impl
self.use_flash_attn = config.use_flash_attn
self.flash_attn_mask_shape = config.flash_attn_mask_shape
def forward(
self,
input: torch.Tensor, # (batch, seqlen) int32
length: torch.Tensor = None, # (batch) int32
context: torch.Tensor = None, # (batch, seqlen) bool
span: torch.Tensor = None, # (batch, seqlen) int32
cu_seqlens: torch.Tensor = None, # (real_batch+2) int32
max_seqlen: int = None,
position_ids: torch.Tensor = None, # (batch, seqlen) int32
hidden_states: torch.Tensor = None
):
batch = input.size(0)
seqlen = input.size(1)
device = input.device
if length is not None and length.dim() == 1:
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
# processing masks and position bias bucket
if not self.use_flash_attn or (self.flash_attn_mask_shape == "2d" and self.flash_impl == "triton"):
with torch.no_grad():
# directional mask
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
-1, 1
)
# context mask
attention_mask = context[:, None, :] | (
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
)
# span mask
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
# length mask
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
if hidden_states is None:
hidden_states = self.input_embedding(input)
if self.use_flash_attn:
if self.flash_attn_mask_shape == "1d":
hidden_states = self.encoder(
hidden_states,
attention_mask=None,
position_bias=self.position_bias,
pos_bias_type="rotary",
length_mask=length,
context_mask=context.to(torch.int16) + 2 * (span.to(torch.int16) + length.to(torch.int16)),
)
else:
if self.flash_impl == "triton":
mask = attention_mask.unsqueeze(dim=1).contiguous()
attention_mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16)
attention_mask_bias[mask == False] -= torch.inf
else:
attention_mask_bias = None
assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl"
hidden_states = self.encoder(
hidden_states,
attention_mask=None,
position_bias=self.position_bias,
pos_bias_type="rotary",
length_mask=None,
context_mask=None,
attention_mask_bias=attention_mask_bias,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
else:
hidden_states = self.encoder(
hidden_states, attention_mask=attention_mask, position_bias=self.position_bias, pos_bias_type="rotary"
)
logits = self.lm_head.projection(hidden_states)
return logits, hidden_states
def inference(
self,
input: torch.Tensor, # (batch, len_q) int32
length: torch.Tensor, # (batch) int32
context: torch.Tensor, # (batch, seqlen) int16
span: torch.Tensor, # (batch, seqlen) int32
past_key_values: Optional[FM9GInferenceState] = None,
hidden_states: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor, FM9GInferenceState]:
batch = input.size(0)
len_q = input.size(1)
len_buffer = 0
if past_key_values is None:
present_buffer = None
else:
present_buffer = past_key_values["buffer"]
len_buffer = present_buffer[0][0].shape[-2]
seqlen = len_buffer + len_q
with torch.no_grad():
device = input.device
if length.dim() == 1:
length = (torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) + length[:, None]) >= seqlen
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
# context mask
attention_mask = context[:, None, :] | (
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
)
# span mask
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
# length mask
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
if hidden_states is None:
hidden_states = self.input_embedding(input)
hidden_states, present_key_values, _ = self.encoder(
hidden_states,
attention_mask=attention_mask[:, len_buffer:],
position_bias=self.position_bias,
use_cache=True,
past_key_values=present_buffer,
pos_bias_type="rotary",
)
logits = self.lm_head.projection(hidden_states)
return (
logits,
hidden_states,
{"buffer": present_key_values},
)

View File

@ -0,0 +1,158 @@
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from functools import partial
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6)
):
super().__init__()
self.num_queries = grid_size ** 2
self.embed_dim = embed_dim
self.num_heads = num_heads
self.pos_embed = nn.Parameter(
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
).requires_grad_(False)
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, attn_mask=None):
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
x = out.permute(1, 0, 2)
x = self.ln_post(x)
x = x @ self.proj
return x
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)

View File

@ -0,0 +1,110 @@
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import timm
import torch
from transformers.utils import ModelOutput
from vis_fm9g.model.fm9g import FM9GTorch
from vis_fm9g.model.resampler import Resampler
@dataclass
class CausalVLLMOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class VLU_FM9G(torch.nn.Module):
def __init__(self, llm: FM9GTorch, vpm: timm.models.VisionTransformer, vision_dim, query_num) -> None:
super().__init__()
self.vpm = vpm
self.llm = llm
self.vision_dim = vision_dim
self.query_num = query_num
embed_dim = self.llm.config.dim_model
self.resampler = Resampler(
grid_size=int(math.sqrt(query_num)),
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=self.vpm.embed_dim,
)
def get_vision_embedding(self, pixel_values):
res = []
dtype = self.vpm.pos_embed.data.dtype
for pixel_value in pixel_values:
vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
if hasattr(self.vpm, 'num_prefix_tokens') and self.vpm.num_prefix_tokens > 0:
vision_embedding = vision_embedding[:, self.vpm.num_prefix_tokens:]
res.append(self.resampler(vision_embedding))
return torch.vstack(res)
def get_vllm_embedding(self, data):
if 'vision_hidden_states' not in data:
pixel_values_list = data['pixel_values']
vision_hidden_states = []
for pixel_values in pixel_values_list:
if len(pixel_values) > 0:
vision_hidden_states.append(self.get_vision_embedding(pixel_values))
else:
vision_hidden_states.append([])
else:
vision_hidden_states = data['vision_hidden_states']
vllm_embedding = self.llm.input_embedding(data['input_ids'])
vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
i, torch.Tensor) else i for i in vision_hidden_states]
bs = len(data['input_ids'])
for i in range(bs):
cur_vs_hs = vision_hidden_states[i]
if len(cur_vs_hs) > 0:
cur_vllm_emb = vllm_embedding[i]
cur_image_bound = data['image_bound'][i]
image_indices = torch.stack(
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
).to(vllm_embedding.device)
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
return vllm_embedding, vision_hidden_states
def forward(self, data, **kwargs):
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
input_ids = data["input_ids"]
input_length = data["length"]
if self.llm.config.flash_impl == 'cuda':
cu_seqlens = data["cu_seqlens"]
max_seqlen = data["max_seqlen"]
position_ids = data["position_ids"]
logits, hidden_states = self.llm(
input_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
hidden_states=vllm_embedding
)
else:
input_context = torch.zeros_like(input_ids).cuda().bool()
input_span = data["spans"]
logits, hidden_states = self.llm(
input_ids,
input_length,
input_context,
input_span,
hidden_states=vllm_embedding
)
return CausalVLLMOutput(
logits=logits,
hidden_states=hidden_states,
)

View File

View File

@ -1,12 +1,14 @@
import os
import io
import json
from typing import Dict
from typing import IO
from typing import List
import pkg_resources
from pytrie import StringTrie
from transformers import LlamaTokenizer
from transformers.tokenization_utils import Trie
file_path = os.path.dirname(__file__)
def load_vocab(fp: IO[bytes]) -> Dict[str, int]:
"""Loads a vocabulary file into a dictionary."""
@ -23,50 +25,66 @@ def load_vocab(fp: IO[bytes]) -> Dict[str, int]:
class FM9GTokenizer(object):
def __init__(self, path=None):
def __init__(self, vocabs_path=None):
self.unk_token = "<unk>"
self.bos_token = "<s>"
self.eos_token = "</s>"
self.im_start = "<image>"
self.im_end = "</image>"
self.ref_start = "<ref>"
self.ref_end = "</ref>"
self.box_start = "<quad>"
self.box_end = "</quad>"
self.quad_start = "<quad>"
self.quad_end = "</quad>"
self.tire = Trie()
self.byte_list = ["<0x0{}>".format(hex(i).upper()[2:]) for i in range(0x10)] + [
"<0x{}>".format(hex(i).upper()[2:]) for i in range(0x10, 0x100)
]
self._special_token_set = set([self.unk_token, self.bos_token, self.eos_token] + self.byte_list)
self._special_token_set = {self.unk_token, self.bos_token, self.eos_token,
self.im_start, self.im_end,
self.ref_start, self.ref_end,
self.box_start, self.box_end,
self.quad_start, self.quad_end}
if path:
all_tokens = load_vocab(io.FileIO(path, "rb"))
else:
all_tokens = load_vocab(pkg_resources.resource_stream("fm9g", "/fm9g/vocabs/fm9g.txt"))
self._byte_set = set(self.byte_list)
# never split special token
for t in self._special_token_set:
self.tire.add(t)
if not vocabs_path:
vocabs_path = os.path.join(file_path, "../config/caterpillar.txt")
all_tokens = load_vocab(io.FileIO(vocabs_path, "rb"))
self.encoder: Dict[str, int] = {}
self._special_encoder: Dict[str, int] = {}
for token, token_id in all_tokens.items():
if token in self._special_token_set:
self.encoder[token] = token_id
self._special_encoder[token] = token_id
elif token in self._byte_set:
self._special_encoder[token] = token_id
else:
self.encoder[token] = token_id
self.decoder = {v: k for k, v in self.encoder.items()}
self._byte_decoder = {self._special_encoder[token]: i for i, token in enumerate(self.byte_list)}
self._max_word_len = max([len(x) for x in self.encoder.keys()])
self._len_word_first = {}
for x in self.encoder.keys():
if not x[0] in self._len_word_first:
self._len_word_first[x[0]] = 1
if len(x) > self._len_word_first[x[0]]:
self._len_word_first[x[0]] = len(x)
self.tencoder = StringTrie(self.encoder)
def get_piece(self, text: str) -> str:
if text[0] in self._len_word_first:
text = text[: self._len_word_first[text[0]]]
len_text = len(text)
for i in range(len(text)):
sub = text[: len_text - i]
if sub in self.encoder:
return sub
text = text[: self._max_word_len]
len_text = len(text)
for i in range(len(text)):
sub = text[: len_text - i]
if sub in self.encoder:
return sub
return text[0]
@property
@ -85,16 +103,26 @@ class FM9GTokenizer(object):
def unk_id(self):
return self._special_encoder[self.unk_token]
@property
def im_start_id(self):
return self._special_encoder[self.im_start]
@property
def im_end_id(self):
return self._special_encoder[self.im_end]
def __len__(self):
return len(self.encoder) + len(self._special_encoder)
def tokenize(self, text: str) -> List[str]:
texts = self.tire.split(text)
output_tokens: List[str] = []
st = 0
while st < len(text):
piece = self.get_piece(text[st:])
output_tokens.append(piece)
st += len(piece)
for text in texts:
st = 0
while st < len(text):
piece = self.get_piece(text[st:])
output_tokens.append(piece)
st += len(piece)
return output_tokens
@staticmethod
@ -106,8 +134,6 @@ class FM9GTokenizer(object):
return text
def encode(self, text: str) -> List[int]:
#if len(text) > 20480:
# return [0 for _ in range(20480)]
ret = []
for x in self.tokenize(text):
if x in self.encoder:
@ -127,25 +153,6 @@ class FM9GTokenizer(object):
st += 1
elif tokens[st] in self._byte_decoder:
if (
st + 3 < len(tokens)
and tokens[st + 1] in self._byte_decoder
and tokens[st + 2] in self._byte_decoder
and tokens[st + 3] in self._byte_decoder
):
first_id = self._byte_decoder[tokens[st]]
plane_id = self._byte_decoder[tokens[st + 1]]
row_id = self._byte_decoder[tokens[st + 2]]
cell_id = self._byte_decoder[tokens[st + 3]]
int_bytes = int.to_bytes(first_id << 24 | plane_id << 16 | row_id << 8 | cell_id, 4, "big")
try:
decoded_str = int_bytes.decode("utf-8", errors="replace")
ret.append(decoded_str)
#print(decoded_str)
except UnicodeDecodeError as e:
print(f"UnicodeDecodeError: {e}")
st += 4
elif (
st + 2 < len(tokens)
and tokens[st + 1] in self._byte_decoder
and tokens[st + 2] in self._byte_decoder
@ -153,33 +160,16 @@ class FM9GTokenizer(object):
plane_id = self._byte_decoder[tokens[st]]
row_id = self._byte_decoder[tokens[st + 1]]
cell_id = self._byte_decoder[tokens[st + 2]]
int_bytes = int.to_bytes(plane_id << 16 | row_id << 8 | cell_id, 3, "big")
try:
decoded_str = int_bytes.decode("utf-8", errors="replace")
ret.append(decoded_str)
except UnicodeDecodeError as e:
print(f"UnicodeDecodeError: {e}")
ret.append(int.to_bytes(plane_id << 16 | row_id << 8 | cell_id, 3, "big").decode("utf-8"))
st += 3
elif st + 1 < len(tokens) and tokens[st + 1] in self._byte_decoder:
row_id = self._byte_decoder[tokens[st]]
cell_id = self._byte_decoder[tokens[st + 1]]
int_bytes = int.to_bytes(row_id << 8 | cell_id, 2, "big")
try:
decoded_str = int_bytes.decode("utf-8", errors="replace")
ret.append(decoded_str)
except UnicodeDecodeError as e:
print(f"UnicodeDecodeError: {e}")
#ret.append(int.to_bytes(row_id << 8 | cell_id, 2, "big").decode("utf-8"))
ret.append(int.to_bytes(row_id << 8 | cell_id, 2, "big").decode("utf-8"))
st += 2
else:
cell_id = self._byte_decoder[tokens[st]]
int_bytes = int.to_bytes(cell_id, 1, "big")
try:
decoded_str = int_bytes.decode("utf-8", errors="replace")
ret.append(decoded_str)
except UnicodeDecodeError as e:
print(f"UnicodeDecodeError: {e}")
#ret.append(int.to_bytes(cell_id, 1, "big").decode("utf-8"))
ret.append(int.to_bytes(cell_id, 1, "big").decode("utf-8"))
st += 1
elif tokens[st] == self.eos_id:
ret.append(self.eos_token)
@ -196,16 +186,53 @@ class FM9GTokenizer(object):
# wrap unicode encoding into a helper function
ids = []
utf8_id = token.encode("utf-8")
for _id in utf8_id:
ids.append(self._special_encoder[self.byte_list[_id]])
plane_id = utf8_id[-3] if len(utf8_id) >= 3 else 0
row_id = utf8_id[-2] if len(utf8_id) >= 2 else 0
cell_id = utf8_id[-1] if len(utf8_id) >= 1 else 0
if plane_id > 0:
ids.append(self._special_encoder[self.byte_list[plane_id]])
if row_id > 0:
ids.append(self._special_encoder[self.byte_list[row_id]])
ids.append(self._special_encoder[self.byte_list[cell_id]])
return ids
def next_token(self, text):
# fast next token matching
token, token_id = self.tencoder.longest_prefix_item(text, (None, None))
if token is None:
token = text[0]
token_ids = self._encode_unicode(token)
else:
token_ids = [token_id]
return token, token_ids
class LlamaTokenizerWrapper(LlamaTokenizer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.im_start = "<image>"
self.im_end = "</image>"
self.ref_start = "<ref>"
self.ref_end = "</ref>"
self.box_start = "<box>"
self.box_end = "</box>"
self.quad_start = "<quad>"
self.quad_end = "</quad>"
@property
def eos_id(self):
return self.sp_model.eos_id()
@property
def bos_id(self):
return self.sp_model.bos_id()
@property
def unk_id(self):
return self.sp_model.unk_id()
@property
def im_start_id(self):
return self._convert_token_to_id(self.im_start)
@property
def im_end_id(self):
return self._convert_token_to_id(self.im_end)
@staticmethod
def escape(text: str) -> str:
return text
@staticmethod
def unescape(text: str) -> str:
return text

View File

View File

@ -0,0 +1,56 @@
import json
import os
import shutil
import pandas as pd
import torch
import torch.distributed
from deepspeed.utils import logger
from utils.utils import is_main_process
def export(vllm_engine, cur_epoch_step, global_step, epoch, args):
if args.save_deepspeed:
logger.info(f'start to deepspped ckpt, save_dir={args.exp_ckpt_dir}')
vllm_engine.save_checkpoint(save_dir=args.exp_ckpt_dir, tag=f'global_step{global_step}', client_state={
'checkpoint_step': global_step, 'epoch': epoch, 'cur_epoch_step': cur_epoch_step})
export_model_dir = os.path.join(args.exp_ckpt_dir, f'{args.exp_name}_epoch_{epoch}_ckpt_{global_step}')
os.makedirs(export_model_dir, exist_ok=True)
base_file_name = f'{args.exp_name}_{cur_epoch_step}_{global_step}'
# model files
if is_main_process():
model_state_dict_path = os.path.join(export_model_dir, base_file_name + '.pt')
# config 和 vocabs 和模型文件一起存储
model_cfg_path = os.path.join(export_model_dir, 'config.json')
model_vocab_path = os.path.join(export_model_dir, 'vocabs.txt')
paths = [model_state_dict_path, model_cfg_path, model_vocab_path]
torch.save(vllm_engine.module.state_dict(), model_state_dict_path)
shutil.copy(args.llm_path, model_cfg_path)
shutil.copy(args.vocabs_path, model_vocab_path)
info = {
'global_step': global_step,
'epoch': epoch,
'cur_epoch_step': cur_epoch_step,
'last_ckpt': model_state_dict_path,
'config': model_cfg_path,
'vocab': model_vocab_path
}
with open(os.path.join(export_model_dir, 'lastest_info'), 'w') as f:
json.dump(info, f, indent=2)
logger.info(f'Successfully save model files! {paths}')
torch.distributed.barrier()
def export_eval_file(df, global_step, args):
export_dir = os.path.join(args.exp_ckpt_dir, 'pretrain_eval')
os.makedirs(export_dir, exist_ok=True)
base_file_name = f'{"_".join(args.model_checkpoint.split("/")[-2:])}_{args.vision_encoder.split("_")[0]}_{global_step}'
if is_main_process():
eval_result_path = os.path.join(export_dir, base_file_name + '.csv')
logger.info(f'save eval result file to {eval_result_path}')
df.to_csv(eval_result_path, index=False)

View File

@ -0,0 +1,212 @@
import os
import glob
import argparse
import numpy as np
from datetime import datetime
import torch
import torch.distributed as dist
from utils.logger import init_logger
logger = init_logger(__name__, level="INFO")
def get_args():
parser = argparse.ArgumentParser('VLLM pre-training script', add_help=False)
parser.add_argument('--self_dir', type=str)
parser.add_argument('--train_file', type=str)
parser.add_argument('--eval_file', type=str)
parser.add_argument('--test_file', type=str)
parser.add_argument('--exp_name', type=str, default='multimodal')
parser.add_argument('--batch_size', default=2, type=int)
parser.add_argument('--split_by_rank', action='store_true', default=False, help="split parquet by rank")
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--log_step', default=50, type=int)
parser.add_argument('--save_step', default=100, type=int)
parser.add_argument('--sft', action='store_true', help='is training all parameter')
parser.add_argument('--tune_vision', action='store_true', help='is train vision parameter')
parser.add_argument('--tune_resampler', action='store_true', help='is train resampler parameter')
parser.add_argument('--tune_llm', action='store_true', help='is train llm parameter')
parser.add_argument('--delta_tuning', action='store_true', help='is use lora')
# Model parameters
parser.add_argument('--img_size', default=224, type=int)
parser.add_argument('--llm_path', default=None, help='Path to LLM model to use', type=str)
parser.add_argument('--vocabs_path', default=None, help='Path to vocabs to use', type=str)
parser.add_argument('--model_checkpoint', default=None, help='Path to VLLM model to use', type=str)
parser.add_argument('--llm_checkpoint', default=None, help='Path to LLM model to use', type=str)
parser.add_argument('--data_state_dict_path', default=None, help='Path to dataset state dict', type=str)
parser.add_argument('--vpm_path', help='Path to VPM model to use', type=str)
parser.add_argument('--vpm_checkpoint', help='Path to VPM model to use', type=str)
parser.add_argument('--vision_encoder', default='eva02_enormous_patch14_clip_224.laion2b_plus',
choices=['eva02_enormous_patch14_clip_224.laion2b_plus', 'vit_so400m_patch14_siglip_384.webli'], type=str)
parser.add_argument('--drop_vision_last_layer', action='store_true', help='is drop last vit layer')
parser.add_argument('--prefix', default=None, help='Path prefix to save file', type=str)
# deepspeed
parser.add_argument('--deepspeed_config', default=None, help='Path to deepspeed config to use', type=str)
parser.add_argument('--save_deepspeed', action='store_true', default=False, help="is save deepspeed checkpoint")
# vlu
parser.add_argument('--skip_overlength', action='store_true', default=False, help="is skip over length data")
parser.add_argument('--skip_no_image', action='store_true', default=False, help="is skip no image data")
parser.add_argument("--flash", default="none", choices=["none", "1d", "triton", "cuda"])
parser.add_argument('--max_length', default=256, type=int, help='max length of input')
# ----- Training -----
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--query_num', default=32, type=int,
help='query numbers')
parser.add_argument('--max_len', default=96, type=int,
help='max len')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--start_epoch', default=0, type=int)
parser.add_argument('--start_step', default=0, type=int)
parser.add_argument('--skip_step', default=0, type=int)
parser.add_argument('--num_workers', default=5, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
help='')
parser.add_argument('--eval', action='store_true', default=False,
help="Perform evaluation only")
parser.add_argument('--eval_step', default=5000, type=int, help='evaluate step')
parser.add_argument('--save_dataset', action='store_true', default=False, help="is save dataset state dict")
parser.add_argument('--export_dir', default=None, help='Path to export', type=str)
# 项目训练中断后,标识是否基于之前训练中断时已保存的 deepspeed 参数(梯度、优化器等)续训
# 替代原有的 load_ckpt_dir,load_ckpt_tag
parser.add_argument('--need_resume', action='store_true', default=False,
help="resume with deepspeed states")
parser.add_argument('--need_resume_tag')
# ----- distributed training parameters -----
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
args = parser.parse_args()
# 1、文件保存/导出相关
args.tensorboard = '{base}/{timestamp}'.format(
base=os.path.join(args.export_dir, args.exp_name, 'tensorboard'), timestamp=datetime.now().strftime("%Y%m%d%H%M%S"))
args.exp_ckpt_dir= os.path.join(args.export_dir, args.exp_name)
os.makedirs(args.tensorboard, exist_ok=True)
# ----- repo 内路径相关参数 -----
# 模型 config从基准模型的 config 复制而来
if not args.llm_path:
args.llm_path = _check_default_path(os.path.join(args.self_dir, 'config/config.json'))
if not args.vocabs_path:
args.vocabs_path = _check_default_path(os.path.join(args.self_dir, 'config/vocabs.txt'))
if not args.deepspeed_config:
args.deepspeed_config = _check_default_path(os.path.join(args.self_dir, 'config/deepspeed.json'))
logger.info("get_args() done")
return args
def _check_default_path(path: str):
if os.path.exists(path):
return path
else:
return None
def _extract_ckpt_path(base_dir: str):
paths = glob.glob(base_dir + '/*.pt')
if len(paths) > 0:
return paths[0]
else:
logger.warning(f'.pt file not found in base_dir({base_dir})')
return None
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def init_distributed_mode(args):
if args.dist_on_itp:
logger.info('init_distributed_mode dist_on_itp')
args.rank = int(os.environ["RANK"])
args.world_size = os.environ["WORLD_SIZE"]
args.gpu = int(os.environ["LOCAL_RANK"])
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
logger.info('init_distributed_mode LOCAL_RANK')
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
logger.info('init_distributed_mode SLURM_PROCID')
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
logger.info('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
torch.set_num_threads(1)
torch.multiprocessing.set_sharing_strategy('file_system')
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def setup(args):
# init dist
init_distributed_mode(args)
rank = get_rank()
logger.info(f"rank={rank} init_distributed_mode done")
seed = args.seed + rank
torch.manual_seed(seed)
np.random.seed(seed)
logger.info(f"rank={rank} setup(args) done")

View File

@ -0,0 +1,397 @@
# import sys
import os
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
import math
import time
import gc
import pandas as pd
from copy import deepcopy
from typing import Dict, List
from timm.data.constants import *
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.transforms import transforms, InterpolationMode
from vis_fm9g.dataset.data import register_data_path
from utils import utils
from vis_fm9g.dataset.utils import SkipBatchSampler
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
from vis_fm9g.generation.vllm_fm9g import VLLMFM9GBeamSearch
import torch
import datetime
import deepspeed
import timm
import torch.distributed
import torch.utils.data
from vis_fm9g.dataset.itembuilder import FM9GImageTextBuilder, FM9GCollater
from vis_fm9g.dataset.datasets import SingleDataSourceDataset, MultiDataSourceDataset
from vis_fm9g.model.vlu_fm9g import VLU_FM9G
from vis_fm9g.model.fm9g import FM9GConfig, FM9GTorch
from deepspeed.utils import logger
from vis_fm9g.train import exporter, initializer
from vis_fm9g.utils.constants import bot_indicator
import safetensors
from safetensors.torch import load_file
def collect_statsd_metric(name, time_monitor):
time_monitor[name] = time.time()
return time_monitor
def convert_data_to_cuda(data: Dict):
for k, v in data.items():
if isinstance(v, torch.Tensor):
data[k] = data[k].cuda()
if isinstance(v, List) and len(v) > 0 and isinstance(v[0], torch.Tensor):
for i in range(len(v)):
v[i] = v[i].cuda()
return data
def create_multi_data_source_dataset(file, item_builder):
with open(file) as f:
data_list = json.load(f)
data_source_names = [i['data_source_name'] for i in data_list]
data_source_weights = [i['data_source_weight'] for i in data_list]
ds_list = []
for name in data_source_names:
if os.path.isdir(name):
ds = SingleDataSourceDataset(name, item_builder, *register_data_path[name](name))
else:
ds = SingleDataSourceDataset(name, item_builder, *register_data_path[name]())
ds_list.append(ds)
if len(ds_list) > 1:
ds = MultiDataSourceDataset(ds_list, data_source_weights)
return ds
def get_transform(args):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
transform = transforms.Compose([
transforms.Resize((args.img_size, args.img_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
return transform
def get_dataloader(tokenizer, data_file, args, is_training=True):
transform = get_transform(args)
builder = FM9GImageTextBuilder(
tokenizer=tokenizer,
max_len=args.max_len,
transform=transform,
query_len=args.query_num,
min_resolution=0,
skip_overlength=args.skip_overlength
)
dataset = create_multi_data_source_dataset(data_file, builder)
datasampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=args.rank)
if args.skip_step > 0:
datasampler = SkipBatchSampler(datasampler, args.skip_step)
unpad = (args.flash == 'cuda') and is_training
dataloader = DataLoader(
dataset,
sampler=datasampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=2,
collate_fn=FM9GCollater(tokenizer=tokenizer, max_len=args.max_len, unpad=unpad)
)
return dataloader
def load_llm_tokenizer(args):
return FM9GTokenizer(args.vocabs_path)
def train(vllm_model, args):
vllm_model.train()
if not args.tune_vision:
vllm_model.vpm.requires_grad_(False)
if not args.tune_resampler:
vllm_model.resampler.requires_grad_(False)
if not args.tune_llm:
vllm_model.llm.requires_grad_(False)
if args.drop_vision_last_layer:
vllm_model.norm.requires_grad_(True)
vllm_engine, vllm_optim, _, _ = deepspeed.initialize(
args=args, model=vllm_model, model_parameters=vllm_model.parameters()
)
torch.cuda.synchronize()
logger.info(f'rank={utils.get_rank()} load model successful')
tokenizer = load_llm_tokenizer(args)
dataloader_train = get_dataloader(tokenizer, data_file=args.train_file, args=args, is_training=True)
if args.eval and args.eval_file:
dataloader_eval = get_dataloader(tokenizer, data_file=args.eval_file, args=args, is_training=False)
else:
dataloader_eval = None
logger.info(f'rank={utils.get_rank()} load dataloader successful')
global_step = args.start_step
log_loss = 0
if args.need_resume:
load_path, client_state = vllm_engine.load_checkpoint(
args.exp_ckpt_dir, tag=args.need_resume_tag)
logger.info(f'Load pre-trained checkpoint from {load_path}, states: {client_state}')
global_step = client_state['checkpoint_step']
args.start_epoch = client_state.get('epoch', args.start_epoch)
logger.info(f'rank={utils.get_rank()} load grad successful')
# init tensorboard writer
if args.tensorboard is not None and utils.is_main_process():
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=args.tensorboard)
else:
writer = None
loss_fct = CrossEntropyLoss(reduction='mean', ignore_index=-100)
for epoch in range(args.start_epoch, args.epochs):
dataloader_train.sampler.set_epoch(epoch)
logger.info(f'start epoch={epoch}')
time_monitor = {}
collect_statsd_metric("init", time_monitor)
for step, batch in enumerate(dataloader_train):
batch = convert_data_to_cuda(batch)
collect_statsd_metric('dataload', time_monitor)
vllm_model.zero_grad()
output = vllm_model(data=batch)
logits = output.logits.view(-1, output.logits.shape[-1]).contiguous()
target = batch['target'].view(-1).type(torch.long).contiguous()
loss = loss_fct(logits, target)
# logger.info(f'epoch={epoch}, logits={logits}, target={target}, loss={loss}')
collect_statsd_metric("forward", time_monitor)
vllm_engine.backward(loss)
collect_statsd_metric("backward", time_monitor)
vllm_engine.step()
collect_statsd_metric("optim", time_monitor)
cost_info = f'dataload cost={(time_monitor["dataload"] - time_monitor["init"]): .2f} ' \
+ f'forward cost={(time_monitor["forward"] - time_monitor["dataload"]): .2f} ' \
+ f'backward cost={(time_monitor["backward"] - time_monitor["forward"]): .2f} ' \
+ f'optim cost={(time_monitor["optim"] - time_monitor["backward"]): .2f}'
log_loss += loss.item()
global_step += 1
if args.tensorboard is not None and utils.is_main_process():
writer.add_scalar("Loss/train", loss.item(), global_step)
if global_step % args.log_step == 0:
log_loss = utils.mean(utils.all_gather(log_loss))
if utils.is_main_process():
logger.info(
f'Datetime: {datetime.datetime.now()} Step: {global_step-args.log_step} - {global_step}: loss: {log_loss/args.log_step: .4f}')
logger.info(f'time cost info {cost_info}')
log_loss = 0
if global_step % args.save_step == 0:
exporter.export(vllm_engine, step, global_step, epoch, args)
# end step
collect_statsd_metric('init', time_monitor)
if args.eval and global_step % args.eval_step == 0:
evaluate(vllm_model, tokenizer, dataloader_eval, global_step, args)
vllm_model.train()
# 最终模型
exporter.export(vllm_engine, 0, global_step, args.epochs-1, args)
if args.eval:
evaluate(vllm_model, tokenizer, dataloader_eval, global_step, args)
def evaluate(vllm_model, tokenizer, dataloader_eval, global_step, args):
vllm_model.eval()
torch.cuda.empty_cache()
config = deepcopy(vllm_model.llm.config)
# 推理不用 flash_attn, 新初始化一个 model
config.use_flash_attn = False
llm = FM9GTorch(config)
vpm = load_vpm(args)
vision_dim = vpm.embed_dim
eval_model = VLU_FM9G(llm, vpm, vision_dim, args.query_num)
eval_model.eval().cuda()
eval_model.load_state_dict(vllm_model.state_dict())
torch.cuda.synchronize()
logger.info(f'rank={utils.get_rank()} start to eval')
transform = get_transform(args)
beam_search = VLLMFM9GBeamSearch(eval_model, tokenizer, transform)
results = []
for step, batch in enumerate(dataloader_eval):
batch = convert_data_to_cuda(batch)
for piexl_values, raw, source in zip(batch['pixel_values'], deepcopy(batch['raw_data']), batch['source']):
raw = raw[3:-4] # 去 <s> </s>
last_idx = raw.rfind(bot_indicator) + len(bot_indicator)
data_list = [{'input': raw[:last_idx]}]
gt = raw[last_idx:]
with torch.inference_mode():
res = beam_search.generate(
img_list=[piexl_values],
data_list=data_list,
use_transform=False,
beam_size=1,
max_length=1,
)
results.append(
{
'y_pred': res[0],
'y_true': gt,
'source': source
}
)
compose_results = utils.all_gather(results)
compose_results_flatten = []
for r in compose_results:
compose_results_flatten.extend(r)
df = pd.DataFrame.from_dict(compose_results_flatten)
exporter.export_eval_file(df, global_step=global_step, args=args)
source = sorted(list(set(df['source'])))
for ds in source:
ds_df = df[df['source'] == ds]
print(ds, '%.3f' % (sum(ds_df['y_pred'] == ds_df['y_true']) / len(ds_df)), len(ds_df))
print('avg', '%.3f' % (sum(df['y_pred'] == df['y_true']) / len(df)), len(df))
del eval_model
gc.collect()
torch.cuda.empty_cache()
vllm_model.train()
def load_llm(args):
config = FM9GConfig.from_json_file(args.llm_path)
if args.flash == "none":
config.use_flash_attn = False
else:
config.use_flash_attn = True
if args.flash == "1d":
config.flash_attn_mask_shape = "1d"
else:
config.flash_attn_mask_shape = "2d"
if args.flash == "triton":
config.flash_impl = "triton"
elif args.flash == "cuda":
config.flash_impl = "cuda"
cpm_model = FM9GTorch(config)
return cpm_model
def load_vpm(args):
model = timm.create_model(
args.vision_encoder,
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True
)
if isinstance(model, timm.models.VisionTransformer):
if model.attn_pool is not None:
model.attn_pool = torch.nn.Identity()
if args.drop_vision_last_layer:
model.blocks[-1] = torch.nn.Identity()
return model
def load_sharded_safetensors(sharded_folder):
safetensors_files = sorted([f for f in os.listdir(sharded_folder) if f.endswith('.safetensors')])
if not safetensors_files:
raise FileNotFoundError(f"No safetensors files found in {sharded_folder}")
merged_state_dict = {}
for file_name in safetensors_files:
file_path = os.path.join(sharded_folder, file_name)
logger.info(f"Loading safetensors shard from {file_path}")
state_dict = load_file(file_path)
merged_state_dict.update(state_dict)
return merged_state_dict
def setup_model(args):
start = time.time()
llm = load_llm(args)
vpm = load_vpm(args)
vision_dim = vpm.embed_dim
model = VLU_FM9G(llm, vpm, vision_dim, args.query_num)
if args.model_checkpoint:
logger.info(f'load model_checkpoint from {args.model_checkpoint}')
model_checkpoint = args.model_checkpoint
file_extension = os.path.splitext(model_checkpoint)[1]
if file_extension == '.safetensors':
logger.info(f"Loading safetensors checkpoint from {model_checkpoint}")
state_dict = load_file(model_checkpoint)
info = model.load_state_dict(state_dict, strict=True)
logger.info(f"Loaded checkpoint info={info}")
elif os.path.isdir(model_checkpoint):
logger.info(f"Loading safetensors from sharded folder: {model_checkpoint}")
state_dict = load_sharded_safetensors(model_checkpoint)
info = model.load_state_dict(state_dict, strict=True)
logger.info(f"Loaded checkpoint info={info}")
else:
state_dict = torch.load(args.model_checkpoint, map_location='cpu')
info = model.load_state_dict(state_dict, strict=True)
logger.info(f"load checkpoint info={info}" )
del state_dict
gc.collect()
model.cuda()
torch.cuda.empty_cache()
return model
def main():
args = initializer.get_args()
# setup file and device
initializer.setup(args)
# load model
model = setup_model(args)
# train
train(model, args)
if __name__ == '__main__':
main()

View File

View File

@ -1,11 +1,7 @@
import copy
import json
import os
from typing import Any
from typing import Dict
from typing import Union
from .log import logger
import copy
from typing import Any, Dict, Union
class Config(object):

View File

@ -0,0 +1,6 @@
from datetime import datetime
current_time = datetime.now().strftime("%Y年%m月%d")
usr_indicator = "<用户>"
bot_indicator = "<AI>"
SYSTEM = f"{usr_indicator}你叫九格,是由启元实验室研发的多模态大型语言模型。\n你的知识库截止至2022年4月当前时间是{current_time}"

View File

@ -1,4 +0,0 @@
# !/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright @2024, QiYuan Inc

Some files were not shown because too many files have changed in this diff Show More