| """ |
| CNN Branch for Hybrid Food Classifier |
| Uses ResNet50 as backbone with adaptive pooling |
| """ |
| import torch |
| import torch.nn as nn |
| import torchvision.models as models |
| from typing import Tuple |
|
|
| class CNNBranch(nn.Module): |
| """CNN branch using ResNet50 backbone""" |
| |
| def __init__( |
| self, |
| backbone: str = "resnet50", |
| pretrained: bool = True, |
| freeze_early_layers: bool = True, |
| dropout: float = 0.3, |
| feature_dim: int = 2048 |
| ): |
| super(CNNBranch, self).__init__() |
| |
| self.feature_dim = feature_dim |
| |
| |
| if backbone == "resnet50": |
| self.backbone = models.resnet50(pretrained=pretrained) |
| |
| self.backbone = nn.Sequential(*list(self.backbone.children())[:-2]) |
| backbone_dim = 2048 |
| else: |
| raise ValueError(f"Unsupported backbone: {backbone}") |
| |
| |
| if freeze_early_layers: |
| self._freeze_early_layers() |
| |
| |
| self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7)) |
| |
| |
| self.feature_proj = nn.Sequential( |
| nn.Conv2d(backbone_dim, feature_dim, kernel_size=1), |
| nn.BatchNorm2d(feature_dim), |
| nn.ReLU(inplace=True), |
| nn.Dropout2d(dropout) |
| ) |
| |
| |
| self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) |
| |
| |
| self.feature_head = nn.Sequential( |
| nn.Linear(feature_dim, feature_dim), |
| nn.BatchNorm1d(feature_dim), |
| nn.ReLU(inplace=True), |
| nn.Dropout(dropout) |
| ) |
| |
| def _freeze_early_layers(self): |
| """Freeze early layers of the backbone""" |
| |
| layers_to_freeze = 6 |
| for i, child in enumerate(self.backbone.children()): |
| if i < layers_to_freeze: |
| for param in child.parameters(): |
| param.requires_grad = False |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Forward pass |
| |
| Args: |
| x: Input tensor [B, 3, H, W] |
| |
| Returns: |
| spatial_features: Spatial features [B, feature_dim, 7, 7] |
| global_features: Global features [B, feature_dim] |
| """ |
| |
| features = self.backbone(x) |
| |
| |
| features = self.adaptive_pool(features) |
| |
| |
| spatial_features = self.feature_proj(features) |
| |
| |
| global_features = self.global_pool(spatial_features) |
| global_features = global_features.flatten(1) |
| |
| |
| global_features = self.feature_head(global_features) |
| |
| return spatial_features, global_features |
| |
| def get_feature_dim(self) -> int: |
| """Get feature dimension""" |
| return self.feature_dim |