| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class CrossModalAttention(nn.Module): |
| """ |
| Cross-modal attention mechanism for fusing vision, audio, and text features. |
| """ |
| def __init__(self, embed_dim=256, num_heads=8): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
|
|
| self.query_proj = nn.Linear(embed_dim, embed_dim) |
| self.key_proj = nn.Linear(embed_dim, embed_dim) |
| self.value_proj = nn.Linear(embed_dim, embed_dim) |
|
|
| self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) |
|
|
| self.norm = nn.LayerNorm(embed_dim) |
| self.dropout = nn.Dropout(0.1) |
|
|
| def forward(self, query, key_value): |
| """ |
| query: (B, seq_len_q, embed_dim) |
| key_value: (B, seq_len_kv, embed_dim) |
| """ |
| |
| q = self.query_proj(query) |
| k = self.key_proj(key_value) |
| v = self.value_proj(key_value) |
|
|
| |
| attn_output, attn_weights = self.multihead_attn(q, k, v) |
|
|
| |
| output = self.norm(query + self.dropout(attn_output)) |
|
|
| return output, attn_weights |
|
|
| class TemporalTransformer(nn.Module): |
| """ |
| Temporal transformer for modeling sequences across time windows. |
| """ |
| def __init__(self, embed_dim=256, num_layers=4, num_heads=8): |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| nn.TransformerEncoderLayer( |
| d_model=embed_dim, |
| nhead=num_heads, |
| dim_feedforward=embed_dim * 4, |
| dropout=0.1, |
| batch_first=True |
| ) for _ in range(num_layers) |
| ]) |
|
|
| self.norm = nn.LayerNorm(embed_dim) |
|
|
| def forward(self, x): |
| """ |
| x: (B, seq_len, embed_dim) - sequence of fused features over time |
| """ |
| for layer in self.layers: |
| x = layer(x) |
|
|
| return self.norm(x) |
|
|
| class MultiModalFusion(nn.Module): |
| """ |
| Complete fusion network combining vision, audio, text with temporal modeling. |
| """ |
| def __init__(self, vision_dim=768, audio_dim=128, text_dim=768, embed_dim=256, |
| num_emotions=7, num_intents=5): |
| super().__init__() |
| self.embed_dim = embed_dim |
|
|
| |
| self.vision_proj = nn.Linear(vision_dim, embed_dim) |
| self.audio_proj = nn.Linear(audio_dim, embed_dim) |
| self.text_proj = nn.Linear(text_dim, embed_dim) |
|
|
| |
| self.vision_to_audio_attn = CrossModalAttention(embed_dim) |
| self.audio_to_text_attn = CrossModalAttention(embed_dim) |
| self.text_to_vision_attn = CrossModalAttention(embed_dim) |
|
|
| |
| self.temporal_transformer = TemporalTransformer(embed_dim) |
|
|
| |
| self.modality_weights = nn.Parameter(torch.ones(3)) |
|
|
| |
| self.emotion_classifier = nn.Linear(embed_dim, num_emotions) |
| self.intent_classifier = nn.Linear(embed_dim, num_intents) |
| self.engagement_regressor = nn.Linear(embed_dim, 1) |
| self.confidence_regressor = nn.Linear(embed_dim, 1) |
|
|
| |
| self.contribution_estimator = nn.Linear(embed_dim * 3, 3) |
|
|
| def forward(self, vision_features, audio_features, text_features, temporal_seq=False): |
| """ |
| vision_features: (B, vision_dim) or (B, T, vision_dim) |
| audio_features: (B, audio_dim) or (B, T, audio_dim) |
| text_features: (B, text_dim) or (B, T, text_dim) |
| temporal_seq: whether inputs are temporal sequences |
| """ |
| |
| v_proj = self.vision_proj(vision_features) |
| a_proj = self.audio_proj(audio_features) |
| t_proj = self.text_proj(text_features) |
|
|
| if temporal_seq: |
| |
| B, T, _ = v_proj.shape |
|
|
| |
| v_flat = v_proj.view(B*T, 1, -1) |
| a_flat = a_proj.view(B*T, 1, -1) |
| t_flat = t_proj.view(B*T, 1, -1) |
|
|
| |
| v_attn, _ = self.vision_to_audio_attn(v_flat, a_flat) |
| a_attn, _ = self.audio_to_text_attn(a_flat, t_flat) |
| t_attn, _ = self.text_to_vision_attn(t_flat, v_flat) |
|
|
| |
| fused = (v_attn + a_attn + t_attn) / 3 |
|
|
| |
| fused = fused.view(B, T, -1) |
|
|
| |
| temporal_out = self.temporal_transformer(fused) |
|
|
| |
| pooled = temporal_out[:, -1, :] |
|
|
| else: |
| |
| |
| v_attn, _ = self.vision_to_audio_attn(v_proj.unsqueeze(1), a_proj.unsqueeze(1)) |
| a_attn, _ = self.audio_to_text_attn(a_proj.unsqueeze(1), t_proj.unsqueeze(1)) |
| t_attn, _ = self.text_to_vision_attn(t_proj.unsqueeze(1), v_proj.unsqueeze(1)) |
|
|
| |
| weights = F.softmax(self.modality_weights, dim=0) |
| fused = weights[0] * v_attn.squeeze(1) + \ |
| weights[1] * a_attn.squeeze(1) + \ |
| weights[2] * t_attn.squeeze(1) |
|
|
| pooled = fused |
|
|
| |
| emotion_logits = self.emotion_classifier(pooled) |
| intent_logits = self.intent_classifier(pooled) |
| engagement = torch.sigmoid(self.engagement_regressor(pooled)) |
| confidence = torch.sigmoid(self.confidence_regressor(pooled)) |
|
|
| |
| contributions = torch.softmax(self.contribution_estimator( |
| torch.cat([v_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True), |
| a_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True), |
| t_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True)], dim=-1) |
| ), dim=-1) |
|
|
| return { |
| 'emotion': emotion_logits, |
| 'intent': intent_logits, |
| 'engagement': engagement.squeeze(), |
| 'confidence': confidence.squeeze(), |
| 'contributions': contributions.squeeze() |
| } |
|
|
| def get_modality_weights(self): |
| """ |
| Return normalized modality weights for explainability. |
| """ |
| return F.softmax(self.modality_weights, dim=0) |