farrell236 commited on
Commit
325d063
·
verified ·
1 Parent(s): f505d7e

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +231 -48
  3. best.pt.enc +3 -0
  4. requirements.txt +16 -3
  5. secure_torch_load.py +56 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ best.pt.enc filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,67 +1,250 @@
1
- import os
2
- import traceback
3
  import gradio as gr
 
 
4
  import torch
5
- from huggingface_hub import hf_hub_download
6
 
7
- # Change this
8
- MODEL_REPO_ID = "farrell236/CephVIT"
9
- MODEL_FILENAME = "best.pt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
11
 
12
- def inspect_checkpoint():
13
- try:
14
- hf_token = os.getenv("HF_TOKEN")
15
- if not hf_token:
16
- return "ERROR: HF_TOKEN is missing. Add it in Space Settings -> Secrets."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- local_path = hf_hub_download(
19
- repo_id=MODEL_REPO_ID,
20
- filename=MODEL_FILENAME,
21
- token=hf_token,
22
- )
23
 
24
- lines = []
25
- lines.append("Download successful.")
26
- lines.append(f"Local path: {local_path}")
27
 
28
- ckpt = torch.load(local_path, map_location="cpu")
29
 
30
- lines.append("")
31
- lines.append(f"Top-level object type: {type(ckpt).__name__}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- if isinstance(ckpt, dict):
34
- top_keys = list(ckpt.keys())
35
- lines.append(f"Top-level key count: {len(top_keys)}")
36
- lines.append("Top-level keys:")
37
- for k in top_keys[:50]:
38
- lines.append(f" - {k}")
39
 
40
- if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
41
- sd_keys = list(ckpt["state_dict"].keys())
42
- lines.append("")
43
- lines.append(f"state_dict key count: {len(sd_keys)}")
44
- lines.append("First 20 state_dict keys:")
45
- for k in sd_keys[:20]:
46
- lines.append(f" - {k}")
47
 
48
- else:
49
- lines.append("Checkpoint is not a dict, so no keys to print.")
50
 
51
- return "\n".join(lines)
 
 
 
 
 
 
 
 
52
 
53
- except Exception as e:
54
- return f"ERROR:\n{type(e).__name__}: {e}\n\n{traceback.format_exc()}"
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- demo = gr.Interface(
58
- fn=inspect_checkpoint,
59
- inputs=None,
60
- outputs=gr.Textbox(label="Checkpoint inspection", lines=30),
61
- title="Private checkpoint test",
62
- description="Checks whether best.pt can be downloaded from a private Hugging Face repo and inspected.",
63
- )
64
 
65
 
66
  if __name__ == "__main__":
67
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
 
2
  import gradio as gr
3
+ import numpy as np
4
+ import cv2
5
  import torch
 
6
 
7
+ from model import SimpleHRNet, ViTHeatmap
8
+ from heatmap_utils import heatmaps_to_coords_dark
9
+ from secure_torch_load import secure_torch_load
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser(description="Cephalogram landmark inference app")
14
+ parser.add_argument("--checkpoint", type=str, default="best.pt.enc", help="Path to model checkpoint")
15
+ parser.add_argument("--device", type=str, default=("cuda" if torch.cuda.is_available() else "cpu"), help="Torch device, e.g. cuda or cpu")
16
+ parser.add_argument("--server-port", type=int, default=44065, help="Port for Gradio app")
17
+ parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Host for Gradio app")
18
+ parser.add_argument("--share", action="store_true", help="Enable public Gradio share link")
19
+ parser.add_argument("--inbrowser", action="store_true", help="Open app in browser on launch")
20
+ return parser.parse_args()
21
+
22
+
23
+ def load_model(checkpoint_path, device):
24
+ ckpt = secure_torch_load(checkpoint_path, map_location="cpu")
25
+ # ckpt = torch.load(checkpoint_path, map_location="cpu")
26
+ args = ckpt["args"]
27
+ landmark_symbols = ckpt.get("landmark_symbols", None)
28
+
29
+ if args["model"] == "hrnet":
30
+ model = SimpleHRNet(num_landmarks=args["num_landmarks"])
31
+ else:
32
+ model = ViTHeatmap(
33
+ num_landmarks=args["num_landmarks"],
34
+ model_name=args["vit_name"],
35
+ pretrained=False,
36
+ img_size=(args["input_height"], args["input_width"]),
37
+ )
38
 
39
+ model.load_state_dict(ckpt["model_state_dict"])
40
+ model.to(device)
41
+ model.eval()
42
 
43
+ return model, args, landmark_symbols
44
+
45
+
46
+ def get_symbols(n, checkpoint_symbols):
47
+ if checkpoint_symbols is not None and len(checkpoint_symbols) == n:
48
+ return checkpoint_symbols
49
+ return [f"LM_{i}" for i in range(n)]
50
+
51
+
52
+ def preprocess(image, model_args, device):
53
+ h_orig, w_orig = image.shape[:2]
54
+ h_in = model_args["input_height"]
55
+ w_in = model_args["input_width"]
56
+
57
+ resized = cv2.resize(image, (w_in, h_in))
58
+ tensor = torch.from_numpy(resized).permute(2, 0, 1).float() / 255.0
59
+ tensor = tensor.unsqueeze(0).to(device)
60
+ return tensor, (h_orig, w_orig), (h_in, w_in)
61
+
62
+
63
+ def decode(pred_heatmaps, orig_size, input_size):
64
+ h_orig, w_orig = orig_size
65
+ h_in, w_in = input_size
66
+ h_hm, w_hm = pred_heatmaps.shape[2], pred_heatmaps.shape[3]
67
+
68
+ coords_hm = heatmaps_to_coords_dark(pred_heatmaps)[0]
69
+
70
+ coords_in = coords_hm.clone()
71
+ coords_in[:, 0] *= (w_in / w_hm)
72
+ coords_in[:, 1] *= (h_in / h_hm)
73
+
74
+ coords_orig = coords_in.clone()
75
+ coords_orig[:, 0] *= (w_orig / w_in)
76
+ coords_orig[:, 1] *= (h_orig / h_in)
77
+
78
+ return coords_orig.cpu().numpy()
79
 
 
 
 
 
 
80
 
81
+ def compute_confidence(heatmaps):
82
+ hm = heatmaps[0].detach().cpu().numpy()
83
+ return hm.reshape(hm.shape[0], -1).max(axis=1)
84
 
 
85
 
86
+ def draw_points(image, coords, symbols, color=(255, 0, 0)):
87
+ out = image.copy()
88
+ h, w = out.shape[:2]
89
+ for i, (x, y) in enumerate(coords):
90
+ x, y = int(round(float(x))), int(round(float(y)))
91
+ if 0 <= x < w and 0 <= y < h:
92
+ cv2.circle(out, (x, y), 4, color, -1, lineType=cv2.LINE_AA)
93
+ cv2.putText(
94
+ out,
95
+ symbols[i],
96
+ (x + 5, y - 5),
97
+ cv2.FONT_HERSHEY_SIMPLEX,
98
+ 0.4,
99
+ color,
100
+ 1,
101
+ cv2.LINE_AA,
102
+ )
103
+ return out
104
 
 
 
 
 
 
 
105
 
106
+ def heatmap_overlay(image, heatmap):
107
+ h, w = image.shape[:2]
108
+ hm = cv2.resize(heatmap, (w, h), interpolation=cv2.INTER_LINEAR)
109
+ hm = (hm - hm.min()) / (hm.max() - hm.min() + 1e-6)
110
+ hm_color = cv2.applyColorMap((hm * 255).astype(np.uint8), cv2.COLORMAP_JET)
111
+ hm_color = cv2.cvtColor(hm_color, cv2.COLOR_BGR2RGB)
112
+ return cv2.addWeighted(image, 0.6, hm_color, 0.4, 0)
113
 
 
 
114
 
115
+ def make_single_landmark_view(orig, coords, symbols, hm_np, idx):
116
+ out = heatmap_overlay(orig, hm_np[idx])
117
+ out = draw_points(
118
+ out,
119
+ np.array([coords[idx]], dtype=np.float32),
120
+ [symbols[idx]],
121
+ color=(255, 255, 255),
122
+ )
123
+ return out
124
 
 
 
125
 
126
+ def build_demo(model, model_args, checkpoint_symbols, device):
127
+ default_symbols = get_symbols(model_args["num_landmarks"], checkpoint_symbols)
128
+
129
+ def run_inference(image):
130
+ if image is None:
131
+ return None, None, None, None, None, None, gr.Dropdown()
132
+
133
+ orig = image.copy()
134
+ tensor, orig_size, input_size = preprocess(orig, model_args, device)
135
+
136
+ with torch.no_grad():
137
+ heatmaps = model(tensor)
138
+
139
+ coords = decode(heatmaps, orig_size, input_size)
140
+ hm_np = heatmaps[0].detach().cpu().numpy()
141
+ conf = compute_confidence(heatmaps)
142
+ symbols = get_symbols(len(coords), checkpoint_symbols)
143
+
144
+ pred_overlay = draw_points(orig, coords, symbols)
145
+ summed_overlay = heatmap_overlay(orig, hm_np.sum(axis=0))
146
+ single_overlay = make_single_landmark_view(orig, coords, symbols, hm_np, 0)
147
+
148
+ table = [
149
+ [symbols[i], float(coords[i, 0]), float(coords[i, 1]), float(conf[i])]
150
+ for i in range(len(symbols))
151
+ ]
152
+
153
+ cache = {
154
+ "orig": orig,
155
+ "coords": coords,
156
+ "symbols": symbols,
157
+ "heatmaps": hm_np,
158
+ "pred_overlay": pred_overlay,
159
+ "summed_overlay": summed_overlay,
160
+ "table": table,
161
+ }
162
+
163
+ dropdown_update = gr.Dropdown(choices=symbols, value=symbols[0])
164
+
165
+ return orig, pred_overlay, summed_overlay, single_overlay, table, cache, dropdown_update
166
+
167
+ def update_selected_landmark(selected_landmark, cache):
168
+ if cache is None:
169
+ return None
170
+
171
+ symbols = cache["symbols"]
172
+ idx = symbols.index(selected_landmark) if selected_landmark in symbols else 0
173
+
174
+ return make_single_landmark_view(
175
+ cache["orig"],
176
+ cache["coords"],
177
+ cache["symbols"],
178
+ cache["heatmaps"],
179
+ idx,
180
+ )
181
+
182
+ with gr.Blocks() as demo:
183
+ gr.Markdown("## Cephalogram Landmark Inference")
184
+
185
+ cache_state = gr.State()
186
+
187
+ with gr.Row():
188
+ with gr.Column(scale=1, min_width=320):
189
+ input_image = gr.Image(type="numpy", label="Input Image", height=420)
190
+ run_button = gr.Button("Run Inference", variant="primary")
191
+ selected_landmark = gr.Dropdown(
192
+ choices=default_symbols,
193
+ value=default_symbols[0],
194
+ label="Landmark Heatmap Selector",
195
+ )
196
+
197
+ with gr.Column(scale=2):
198
+ with gr.Row():
199
+ out_orig = gr.Image(label="Original", height=284)
200
+ out_pred = gr.Image(label="Predictions", height=284)
201
+ with gr.Row():
202
+ out_sum = gr.Image(label="All-Landmark Heatmap Overlay", height=284)
203
+ out_single = gr.Image(label="Selected Landmark Heatmap Overlay", height=284)
204
+
205
+ out_table = gr.Dataframe(
206
+ headers=["Landmark", "X", "Y", "Confidence"],
207
+ label="Predictions",
208
+ interactive=False,
209
+ wrap=True,
210
+ )
211
+
212
+ run_button.click(
213
+ fn=run_inference,
214
+ inputs=[input_image],
215
+ outputs=[
216
+ out_orig,
217
+ out_pred,
218
+ out_sum,
219
+ out_single,
220
+ out_table,
221
+ cache_state,
222
+ selected_landmark,
223
+ ],
224
+ )
225
+
226
+ selected_landmark.change(
227
+ fn=update_selected_landmark,
228
+ inputs=[selected_landmark, cache_state],
229
+ outputs=[out_single],
230
+ )
231
 
232
+ return demo
 
 
 
 
 
 
233
 
234
 
235
  if __name__ == "__main__":
236
+ cli_args = parse_args()
237
+ model, model_args, checkpoint_symbols = load_model(cli_args.checkpoint, cli_args.device)
238
+ checkpoint_symbols = [
239
+ "A", "ANS", "B", "Me", "N", "Or", "Pog", "PNS", "Pn", "R",
240
+ "S", "Ar", "Co", "Gn", "Go", "Po", "LPM", "LIT", "LMT", "UPM",
241
+ "UIA", "UIT", "UMT", "LIA", "Li", "Ls", "N`", "Pog`", "Sn"
242
+ ] # TEMPORARY HARD CODE
243
+
244
+ demo = build_demo(model, model_args, checkpoint_symbols, cli_args.device)
245
+ demo.launch(
246
+ # server_name=cli_args.server_name,
247
+ # server_port=cli_args.server_port,
248
+ # share=cli_args.share,
249
+ # inbrowser=cli_args.inbrowser,
250
+ )
best.pt.enc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:018b89108a18be63689aea6ce7d2cedbc22c09e1fe28558b5bddd901efb3f558
3
+ size 976710027
requirements.txt CHANGED
@@ -1,3 +1,16 @@
1
- gradio
2
- torch
3
- huggingface_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+
3
+ albumentations==1.3.1
4
+ cryptography==46.0.6
5
+ gradio==4.44.1
6
+ huggingface_hub==0.31.2
7
+ numpy==1.26.3
8
+ opencv-python==4.11.0.86
9
+ pandas==2.3.3
10
+ pillow==10.4.0
11
+ pydantic==2.10.6
12
+ timm==1.0.9
13
+ torch==2.5.1
14
+ torchvision==0.20.1
15
+ torchaudio==2.5.1
16
+ tqdm==4.66.5
secure_torch_load.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import io
3
+ import os
4
+ import torch
5
+ from typing import Optional
6
+ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
7
+
8
+
9
+ def _parse_key(key_str: str) -> bytes:
10
+ key_str = key_str.strip()
11
+
12
+ try:
13
+ key = bytes.fromhex(key_str)
14
+ if len(key) == 32:
15
+ return key
16
+ except ValueError:
17
+ pass
18
+
19
+ key = key_str.encode("utf-8")
20
+ if len(key) == 32:
21
+ return key
22
+
23
+ raise ValueError("Key must be either a 64-character hex string or a 32-character raw string.")
24
+
25
+
26
+ def _get_key(key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes:
27
+ if key is not None:
28
+ return _parse_key(key)
29
+
30
+ env_value = os.environ.get(env_var)
31
+ if not env_value:
32
+ raise RuntimeError("Missing key. Provide key=... or set environment variable {}.".format(env_var))
33
+ return _parse_key(env_value)
34
+
35
+
36
+ def decrypt_and_decompress_to_bytes(path: str, key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes:
37
+ key_bytes = _get_key(key=key, env_var=env_var)
38
+ aesgcm = AESGCM(key_bytes)
39
+
40
+ with open(path, "rb") as f:
41
+ data = f.read()
42
+
43
+ if len(data) < 13:
44
+ raise ValueError("Encrypted file is too short or invalid.")
45
+
46
+ nonce = data[:12]
47
+ ciphertext = data[12:]
48
+ compressed = aesgcm.decrypt(nonce, ciphertext, None)
49
+ plaintext = gzip.decompress(compressed)
50
+ return plaintext
51
+
52
+
53
+ def secure_torch_load(path: str, *args, key: Optional[str] = None, env_var: str = "MODEL_KEY", **kwargs):
54
+ plaintext = decrypt_and_decompress_to_bytes(path, key=key, env_var=env_var)
55
+ buffer = io.BytesIO(plaintext)
56
+ return torch.load(buffer, *args, **kwargs)