| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch.nn as nn |
| import re |
|
|
|
|
| class IdentityMap(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x, *args, **kwargs): |
| return x |
|
|
| @property |
| def config(self): |
| return {"mm_projector_type": 'identity'} |
|
|
|
|
| class SimpleResBlock(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.pre_norm = nn.LayerNorm(channels) |
|
|
| self.proj = nn.Sequential( |
| nn.Linear(channels, channels), |
| nn.GELU(), |
| nn.Linear(channels, channels) |
| ) |
| def forward(self, x): |
| x = self.pre_norm(x) |
| return x + self.proj(x) |
|
|
|
|
| def build_vision_projector(mm_hidden_size=1024, hidden_size=4096, projector_type="mlp2x_gelu"): |
| mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) |
| if mlp_gelu_match: |
| mlp_depth = int(mlp_gelu_match.group(1)) |
| modules = [nn.Linear(mm_hidden_size, hidden_size)] |
| for _ in range(1, mlp_depth): |
| modules.append(nn.GELU()) |
| modules.append(nn.Linear(hidden_size, hidden_size)) |
| return nn.Sequential(*modules) |
|
|
| if projector_type == 'identity': |
| return IdentityMap() |
|
|
| raise ValueError(f'Unknown projector type: {projector_type}') |
|
|