ktejeshnaidu commited on
Commit
a8faaa3
·
verified ·
1 Parent(s): 9e86c28

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +66 -60
model.py CHANGED
@@ -1,60 +1,66 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import pickle
5
- from torchvision import transforms
6
-
7
- import numpy as np
8
- from PIL import Image
9
-
10
-
11
- class FaceClassifier(nn.Module):
12
- def __init__(self, num_classes):
13
- super().__init__()
14
- self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
15
- self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
16
- self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
17
- self.pool = nn.MaxPool2d(2, 2)
18
-
19
- self.dropout = nn.Dropout(0.1)
20
- self.fc1 = nn.Linear(128 * 16 * 16, 512)
21
- self.fc2 = nn.Linear(512, num_classes)
22
-
23
- def forward(self, x):
24
- x = self.pool(F.relu(self.conv1(x)))
25
- x = self.pool(F.relu(self.conv2(x)))
26
- x = self.pool(F.relu(self.conv3(x)))
27
- x = x.view(-1, 128 * 16 * 16)
28
- x = self.dropout(F.relu(self.fc1(x)))
29
- x = self.fc2(x)
30
- return x
31
-
32
-
33
-
34
- class EmotionPredictor:
35
- def __init__(self):
36
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
-
38
- with open("classes.pkl", "rb") as f:
39
- self.classes = pickle.load(f)
40
-
41
- self.model = FaceClassifier(len(self.classes))
42
- self.model.load_state_dict(
43
- torch.load("face_classifier.pth", map_location=self.device)
44
- )
45
- self.model.to(self.device).eval()
46
-
47
- self.transform = transforms.Compose([
48
- transforms.Resize((128, 128)),
49
- transforms.ToTensor(),
50
- transforms.Normalize((0.5,), (0.5,))
51
- ])
52
- @torch.inference_mode()
53
- def predict(self, image_np: np.ndarray) -> str:
54
- img = Image.fromarray(image_np)
55
- tensor = self.transform(img).unsqueeze(0).to(self.device)
56
- output = self.model(tensor)
57
- return self.classes[output.argmax(1).item()]
58
-
59
-
60
-
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import pickle
4
+ import numpy as np
5
+ # Import your specific model architecture here if you saved a state_dict!
6
+ # from your_network_file import YourCNNClass
7
+
8
+ class EmotionPredictor:
9
+ def __init__(self, model_path='face_classifier.pth', classes_path='classes.pkl'):
10
+ # 1. Device Management: Automatically fall back to CPU for Hugging Face Spaces
11
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ # 2. Weight & Class Caching: Load these ONLY ONCE when the server starts
14
+ self.classes = self._load_classes(classes_path)
15
+ self.model = self._load_model(model_path)
16
+
17
+ # 3. Pre-compiled Tensor Transformations
18
+ # (Adjust the Resize dimensions to match what you used in Train_model.ipynb)
19
+ self.transform = transforms.Compose([
20
+ transforms.ToPILImage(),
21
+ transforms.Resize((224, 224)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
24
+ ])
25
+
26
+ def _load_classes(self, path):
27
+ try:
28
+ with open(path, 'rb') as f:
29
+ return pickle.load(f)
30
+ except Exception as e:
31
+ print(f"Warning: Could not load {path}. Defaulting to standard classes. Error: {e}")
32
+ return ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
33
+
34
+ def _load_model(self, path):
35
+ try:
36
+ # OPTION A: If you saved the ENTIRE model in your Jupyter Notebook
37
+ model = torch.load(path, map_location=self.device)
38
+
39
+ # OPTION B: If you saved ONLY the state_dict (Best Practice)
40
+ # Uncomment and use this if Option A throws an architecture error:
41
+ # model = YourCNNClass(num_classes=len(self.classes))
42
+ # model.load_state_dict(torch.load(path, map_location=self.device))
43
+
44
+ model.to(self.device)
45
+
46
+ # CRITICAL: Put the model in evaluation mode to disable dropout/batchnorm
47
+ model.eval()
48
+ return model
49
+
50
+ except Exception as e:
51
+ raise RuntimeError(f"Failed to load PyTorch model: {e}")
52
+
53
+ def predict(self, face_image_rgb):
54
+ """
55
+ Expects an RGB numpy array of the cropped face from OpenCV.
56
+ """
57
+ # Apply transforms and add the batch dimension (B, C, H, W)
58
+ tensor = self.transform(face_image_rgb).unsqueeze(0).to(self.device)
59
+
60
+ # Disable gradient calculation for significantly faster CPU inference
61
+ with torch.no_grad():
62
+ outputs = self.model(tensor)
63
+ _, predicted = torch.max(outputs, 1)
64
+
65
+ return self.classes[predicted.item()]
66
+