| import os
|
| import cv2
|
| import json
|
| import torch
|
| import numpy as np
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torchvision import models, transforms
|
| from datetime import datetime
|
| from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
|
|
|
| try:
|
| from ucf101_config import UCF101_CLASSES
|
| UCF101_AVAILABLE = True
|
| except ImportError:
|
| UCF101_AVAILABLE = False
|
| print("[WARNING] UCF101 config not found, using default configuration")
|
|
|
| class CNN_GRU(nn.Module):
|
| def __init__(self, cnn_model='mobilenetv2', hidden_size=128, num_layers=1,
|
| num_classes=5, dropout=0.5, FREEZE_BACKBONE=True):
|
| super(CNN_GRU, self).__init__()
|
|
|
| if cnn_model == 'mobilenetv2':
|
| cnn = models.mobilenet_v2(pretrained=True)
|
| self.cnn_out_features = cnn.last_channel
|
| self.cnn = cnn.features
|
| elif cnn_model == 'efficientnet_b0':
|
| import timm
|
| cnn = timm.create_model('efficientnet_b0', pretrained=True)
|
| self.cnn_out_features = cnn.classifier.in_features
|
| cnn.classifier = nn.Identity()
|
| self.cnn = cnn
|
| else:
|
| raise ValueError("Invalid CNN model")
|
|
|
| if FREEZE_BACKBONE:
|
| for p in self.cnn.parameters():
|
| p.requires_grad = False
|
|
|
| self.gru = nn.GRU(self.cnn_out_features,
|
| hidden_size,
|
| num_layers=num_layers,
|
| batch_first=True, dropout=dropout if num_layers > 1 else 0)
|
|
|
| self.dropout = nn.Dropout(dropout)
|
| self.fc = nn.Linear(hidden_size, num_classes)
|
|
|
| def forward(self, x):
|
| b, t, c, h, w = x.size()
|
| x = x.view(b * t, c, h, w)
|
| feats = self.cnn(x)
|
| feats = F.adaptive_avg_pool2d(feats, 1).view(b, t, -1)
|
| out, _ = self.gru(feats)
|
| out = self.dropout(out[:, -1])
|
| return self.fc(out)
|
|
|
|
|
| def get_transform(resize=(112, 112), augment=False):
|
| transforms_list = [
|
| transforms.ToPILImage(),
|
| transforms.Resize(resize),
|
| ]
|
|
|
| if augment:
|
| transforms_list.extend([
|
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
| transforms.RandomHorizontalFlip(p=0.5),
|
| ])
|
| transforms_list.extend([
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.485, 0.456, 0.406],
|
| [0.229, 0.224, 0.225]),
|
| ])
|
| return transforms.Compose(transforms_list)
|
|
|
|
|
| def preprocess_frames(frames, seq_len=16, resize=(112, 112), augment=False):
|
| transform = get_transform(resize=resize, augment=augment)
|
| rgb_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
|
| total_frames = len(rgb_frames)
|
|
|
| if total_frames >= seq_len:
|
| indices = np.linspace(0, total_frames - 1, seq_len, dtype=int)
|
| else:
|
| indices = np.pad(np.arange(total_frames), (0, seq_len - total_frames), mode='wrap')
|
|
|
| sampled_frames = [rgb_frames[i] for i in indices]
|
| transformed_frames = [transform(frame) for frame in sampled_frames]
|
| frames_tensor = torch.stack(transformed_frames)
|
| return frames_tensor
|
|
|
|
|
| def load_action_model(model_path="best_model.pt", device='cpu',
|
| num_classes=5, hidden_size=128):
|
| if not os.path.exists(model_path):
|
| print(f"[ERROR] Model file not found: {model_path}")
|
| return None
|
| model = CNN_GRU(num_classes=num_classes, hidden_size=hidden_size)
|
| checkpoint = torch.load(model_path, map_location=device)
|
| model.load_state_dict(checkpoint)
|
| model.to(device)
|
| model.eval()
|
| print(f"[INFO] Loaded model from {model_path} on {device}")
|
| return model
|
|
|
|
|
| def predict_action(model, frames_tensor, label_map_path="label_map.json", device="cpu", top_k=3):
|
| if model is None:
|
| return {"action": "Model not loaded", "confidence": 0.0, "top_predictions": []}
|
|
|
| idx_to_class = {}
|
| if os.path.exists(label_map_path):
|
| try:
|
| with open(label_map_path, 'r') as f:
|
| label_map = json.load(f)
|
| idx_to_class = {v: k for k, v in label_map.items()}
|
| print(f"[INFO] Loaded {len(idx_to_class)} classes from {label_map_path}")
|
| except Exception as e:
|
| print(f"[WARNING] Could not load label map: {e}")
|
|
|
| if not idx_to_class and UCF101_AVAILABLE:
|
| idx_to_class = {i: class_name for i, class_name in enumerate(UCF101_CLASSES)}
|
| print("[INFO] Using default UCF101 class mapping")
|
| elif not idx_to_class:
|
| idx_to_class = {0: 'CricketShot', 1: 'PlayingCello', 2: 'Punch',
|
| 3: 'ShavingBeard', 4: 'TennisSwing'}
|
| print("[WARNING] Using minimal default labels.")
|
|
|
| try:
|
| with torch.no_grad():
|
| frames_tensor = frames_tensor.unsqueeze(0).to(device)
|
| output = model(frames_tensor)
|
| probabilities = torch.softmax(output, dim=1)
|
| top_k_probs, top_k_indices = torch.topk(probabilities, min(top_k, probabilities.size(1)))
|
|
|
| predicted_idx = top_k_indices[0][0].item()
|
| predicted_class = idx_to_class.get(predicted_idx, f"Class_{predicted_idx}")
|
| confidence = top_k_probs[0][0].item()
|
|
|
| top_predictions = [
|
| {"class": idx_to_class.get(idx.item(), f"Class_{idx.item()}"),
|
| "confidence": prob.item()}
|
| for prob, idx in zip(top_k_probs[0], top_k_indices[0])
|
| ]
|
|
|
| return {
|
| "action": predicted_class,
|
| "confidence": confidence,
|
| "top_predictions": top_predictions
|
| }
|
| except Exception as e:
|
| print(f"[ERROR] Prediction failed: {e}")
|
| return {"action": "Error", "confidence": 0.0, "top_predictions": []}
|
|
|
|
|
| def log_action_prediction(action_label, confidence, log_file="logs/action_log.txt"):
|
| os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
| with open(log_file, 'a', encoding='utf-8') as f:
|
| f.write(f"[{timestamp}] ACTION: {action_label} (confidence: {confidence:.2f})\n") |