support pretraining of llava

This commit is contained in:
BUAADreamer 2024-05-21 08:57:14 +08:00
parent 2a67457e39
commit 29a6d5bdb8
4 changed files with 115 additions and 0 deletions

View File

@ -38,6 +38,20 @@
"assistant_tag": "assistant"
}
},
"mllm_pt_demo": {
"file_name": "mllm_pt_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en"

92
data/mllm_pt_demo.json Normal file
View File

@ -0,0 +1,92 @@
[
{
"messages": [
{
"content": "Render a clear and concise summary of the photo.",
"role": "user"
},
{
"content": "There are two soccer players on the field.",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/1.jpg"
]
},
{
"messages": [
{
"content": "Write a terse but informative summary of the picture.",
"role": "user"
},
{
"content": "A soccer player is sliding on his knees to celebrate",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/2.jpg"
]
},
{
"messages": [
{
"content": "What is this?",
"role": "user"
},
{
"content": "A man is giving a speech.",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/3.jpg"
]
},
{
"messages": [
{
"content": "对照片进行简明扼要的概括。",
"role": "user"
},
{
"content": "两个足球运动员在场上",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/1.jpg"
]
},
{
"messages": [
{
"content": "为图片写一个简短但内容丰富的摘要。",
"role": "user"
},
{
"content": "一个足球运动员在跪地滑行庆祝",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/2.jpg"
]
},
{
"messages": [
{
"content": "这是什么?",
"role": "user"
},
{
"content": "一个男人在演讲",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/3.jpg"
]
}
]

View File

@ -85,6 +85,10 @@ class ModelArguments:
default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
)
tune_mm_proj: bool = field(
default=False,
metadata={"help": "Whethor or not only finetune mm_projector for MLLM."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},

View File

@ -163,6 +163,11 @@ def load_model(
else:
model.train()
if model_args.visual_inputs and model_args.tune_mm_proj:
lm_params = [param for name, param in model.named_parameters() if "language_model" in name]
for param in lm_params:
param.requires_grad_(False)
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(