farrell236 commited on
Commit
c320b82
·
verified ·
1 Parent(s): ff1ccfe

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +125 -0
model.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import timm
5
+
6
+
7
+ # =========================
8
+ # Simple HRNet baseline
9
+ # =========================
10
+ class SimpleHRNet(nn.Module):
11
+ def __init__(self, num_landmarks=29, in_chans=3):
12
+ super().__init__()
13
+
14
+ self.stem = nn.Sequential(
15
+ nn.Conv2d(in_chans, 64, kernel_size=3, stride=2, padding=1),
16
+ nn.BatchNorm2d(64),
17
+ nn.ReLU(inplace=True),
18
+
19
+ nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
20
+ nn.BatchNorm2d(64),
21
+ nn.ReLU(inplace=True),
22
+ )
23
+
24
+ self.block1 = self._make_block(64, 64)
25
+ self.block2 = self._make_block(64, 64)
26
+ self.block3 = self._make_block(64, 64)
27
+
28
+ self.head = nn.Conv2d(64, num_landmarks, kernel_size=1)
29
+
30
+ def _make_block(self, in_ch, out_ch):
31
+ return nn.Sequential(
32
+ nn.Conv2d(in_ch, out_ch, 3, padding=1),
33
+ nn.BatchNorm2d(out_ch),
34
+ nn.ReLU(inplace=True),
35
+ nn.Conv2d(out_ch, out_ch, 3, padding=1),
36
+ nn.BatchNorm2d(out_ch),
37
+ nn.ReLU(inplace=True),
38
+ )
39
+
40
+ def forward(self, x):
41
+ x = self.stem(x)
42
+ x = self.block1(x)
43
+ x = self.block2(x)
44
+ x = self.block3(x)
45
+ return self.head(x)
46
+
47
+
48
+ # =========================
49
+ # ViT + Heatmap Head
50
+ # =========================
51
+ class ViTHeatmap(nn.Module):
52
+ def __init__(
53
+ self,
54
+ num_landmarks=29,
55
+ model_name="vit_base_patch16_224",
56
+ pretrained=True,
57
+ img_size=(512, 512),
58
+ ):
59
+ super().__init__()
60
+
61
+ self.backbone = timm.create_model(
62
+ model_name,
63
+ pretrained=pretrained,
64
+ img_size=img_size,
65
+ dynamic_img_size=True,
66
+ num_classes=0,
67
+ global_pool="",
68
+ )
69
+
70
+ embed_dim = self.backbone.num_features
71
+ self.conv_proj = nn.Conv2d(embed_dim, 256, kernel_size=1)
72
+
73
+ self.head = nn.Sequential(
74
+ nn.Conv2d(256, 256, 3, padding=1),
75
+ nn.ReLU(inplace=True),
76
+
77
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
78
+ nn.Conv2d(256, 128, 3, padding=1),
79
+ nn.ReLU(inplace=True),
80
+
81
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
82
+ nn.Conv2d(128, 64, 3, padding=1),
83
+ nn.ReLU(inplace=True),
84
+
85
+ nn.Conv2d(64, num_landmarks, kernel_size=1),
86
+ )
87
+
88
+ def forward(self, x):
89
+ B = x.shape[0]
90
+
91
+ tokens = self.backbone.forward_features(x)
92
+
93
+ if isinstance(tokens, (list, tuple)):
94
+ tokens = tokens[-1]
95
+
96
+ tokens = tokens[:, 1:, :] # drop CLS token
97
+
98
+ num_patches = tokens.shape[1]
99
+ h = x.shape[2] // 16
100
+ w = x.shape[3] // 16
101
+
102
+ if h * w != num_patches:
103
+ raise ValueError(
104
+ f"Patch grid mismatch: input {(x.shape[2], x.shape[3])}, "
105
+ f"expected {h}x{w}={h*w} patches, got {num_patches}"
106
+ )
107
+
108
+ feat = tokens.transpose(1, 2).reshape(B, -1, h, w)
109
+ feat = self.conv_proj(feat)
110
+ return self.head(feat)
111
+
112
+
113
+ # =========================
114
+ # model test
115
+ # =========================
116
+ if __name__ == "__main__":
117
+ x = torch.randn(2, 3, 224, 224)
118
+
119
+ model1 = SimpleHRNet(num_landmarks=29)
120
+ out1 = model1(x)
121
+ print("HRNet output:", out1.shape)
122
+
123
+ model2 = ViTHeatmap(num_landmarks=29)
124
+ out2 = model2(x)
125
+ print("ViT output:", out2.shape)