| """ |
| Vision Transformer Branch for Hybrid Food Classifier |
| Uses DeiT-Base as backbone with custom head |
| """ |
| import torch |
| import torch.nn as nn |
| from transformers import DeiTModel, DeiTConfig |
| from typing import Tuple |
|
|
| class ViTBranch(nn.Module): |
| """Vision Transformer branch using DeiT-Base""" |
| |
| def __init__( |
| self, |
| model_name: str = "facebook/deit-base-distilled-patch16-224", |
| pretrained: bool = True, |
| freeze_early_layers: bool = True, |
| dropout: float = 0.1, |
| feature_dim: int = 768 |
| ): |
| super(ViTBranch, self).__init__() |
| |
| self.feature_dim = feature_dim |
| |
| |
| if pretrained: |
| self.vit = DeiTModel.from_pretrained(model_name) |
| else: |
| config = DeiTConfig.from_pretrained(model_name) |
| self.vit = DeiTModel(config) |
| |
| |
| self.hidden_size = self.vit.config.hidden_size |
| self.num_patches = (224 // 16) ** 2 |
| |
| |
| if freeze_early_layers: |
| self._freeze_early_layers() |
| |
| |
| self.feature_proj = nn.Sequential( |
| nn.Linear(self.hidden_size, feature_dim), |
| nn.LayerNorm(feature_dim), |
| nn.GELU(), |
| nn.Dropout(dropout) |
| ) |
| |
| |
| self.spatial_proj = nn.Sequential( |
| nn.Linear(self.hidden_size, feature_dim), |
| nn.LayerNorm(feature_dim), |
| nn.GELU(), |
| nn.Dropout(dropout) |
| ) |
| |
| |
| self.feature_head = nn.Sequential( |
| nn.Linear(feature_dim, feature_dim), |
| nn.LayerNorm(feature_dim), |
| nn.GELU(), |
| nn.Dropout(dropout) |
| ) |
| |
| def _freeze_early_layers(self): |
| """Freeze early layers of the ViT""" |
| |
| layers_to_freeze = 8 |
| for i, layer in enumerate(self.vit.encoder.layer): |
| if i < layers_to_freeze: |
| for param in layer.parameters(): |
| param.requires_grad = False |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Forward pass |
| |
| Args: |
| x: Input tensor [B, 3, H, W] |
| |
| Returns: |
| spatial_features: Patch features [B, num_patches, feature_dim] |
| global_features: CLS token features [B, feature_dim] |
| """ |
| |
| outputs = self.vit(pixel_values=x) |
| |
| |
| last_hidden_states = outputs.last_hidden_state |
| |
| |
| cls_token = last_hidden_states[:, 0] |
| |
| |
| patch_tokens = last_hidden_states[:, 1:] |
| |
| |
| global_features = self.feature_proj(cls_token) |
| spatial_features = self.spatial_proj(patch_tokens) |
| |
| |
| global_features = self.feature_head(global_features) |
| |
| return spatial_features, global_features |
| |
| def get_feature_dim(self) -> int: |
| """Get feature dimension""" |
| return self.feature_dim |
| |
| def get_num_patches(self) -> int: |
| """Get number of patches""" |
| return self.num_patches |