add support for Yi-VL

This commit is contained in:
BUAADreamer 2024-05-14 14:03:19 +08:00
parent 45654ebedb
commit ab3464ce65
1 changed files with 5 additions and 2 deletions

View File

@ -43,10 +43,13 @@ class LlavaMultiModalProjectorYiVL(nn.Module):
self.linear_3 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_4 = nn.LayerNorm(config.text_config.hidden_size, bias=True)
self.act = nn.GELU()
self.proj = nn.Sequential(*[self.linear_1, self.linear_2, self.act, self.linear_3, self.linear_4])
def forward(self, image_features):
hidden_states = self.proj(image_features)
hidden_states = self.linear_1(image_features)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_3(hidden_states)
hidden_states = self.linear_4(hidden_states)
return hidden_states