ktejeshnaidu commited on
Commit
1dc99e7
·
verified ·
1 Parent(s): 91d7a92

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -56
model.py CHANGED
@@ -1,113 +1,60 @@
1
  import torch
2
-
3
  import torch.nn as nn
4
-
5
  import torch.nn.functional as F
6
-
7
  import pickle
8
-
9
  from torchvision import transforms
10
 
11
-
12
-
13
  import numpy as np
14
-
15
  from PIL import Image
16
 
17
 
18
-
19
-
20
-
21
  class FaceClassifier(nn.Module):
22
-
23
  def __init__(self, num_classes):
24
-
25
  super().__init__()
26
-
27
  self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
28
-
29
  self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
30
-
31
  self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
32
-
33
  self.pool = nn.MaxPool2d(2, 2)
34
-
35
 
36
-
37
  self.dropout = nn.Dropout(0.1)
38
-
39
  self.fc1 = nn.Linear(128 * 16 * 16, 512)
40
-
41
  self.fc2 = nn.Linear(512, num_classes)
42
 
43
-
44
-
45
  def forward(self, x):
46
-
47
  x = self.pool(F.relu(self.conv1(x)))
48
-
49
  x = self.pool(F.relu(self.conv2(x)))
50
-
51
  x = self.pool(F.relu(self.conv3(x)))
52
-
53
  x = x.view(-1, 128 * 16 * 16)
54
-
55
  x = self.dropout(F.relu(self.fc1(x)))
56
-
57
  x = self.fc2(x)
58
-
59
  return x
60
 
61
 
62
 
63
-
64
-
65
-
66
-
67
  class EmotionPredictor:
68
-
69
  def __init__(self):
70
-
71
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
 
73
-
74
-
75
  with open("classes.pkl", "rb") as f:
76
-
77
  self.classes = pickle.load(f)
78
 
79
-
80
-
81
  self.model = FaceClassifier(len(self.classes))
82
-
83
  self.model.load_state_dict(
84
-
85
  torch.load("face_classifier.pth", map_location=self.device)
86
-
87
  )
88
-
89
  self.model.to(self.device).eval()
90
 
91
-
92
-
93
  self.transform = transforms.Compose([
94
-
95
  transforms.Resize((128, 128)),
96
-
97
  transforms.ToTensor(),
98
-
99
  transforms.Normalize((0.5,), (0.5,))
100
-
101
  ])
102
-
103
  @torch.inference_mode()
104
-
105
  def predict(self, image_np: np.ndarray) -> str:
106
-
107
  img = Image.fromarray(image_np)
108
-
109
  tensor = self.transform(img).unsqueeze(0).to(self.device)
110
-
111
  output = self.model(tensor)
 
 
 
112
 
113
- return self.classes[output.argmax(1).item()]
 
1
  import torch
 
2
  import torch.nn as nn
 
3
  import torch.nn.functional as F
 
4
  import pickle
 
5
  from torchvision import transforms
6
 
 
 
7
  import numpy as np
 
8
  from PIL import Image
9
 
10
 
 
 
 
11
  class FaceClassifier(nn.Module):
 
12
  def __init__(self, num_classes):
 
13
  super().__init__()
 
14
  self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
 
15
  self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
 
16
  self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
 
17
  self.pool = nn.MaxPool2d(2, 2)
 
18
 
 
19
  self.dropout = nn.Dropout(0.1)
 
20
  self.fc1 = nn.Linear(128 * 16 * 16, 512)
 
21
  self.fc2 = nn.Linear(512, num_classes)
22
 
 
 
23
  def forward(self, x):
 
24
  x = self.pool(F.relu(self.conv1(x)))
 
25
  x = self.pool(F.relu(self.conv2(x)))
 
26
  x = self.pool(F.relu(self.conv3(x)))
 
27
  x = x.view(-1, 128 * 16 * 16)
 
28
  x = self.dropout(F.relu(self.fc1(x)))
 
29
  x = self.fc2(x)
 
30
  return x
31
 
32
 
33
 
 
 
 
 
34
  class EmotionPredictor:
 
35
  def __init__(self):
 
36
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
 
 
 
38
  with open("classes.pkl", "rb") as f:
 
39
  self.classes = pickle.load(f)
40
 
 
 
41
  self.model = FaceClassifier(len(self.classes))
 
42
  self.model.load_state_dict(
 
43
  torch.load("face_classifier.pth", map_location=self.device)
 
44
  )
 
45
  self.model.to(self.device).eval()
46
 
 
 
47
  self.transform = transforms.Compose([
 
48
  transforms.Resize((128, 128)),
 
49
  transforms.ToTensor(),
 
50
  transforms.Normalize((0.5,), (0.5,))
 
51
  ])
 
52
  @torch.inference_mode()
 
53
  def predict(self, image_np: np.ndarray) -> str:
 
54
  img = Image.fromarray(image_np)
 
55
  tensor = self.transform(img).unsqueeze(0).to(self.device)
 
56
  output = self.model(tensor)
57
+ return self.classes[output.argmax(1).item()]
58
+
59
+
60