lodestones commited on
Commit
bc0d1c8
·
1 Parent(s): 3625530

Upload tagger_ui_server.py

Browse files
Files changed (1) hide show
  1. tagger_ui_server.py +173 -0
tagger_ui_server.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DINOv3 Tagger — FastAPI + Jinja2 Web UI
2
+
3
+ Usage
4
+ -----
5
+ python tagger_ui_server.py \
6
+ --checkpoint tagger_dino_v3/checkpoints/2026-03-28_22-57-47.safetensors \
7
+ --vocab tagger_dino_v3/tagger_vocab.json \
8
+ --host 0.0.0.0 \
9
+ --port 7860
10
+
11
+ Then open http://localhost:7860 in your browser.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import io
18
+ from pathlib import Path
19
+
20
+ import torch
21
+ import uvicorn
22
+ from fastapi import FastAPI, File, HTTPException, Query, UploadFile
23
+ from fastapi.responses import HTMLResponse
24
+ from fastapi.templating import Jinja2Templates
25
+ from fastapi.requests import Request
26
+ from PIL import Image
27
+
28
+ # Reuse the standalone inference code — no other deps needed
29
+ from inference_tagger_standalone import Tagger, preprocess_image, _open_image
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # App setup
33
+ # ---------------------------------------------------------------------------
34
+
35
+ app = FastAPI(title="DINOv3 Tagger UI")
36
+ templates = Jinja2Templates(directory=Path(__file__).parent / "tagger_ui" / "templates")
37
+ templates.env.filters["format_number"] = lambda v: f"{v:,}"
38
+
39
+ _tagger: Tagger | None = None
40
+ _vocab_path: str = ""
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # Routes
45
+ # ---------------------------------------------------------------------------
46
+
47
+ @app.get("/", response_class=HTMLResponse)
48
+ async def index(request: Request):
49
+ return templates.TemplateResponse("index.html", {
50
+ "request": request,
51
+ "num_tags": _tagger.num_tags if _tagger else 0,
52
+ "vocab_path": _vocab_path,
53
+ })
54
+
55
+
56
+ @app.post("/tag/url")
57
+ async def tag_url(
58
+ url: str = Query(..., description="Image URL"),
59
+ topk: int | None = Query(default=40),
60
+ threshold: float | None = Query(default=None),
61
+ max_size: int = Query(default=1024),
62
+ ):
63
+ """Tag an image from a URL."""
64
+ assert _tagger is not None
65
+ try:
66
+ img = _open_image(url)
67
+ except Exception as e:
68
+ raise HTTPException(status_code=400, detail=f"Could not fetch image: {e}")
69
+
70
+ return _run_tagger(img, topk, threshold, max_size)
71
+
72
+
73
+ @app.post("/tag/upload")
74
+ async def tag_upload(
75
+ file: UploadFile = File(...),
76
+ topk: int | None = Query(default=40),
77
+ threshold: float | None = Query(default=None),
78
+ max_size: int = Query(default=1024),
79
+ ):
80
+ """Tag an uploaded image file."""
81
+ assert _tagger is not None
82
+ try:
83
+ data = await file.read()
84
+ img = Image.open(io.BytesIO(data)).convert("RGB")
85
+ except Exception as e:
86
+ raise HTTPException(status_code=400, detail=f"Could not read image: {e}")
87
+
88
+ return _run_tagger(img, topk, threshold, max_size)
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # Shared inference helper
93
+ # ---------------------------------------------------------------------------
94
+
95
+ def _run_tagger(
96
+ img: Image.Image,
97
+ topk: int | None,
98
+ threshold: float | None,
99
+ max_size: int,
100
+ ) -> dict:
101
+ assert _tagger is not None
102
+
103
+ if topk is None and threshold is None:
104
+ topk = 40
105
+
106
+ # Preprocess from PIL directly (avoids re-opening)
107
+ from inference_tagger_standalone import _snap, PATCH_SIZE, _IMAGENET_MEAN, _IMAGENET_STD
108
+ import torchvision.transforms.v2 as v2
109
+
110
+ w, h = img.size
111
+ scale = min(1.0, max_size / max(w, h))
112
+ new_w = _snap(round(w * scale), PATCH_SIZE)
113
+ new_h = _snap(round(h * scale), PATCH_SIZE)
114
+
115
+ transform = v2.Compose([
116
+ v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
117
+ v2.ToImage(),
118
+ v2.ToDtype(torch.float32, scale=True),
119
+ v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
120
+ ])
121
+ pixel_values = transform(img).unsqueeze(0).to(_tagger.device)
122
+
123
+ with torch.no_grad(), torch.autocast(device_type=_tagger.device.type, dtype=_tagger.dtype):
124
+ logits = _tagger.model(pixel_values)[0]
125
+
126
+ scores = torch.sigmoid(logits.float())
127
+
128
+ if topk is not None:
129
+ values, indices = scores.topk(min(topk, _tagger.num_tags))
130
+ else:
131
+ assert threshold is not None
132
+ indices = (scores >= threshold).nonzero(as_tuple=True)[0]
133
+ values = scores[indices]
134
+ order = values.argsort(descending=True)
135
+ indices, values = indices[order], values[order]
136
+
137
+ tags = [
138
+ {"tag": _tagger.idx2tag[i], "score": round(float(v), 4)}
139
+ for i, v in zip(indices.tolist(), values.tolist())
140
+ ]
141
+ return {"tags": tags, "count": len(tags)}
142
+
143
+
144
+ # ---------------------------------------------------------------------------
145
+ # Entry point
146
+ # ---------------------------------------------------------------------------
147
+
148
+ def main():
149
+ global _tagger, _vocab_path
150
+
151
+ parser = argparse.ArgumentParser(description="DINOv3 Tagger Web UI")
152
+ parser.add_argument("--checkpoint", required=True, help="Path to .safetensors checkpoint")
153
+ parser.add_argument("--vocab", required=True, help="Path to tagger_vocab.json")
154
+ parser.add_argument("--device", default="cuda", help="cuda / cpu (default: cuda)")
155
+ parser.add_argument("--max-size", type=int, default=1024, help="Default long-edge cap")
156
+ parser.add_argument("--host", default="0.0.0.0")
157
+ parser.add_argument("--port", type=int, default=7860)
158
+ args = parser.parse_args()
159
+
160
+ _vocab_path = args.vocab
161
+ _tagger = Tagger(
162
+ checkpoint_path=args.checkpoint,
163
+ vocab_path=args.vocab,
164
+ device=args.device,
165
+ max_size=args.max_size,
166
+ )
167
+
168
+ print(f"\n Tagger UI running at http://{args.host}:{args.port}\n")
169
+ uvicorn.run(app, host=args.host, port=args.port)
170
+
171
+
172
+ if __name__ == "__main__":
173
+ main()