files_upload
Browse files- action_model.py +161 -0
- best_model.pt +3 -0
- label_map.json +7 -0
action_model.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torchvision import models, transforms
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from ucf101_config import UCF101_CLASSES
|
| 14 |
+
UCF101_AVAILABLE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
UCF101_AVAILABLE = False
|
| 17 |
+
print("[WARNING] UCF101 config not found, using default configuration")
|
| 18 |
+
|
| 19 |
+
class CNN_GRU(nn.Module):
|
| 20 |
+
def __init__(self, cnn_model='mobilenetv2', hidden_size=128, num_layers=1,
|
| 21 |
+
num_classes=5, dropout=0.5, FREEZE_BACKBONE=True):
|
| 22 |
+
super(CNN_GRU, self).__init__()
|
| 23 |
+
|
| 24 |
+
if cnn_model == 'mobilenetv2':
|
| 25 |
+
cnn = models.mobilenet_v2(pretrained=True)
|
| 26 |
+
self.cnn_out_features = cnn.last_channel
|
| 27 |
+
self.cnn = cnn.features
|
| 28 |
+
elif cnn_model == 'efficientnet_b0':
|
| 29 |
+
import timm
|
| 30 |
+
cnn = timm.create_model('efficientnet_b0', pretrained=True)
|
| 31 |
+
self.cnn_out_features = cnn.classifier.in_features
|
| 32 |
+
cnn.classifier = nn.Identity()
|
| 33 |
+
self.cnn = cnn
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError("Invalid CNN model")
|
| 36 |
+
|
| 37 |
+
if FREEZE_BACKBONE:
|
| 38 |
+
for p in self.cnn.parameters():
|
| 39 |
+
p.requires_grad = False
|
| 40 |
+
|
| 41 |
+
self.gru = nn.GRU(self.cnn_out_features,
|
| 42 |
+
hidden_size,
|
| 43 |
+
num_layers=num_layers,
|
| 44 |
+
batch_first=True, dropout=dropout if num_layers > 1 else 0)
|
| 45 |
+
|
| 46 |
+
self.dropout = nn.Dropout(dropout)
|
| 47 |
+
self.fc = nn.Linear(hidden_size, num_classes)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
b, t, c, h, w = x.size()
|
| 51 |
+
x = x.view(b * t, c, h, w)
|
| 52 |
+
feats = self.cnn(x)
|
| 53 |
+
feats = F.adaptive_avg_pool2d(feats, 1).view(b, t, -1)
|
| 54 |
+
out, _ = self.gru(feats)
|
| 55 |
+
out = self.dropout(out[:, -1])
|
| 56 |
+
return self.fc(out)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_transform(resize=(112, 112), augment=False):
|
| 60 |
+
transforms_list = [
|
| 61 |
+
transforms.ToPILImage(),
|
| 62 |
+
transforms.Resize(resize),
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
if augment:
|
| 66 |
+
transforms_list.extend([
|
| 67 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
| 68 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 69 |
+
])
|
| 70 |
+
transforms_list.extend([
|
| 71 |
+
transforms.ToTensor(),
|
| 72 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
| 73 |
+
[0.229, 0.224, 0.225]),
|
| 74 |
+
])
|
| 75 |
+
return transforms.Compose(transforms_list)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def preprocess_frames(frames, seq_len=16, resize=(112, 112), augment=False):
|
| 79 |
+
transform = get_transform(resize=resize, augment=augment)
|
| 80 |
+
rgb_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
|
| 81 |
+
total_frames = len(rgb_frames)
|
| 82 |
+
|
| 83 |
+
if total_frames >= seq_len:
|
| 84 |
+
indices = np.linspace(0, total_frames - 1, seq_len, dtype=int)
|
| 85 |
+
else:
|
| 86 |
+
indices = np.pad(np.arange(total_frames), (0, seq_len - total_frames), mode='wrap')
|
| 87 |
+
|
| 88 |
+
sampled_frames = [rgb_frames[i] for i in indices]
|
| 89 |
+
transformed_frames = [transform(frame) for frame in sampled_frames]
|
| 90 |
+
frames_tensor = torch.stack(transformed_frames) # [T, C, H, W]
|
| 91 |
+
return frames_tensor
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load_action_model(model_path="best_model.pt", device='cpu',
|
| 95 |
+
num_classes=5, hidden_size=128):
|
| 96 |
+
if not os.path.exists(model_path):
|
| 97 |
+
print(f"[ERROR] Model file not found: {model_path}")
|
| 98 |
+
return None
|
| 99 |
+
model = CNN_GRU(num_classes=num_classes, hidden_size=hidden_size)
|
| 100 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 101 |
+
model.load_state_dict(checkpoint)
|
| 102 |
+
model.to(device)
|
| 103 |
+
model.eval()
|
| 104 |
+
print(f"[INFO] Loaded model from {model_path} on {device}")
|
| 105 |
+
return model
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def predict_action(model, frames_tensor, label_map_path="label_map.json", device="cpu", top_k=3):
|
| 109 |
+
if model is None:
|
| 110 |
+
return {"action": "Model not loaded", "confidence": 0.0, "top_predictions": []}
|
| 111 |
+
|
| 112 |
+
idx_to_class = {}
|
| 113 |
+
if os.path.exists(label_map_path):
|
| 114 |
+
try:
|
| 115 |
+
with open(label_map_path, 'r') as f:
|
| 116 |
+
label_map = json.load(f)
|
| 117 |
+
idx_to_class = {v: k for k, v in label_map.items()}
|
| 118 |
+
print(f"[INFO] Loaded {len(idx_to_class)} classes from {label_map_path}")
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"[WARNING] Could not load label map: {e}")
|
| 121 |
+
|
| 122 |
+
if not idx_to_class and UCF101_AVAILABLE:
|
| 123 |
+
idx_to_class = {i: class_name for i, class_name in enumerate(UCF101_CLASSES)}
|
| 124 |
+
print("[INFO] Using default UCF101 class mapping")
|
| 125 |
+
elif not idx_to_class:
|
| 126 |
+
idx_to_class = {0: 'CricketShot', 1: 'PlayingCello', 2: 'Punch',
|
| 127 |
+
3: 'ShavingBeard', 4: 'TennisSwing'}
|
| 128 |
+
print("[WARNING] Using minimal default labels.")
|
| 129 |
+
|
| 130 |
+
try:
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
frames_tensor = frames_tensor.unsqueeze(0).to(device) # [1, T, C, H, W]
|
| 133 |
+
output = model(frames_tensor)
|
| 134 |
+
probabilities = torch.softmax(output, dim=1)
|
| 135 |
+
top_k_probs, top_k_indices = torch.topk(probabilities, min(top_k, probabilities.size(1)))
|
| 136 |
+
|
| 137 |
+
predicted_idx = top_k_indices[0][0].item()
|
| 138 |
+
predicted_class = idx_to_class.get(predicted_idx, f"Class_{predicted_idx}")
|
| 139 |
+
confidence = top_k_probs[0][0].item()
|
| 140 |
+
|
| 141 |
+
top_predictions = [
|
| 142 |
+
{"class": idx_to_class.get(idx.item(), f"Class_{idx.item()}"),
|
| 143 |
+
"confidence": prob.item()}
|
| 144 |
+
for prob, idx in zip(top_k_probs[0], top_k_indices[0])
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
"action": predicted_class,
|
| 149 |
+
"confidence": confidence,
|
| 150 |
+
"top_predictions": top_predictions
|
| 151 |
+
}
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f"[ERROR] Prediction failed: {e}")
|
| 154 |
+
return {"action": "Error", "confidence": 0.0, "top_predictions": []}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def log_action_prediction(action_label, confidence, log_file="logs/action_log.txt"):
|
| 158 |
+
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
| 159 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
| 160 |
+
with open(log_file, 'a', encoding='utf-8') as f:
|
| 161 |
+
f.write(f"[{timestamp}] ACTION: {action_label} (confidence: {confidence:.2f})\n")
|
best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59c240bbc50fb6027bd31a0e0470a95450ab95c82d07cb9b13ae1b87da32821c
|
| 3 |
+
size 11307065
|
label_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"CricketShot": 0,
|
| 3 |
+
"PlayingCello": 1,
|
| 4 |
+
"Punch": 2,
|
| 5 |
+
"ShavingBeard": 3,
|
| 6 |
+
"TennisSwing": 4
|
| 7 |
+
}
|