| """ |
| Adaptive Fusion Module for Hybrid Food Classifier |
| Combines CNN and ViT features using cross-attention mechanism |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Tuple |
|
|
| class AdaptiveFusionModule(nn.Module): |
| """Adaptive fusion module with cross-attention""" |
| |
| def __init__( |
| self, |
| feature_dim: int = 768, |
| hidden_dim: int = 512, |
| num_heads: int = 8, |
| dropout: float = 0.2, |
| spatial_size: int = 7 |
| ): |
| super(AdaptiveFusionModule, self).__init__() |
| |
| self.feature_dim = feature_dim |
| self.hidden_dim = hidden_dim |
| self.num_heads = num_heads |
| self.spatial_size = spatial_size |
| |
| |
| self.cnn_to_vit_attention = nn.MultiheadAttention( |
| embed_dim=feature_dim, |
| num_heads=num_heads, |
| dropout=dropout, |
| batch_first=True |
| ) |
| |
| |
| self.vit_to_cnn_attention = nn.MultiheadAttention( |
| embed_dim=feature_dim, |
| num_heads=num_heads, |
| dropout=dropout, |
| batch_first=True |
| ) |
| |
| |
| self.self_attention = nn.MultiheadAttention( |
| embed_dim=feature_dim, |
| num_heads=num_heads, |
| dropout=dropout, |
| batch_first=True |
| ) |
| |
| |
| self.cnn_spatial_proj = nn.Sequential( |
| nn.Linear(feature_dim, feature_dim), |
| nn.LayerNorm(feature_dim), |
| nn.GELU(), |
| nn.Dropout(dropout) |
| ) |
| |
| self.vit_spatial_proj = nn.Sequential( |
| nn.Linear(feature_dim, feature_dim), |
| nn.LayerNorm(feature_dim), |
| nn.GELU(), |
| nn.Dropout(dropout) |
| ) |
| |
| |
| self.global_fusion = nn.Sequential( |
| nn.Linear(feature_dim * 2, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, feature_dim), |
| nn.LayerNorm(feature_dim), |
| nn.GELU(), |
| nn.Dropout(dropout) |
| ) |
| |
| |
| self.adaptive_weight = nn.Sequential( |
| nn.Linear(feature_dim * 2, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, 2), |
| nn.Softmax(dim=-1) |
| ) |
| |
| |
| self.final_proj = nn.Sequential( |
| nn.Linear(feature_dim, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout) |
| ) |
| |
| def forward( |
| self, |
| cnn_spatial: torch.Tensor, |
| cnn_global: torch.Tensor, |
| vit_spatial: torch.Tensor, |
| vit_global: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Forward pass |
| |
| Args: |
| cnn_spatial: CNN spatial features [B, feature_dim, 7, 7] |
| cnn_global: CNN global features [B, feature_dim] |
| vit_spatial: ViT patch features [B, num_patches, feature_dim] |
| vit_global: ViT CLS token features [B, feature_dim] |
| |
| Returns: |
| fused_spatial: Fused spatial features [B, seq_len, feature_dim] |
| fused_global: Fused global features [B, feature_dim] |
| """ |
| batch_size = cnn_spatial.size(0) |
| |
| |
| cnn_spatial_seq = cnn_spatial.flatten(2).transpose(1, 2) |
| |
| |
| cnn_spatial_proj = self.cnn_spatial_proj(cnn_spatial_seq) |
| vit_spatial_proj = self.vit_spatial_proj(vit_spatial) |
| |
| |
| cnn_attended, _ = self.cnn_to_vit_attention( |
| query=cnn_spatial_proj, |
| key=vit_spatial_proj, |
| value=vit_spatial_proj |
| ) |
| |
| |
| vit_attended, _ = self.vit_to_cnn_attention( |
| query=vit_spatial_proj, |
| key=cnn_spatial_proj, |
| value=cnn_spatial_proj |
| ) |
| |
| |
| |
| combined_spatial = torch.cat([ |
| cnn_attended + cnn_spatial_proj, |
| vit_attended + vit_spatial_proj |
| ], dim=1) |
| |
| |
| fused_spatial, _ = self.self_attention( |
| query=combined_spatial, |
| key=combined_spatial, |
| value=combined_spatial |
| ) |
| |
| |
| global_concat = torch.cat([cnn_global, vit_global], dim=-1) |
| fused_global_base = self.global_fusion(global_concat) |
| |
| |
| weights = self.adaptive_weight(global_concat) |
| cnn_weight = weights[:, 0:1] |
| vit_weight = weights[:, 1:2] |
| |
| |
| fused_global = (cnn_weight * cnn_global + |
| vit_weight * vit_global + |
| fused_global_base) / 2 |
| |
| |
| fused_global = self.final_proj(fused_global) |
| |
| return fused_spatial, fused_global |
| |
| def get_output_dim(self) -> int: |
| """Get output feature dimension""" |
| return self.hidden_dim |