| """
|
| BaramNuri (바람누리) - Lightweight Driver Behavior Detection Model
|
|
|
| A hybrid architecture combining:
|
| - Video Swin Transformer (Stage 1-3) for spatial features
|
| - Selective State Space Model (SSM) for temporal modeling
|
|
|
| Trained via Knowledge Distillation from Video Swin-T teacher.
|
|
|
| Author: C-Team
|
| License: Apache-2.0
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torchvision.models.video import swin3d_t, Swin3D_T_Weights
|
| from typing import Dict, Tuple
|
|
|
|
|
| class SelectiveSSM(nn.Module):
|
| """
|
| Selective State Space Model (Mamba-style)
|
|
|
| Key: Dynamically generates B, C, delta based on input
|
| - Important information is remembered
|
| - Less important information is quickly forgotten
|
| """
|
|
|
| def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2, dropout: float = 0.1):
|
| super().__init__()
|
|
|
| self.d_model = d_model
|
| self.d_state = d_state
|
| self.d_conv = d_conv
|
| self.expand = expand
|
| self.d_inner = d_model * expand
|
|
|
|
|
| self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
|
|
|
|
| self.conv1d = nn.Conv1d(
|
| self.d_inner, self.d_inner,
|
| kernel_size=d_conv,
|
| padding=d_conv - 1,
|
| groups=self.d_inner
|
| )
|
|
|
|
|
| self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
|
|
|
|
|
| self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32)))
|
| self.D = nn.Parameter(torch.ones(self.d_inner))
|
|
|
|
|
| self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
|
|
| self.dropout = nn.Dropout(dropout)
|
| self.layer_norm = nn.LayerNorm(d_model)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| x: [B, T, D]
|
| Returns:
|
| y: [B, T, D]
|
| """
|
| residual = x
|
| x = self.layer_norm(x)
|
|
|
| B, T, D = x.shape
|
|
|
|
|
| xz = self.in_proj(x)
|
| x, z = xz.chunk(2, dim=-1)
|
|
|
|
|
| x = x.transpose(1, 2)
|
| x = self.conv1d(x)[:, :, :T]
|
| x = x.transpose(1, 2)
|
|
|
| x = F.silu(x)
|
|
|
|
|
| x_ssm = self.x_proj(x)
|
| B_t = x_ssm[:, :, :self.d_state]
|
| C_t = x_ssm[:, :, self.d_state:self.d_state*2]
|
| delta = F.softplus(x_ssm[:, :, -1:])
|
|
|
|
|
| A = -torch.exp(self.A_log)
|
|
|
|
|
| A_bar = torch.exp(delta * A.view(1, 1, -1))
|
|
|
|
|
| h = torch.zeros(B, self.d_inner, self.d_state, device=x.device, dtype=x.dtype)
|
| outputs = []
|
|
|
| for t in range(T):
|
| x_t = x[:, t, :]
|
| B_t_t = B_t[:, t, :]
|
| C_t_t = C_t[:, t, :]
|
| A_bar_t = A_bar[:, t, :]
|
|
|
|
|
| h = h * A_bar_t.unsqueeze(1) + B_t_t.unsqueeze(1) * x_t.unsqueeze(-1)
|
|
|
|
|
| y_t = (C_t_t.unsqueeze(1) * h).sum(dim=-1) + self.D * x_t
|
| outputs.append(y_t)
|
|
|
| y = torch.stack(outputs, dim=1)
|
|
|
|
|
| y = y * F.silu(z)
|
|
|
|
|
| y = self.out_proj(y)
|
| y = self.dropout(y)
|
|
|
| return y + residual
|
|
|
|
|
| class TemporalSSMBlock(nn.Module):
|
| """
|
| Temporal SSM Block for video
|
|
|
| Takes [B, T, C] sequence and applies SSM layers
|
| """
|
|
|
| def __init__(self, d_model: int, d_state: int = 16, n_layers: int = 2, dropout: float = 0.1):
|
| super().__init__()
|
|
|
| self.ssm_layers = nn.ModuleList([
|
| SelectiveSSM(d_model, d_state=d_state, dropout=dropout)
|
| for _ in range(n_layers)
|
| ])
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| x: [B, T, D] sequence
|
| Returns:
|
| y: [B, D] final representation
|
| """
|
| for ssm in self.ssm_layers:
|
| x = ssm(x)
|
|
|
| return x.mean(dim=1)
|
|
|
|
|
| class BaramNuri(nn.Module):
|
| """
|
| BaramNuri (바람누리) - Lightweight Driver Behavior Detection Model
|
|
|
| Architecture:
|
| 1. Video Swin-T Stages 1-3 (spatial features, 384 dim)
|
| 2. Selective SSM Block (temporal modeling)
|
| 3. Classification Head
|
|
|
| Parameters: 14.20M (49% reduction from teacher's 27.86M)
|
| Performance: 96.17% accuracy, 0.9504 Macro F1
|
| """
|
|
|
| CLASS_NAMES = ["정상", "졸음운전", "물건찾기", "휴대폰 사용", "운전자 폭행"]
|
| CLASS_NAMES_EN = ["normal", "drowsy_driving", "searching_object", "phone_usage", "driver_assault"]
|
|
|
| def __init__(
|
| self,
|
| num_classes: int = 5,
|
| pretrained: bool = True,
|
| d_state: int = 16,
|
| ssm_layers: int = 2,
|
| dropout: float = 0.2,
|
| ):
|
| super().__init__()
|
|
|
| self.num_classes = num_classes
|
|
|
|
|
| if pretrained:
|
| print("Loading Swin backbone (Kinetics-400 pretrained)...")
|
| full_swin = swin3d_t(weights=Swin3D_T_Weights.KINETICS400_V1)
|
| else:
|
| full_swin = swin3d_t(weights=None)
|
|
|
|
|
| self.patch_embed = full_swin.patch_embed
|
|
|
|
|
| self.features = nn.Sequential(*[full_swin.features[i] for i in range(5)])
|
|
|
|
|
| self.feature_dim = 384
|
|
|
|
|
| self.avgpool = nn.AdaptiveAvgPool3d(output_size=1)
|
|
|
|
|
| self.temporal_ssm = TemporalSSMBlock(
|
| d_model=self.feature_dim,
|
| d_state=d_state,
|
| n_layers=ssm_layers,
|
| dropout=dropout,
|
| )
|
|
|
|
|
| self.head = nn.Sequential(
|
| nn.LayerNorm(self.feature_dim),
|
| nn.Dropout(p=dropout),
|
| nn.Linear(self.feature_dim, num_classes),
|
| )
|
|
|
|
|
| self._init_head()
|
|
|
|
|
| del full_swin
|
|
|
| def _init_head(self):
|
| """Initialize head weights"""
|
| for m in self.head.modules():
|
| if isinstance(m, nn.Linear):
|
| nn.init.trunc_normal_(m.weight, std=0.02)
|
| if m.bias is not None:
|
| nn.init.zeros_(m.bias)
|
|
|
| def extract_features(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Extract features (for knowledge distillation)
|
|
|
| Args:
|
| x: [B, C, T, H, W]
|
| Returns:
|
| features: [B, feature_dim]
|
| """
|
|
|
| x = self.patch_embed(x)
|
|
|
|
|
| x = self.features(x)
|
|
|
| B, T, H, W, C = x.shape
|
|
|
|
|
| x = x.mean(dim=[2, 3])
|
|
|
|
|
| x = self.temporal_ssm(x)
|
|
|
| return x
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Forward pass
|
|
|
| Args:
|
| x: [B, C, T, H, W] video tensor
|
| Returns:
|
| logits: [B, num_classes]
|
| """
|
| features = self.extract_features(x)
|
| logits = self.head(features)
|
| return logits
|
|
|
| def forward_with_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """
|
| Return both features and logits (for knowledge distillation)
|
| """
|
| features = self.extract_features(x)
|
| logits = self.head(features)
|
| return logits, features
|
|
|
| def predict(self, x: torch.Tensor, return_english: bool = False) -> Dict:
|
| """
|
| Inference prediction
|
|
|
| Args:
|
| x: [1, C, T, H, W] single video
|
| return_english: Return English class names
|
| Returns:
|
| dict with class, confidence, class_name
|
| """
|
| self.eval()
|
| with torch.no_grad():
|
| logits = self.forward(x)
|
| probs = F.softmax(logits, dim=-1)[0]
|
| class_idx = probs.argmax().item()
|
|
|
| class_names = self.CLASS_NAMES_EN if return_english else self.CLASS_NAMES
|
|
|
| return {
|
| "class": class_idx,
|
| "confidence": probs[class_idx].item(),
|
| "class_name": class_names[class_idx],
|
| "all_probs": {
|
| name: probs[i].item()
|
| for i, name in enumerate(class_names)
|
| }
|
| }
|
|
|
| @classmethod
|
| def from_pretrained(cls, checkpoint_path: str, device: str = 'cpu'):
|
| """
|
| Load pretrained model from checkpoint
|
|
|
| Args:
|
| checkpoint_path: Path to .pth file
|
| device: 'cpu' or 'cuda'
|
| Returns:
|
| Loaded model in eval mode
|
| """
|
| model = cls(num_classes=5, pretrained=True)
|
| checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
|
| if 'model_state_dict' in checkpoint:
|
| model.load_state_dict(checkpoint['model_state_dict'])
|
| else:
|
| model.load_state_dict(checkpoint)
|
|
|
| model = model.to(device)
|
| model.eval()
|
|
|
| return model
|
|
|
|
|
| def count_parameters(model: nn.Module) -> int:
|
| """Count total model parameters"""
|
| return sum(p.numel() for p in model.parameters())
|
|
|
|
|
| if __name__ == "__main__":
|
| print("=" * 60)
|
| print("BaramNuri Model Test")
|
| print("=" * 60)
|
|
|
|
|
| model = BaramNuri(num_classes=5, pretrained=True)
|
|
|
|
|
| total_params = count_parameters(model)
|
| print(f"\nTotal parameters: {total_params:,} ({total_params/1e6:.2f}M)")
|
|
|
|
|
| dummy_input = torch.randn(2, 3, 30, 224, 224)
|
| print(f"\nInput shape: {dummy_input.shape}")
|
|
|
|
|
| model.eval()
|
| with torch.no_grad():
|
| output = model(dummy_input)
|
| print(f"Output shape: {output.shape}")
|
|
|
|
|
| single_input = torch.randn(1, 3, 30, 224, 224)
|
| prediction = model.predict(single_input)
|
| print(f"\nPrediction (Korean): {prediction['class_name']} ({prediction['confidence']:.2%})")
|
|
|
| prediction_en = model.predict(single_input, return_english=True)
|
| print(f"Prediction (English): {prediction_en['class_name']} ({prediction_en['confidence']:.2%})")
|
|
|
| print("\nModel test passed!")
|
|
|