| """ |
| Hybrid CNN-ViT Food Classifier |
| Combines ResNet50 and DeiT-Base with adaptive fusion |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, Any, Optional |
|
|
| from .cnn_branch import CNNBranch |
| from .vit_branch import ViTBranch |
| from .fusion_module import AdaptiveFusionModule |
|
|
| class HybridFoodClassifier(nn.Module): |
| """Hybrid CNN-ViT model for food classification""" |
| |
| def __init__( |
| self, |
| num_classes: int = 101, |
| feature_dim: int = 768, |
| hidden_dim: int = 512, |
| dropout: float = 0.2, |
| pretrained: bool = True, |
| freeze_early_layers: bool = True |
| ): |
| super(HybridFoodClassifier, self).__init__() |
| |
| self.num_classes = num_classes |
| self.feature_dim = feature_dim |
| self.hidden_dim = hidden_dim |
| |
| |
| self.cnn_branch = CNNBranch( |
| pretrained=pretrained, |
| freeze_early_layers=freeze_early_layers, |
| dropout=dropout, |
| feature_dim=feature_dim |
| ) |
| |
| |
| self.vit_branch = ViTBranch( |
| pretrained=pretrained, |
| freeze_early_layers=freeze_early_layers, |
| dropout=dropout, |
| feature_dim=feature_dim |
| ) |
| |
| |
| self.fusion_module = AdaptiveFusionModule( |
| feature_dim=feature_dim, |
| hidden_dim=hidden_dim, |
| dropout=dropout |
| ) |
| |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.LayerNorm(hidden_dim // 2), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim // 2, num_classes) |
| ) |
| |
| |
| self.cnn_aux_classifier = nn.Sequential( |
| nn.Linear(feature_dim, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim // 2, num_classes) |
| ) |
| |
| self.vit_aux_classifier = nn.Sequential( |
| nn.Linear(feature_dim, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim // 2, num_classes) |
| ) |
| |
| |
| self._initialize_weights() |
| |
| def _initialize_weights(self): |
| """Initialize classifier weights""" |
| for m in [self.classifier, self.cnn_aux_classifier, self.vit_aux_classifier]: |
| for layer in m: |
| if isinstance(layer, nn.Linear): |
| nn.init.xavier_uniform_(layer.weight) |
| if layer.bias is not None: |
| nn.init.constant_(layer.bias, 0) |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| return_features: bool = False, |
| use_aux_loss: bool = True |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass |
| |
| Args: |
| x: Input tensor [B, 3, H, W] |
| return_features: Whether to return intermediate features |
| use_aux_loss: Whether to compute auxiliary losses |
| |
| Returns: |
| Dictionary containing logits and optionally features/aux_logits |
| """ |
| |
| cnn_spatial, cnn_global = self.cnn_branch(x) |
| |
| |
| vit_spatial, vit_global = self.vit_branch(x) |
| |
| |
| fused_spatial, fused_global = self.fusion_module( |
| cnn_spatial, cnn_global, vit_spatial, vit_global |
| ) |
| |
| |
| logits = self.classifier(fused_global) |
| |
| |
| output = {'logits': logits} |
| |
| |
| if use_aux_loss and self.training: |
| cnn_aux_logits = self.cnn_aux_classifier(cnn_global) |
| vit_aux_logits = self.vit_aux_classifier(vit_global) |
| output.update({ |
| 'cnn_aux_logits': cnn_aux_logits, |
| 'vit_aux_logits': vit_aux_logits |
| }) |
| |
| |
| if return_features: |
| output.update({ |
| 'cnn_spatial': cnn_spatial, |
| 'cnn_global': cnn_global, |
| 'vit_spatial': vit_spatial, |
| 'vit_global': vit_global, |
| 'fused_spatial': fused_spatial, |
| 'fused_global': fused_global |
| }) |
| |
| return output |
| |
| def get_attention_maps(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """Get attention maps for visualization""" |
| with torch.no_grad(): |
| |
| output = self.forward(x, return_features=True, use_aux_loss=False) |
| |
| |
| cnn_spatial = output['cnn_spatial'] |
| cnn_attention = torch.mean(cnn_spatial, dim=1, keepdim=True) |
| cnn_attention = F.interpolate( |
| cnn_attention, |
| size=(224, 224), |
| mode='bilinear', |
| align_corners=False |
| ) |
| |
| |
| vit_spatial = output['vit_spatial'] |
| vit_patches = vit_spatial[:, 1:] |
| vit_attention = torch.mean(vit_patches, dim=-1) |
| vit_attention = vit_attention.view(-1, 14, 14).unsqueeze(1) |
| vit_attention = F.interpolate( |
| vit_attention, |
| size=(224, 224), |
| mode='bilinear', |
| align_corners=False |
| ) |
| |
| return { |
| 'cnn_attention': cnn_attention, |
| 'vit_attention': vit_attention |
| } |
| |
| def freeze_backbone(self): |
| """Freeze backbone networks""" |
| for param in self.cnn_branch.backbone.parameters(): |
| param.requires_grad = False |
| for param in self.vit_branch.vit.parameters(): |
| param.requires_grad = False |
| |
| def unfreeze_backbone(self): |
| """Unfreeze backbone networks""" |
| for param in self.cnn_branch.backbone.parameters(): |
| param.requires_grad = True |
| for param in self.vit_branch.vit.parameters(): |
| param.requires_grad = True |
| |
| def get_model_size(self) -> Dict[str, int]: |
| """Get model size information""" |
| total_params = sum(p.numel() for p in self.parameters()) |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| |
| cnn_params = sum(p.numel() for p in self.cnn_branch.parameters()) |
| vit_params = sum(p.numel() for p in self.vit_branch.parameters()) |
| fusion_params = sum(p.numel() for p in self.fusion_module.parameters()) |
| classifier_params = sum(p.numel() for p in self.classifier.parameters()) |
| |
| return { |
| 'total_params': total_params, |
| 'trainable_params': trainable_params, |
| 'cnn_params': cnn_params, |
| 'vit_params': vit_params, |
| 'fusion_params': fusion_params, |
| 'classifier_params': classifier_params |
| } |