lodestones commited on
Commit
ccc01da
·
1 Parent(s): 45887e5

Update tagger_ui_server.py

Browse files
Files changed (1) hide show
  1. tagger_ui_server.py +111 -66
tagger_ui_server.py CHANGED
@@ -1,14 +1,12 @@
1
- """DINOv3 Tagger — FastAPI + Jinja2 Web UI
2
 
3
  Usage
4
  -----
5
  python tagger_ui_server.py \
6
  --checkpoint tagger_dino_v3/checkpoints/2026-03-28_22-57-47.safetensors \
7
- --vocab tagger_dino_v3/tagger_vocab.json \
8
- --host 0.0.0.0 \
9
- --port 7860
10
-
11
- Then open http://localhost:7860 in your browser.
12
  """
13
 
14
  from __future__ import annotations
@@ -18,25 +16,56 @@ import io
18
  from pathlib import Path
19
 
20
  import torch
 
21
  import uvicorn
22
  from fastapi import FastAPI, File, HTTPException, Query, UploadFile
 
23
  from fastapi.responses import HTMLResponse
24
  from fastapi.templating import Jinja2Templates
25
- from fastapi.requests import Request
26
  from PIL import Image
27
 
28
- # Reuse the standalone inference code — no other deps needed
29
- from inference_tagger_standalone import Tagger, preprocess_image, _open_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # ---------------------------------------------------------------------------
32
- # App setup
33
  # ---------------------------------------------------------------------------
34
 
35
  app = FastAPI(title="DINOv3 Tagger UI")
36
- templates = Jinja2Templates(directory=Path(__file__).parent / "tagger_ui" / "templates")
 
 
37
  templates.env.filters["format_number"] = lambda v: f"{v:,}"
38
 
39
  _tagger: Tagger | None = None
 
40
  _vocab_path: str = ""
41
 
42
 
@@ -47,98 +76,106 @@ _vocab_path: str = ""
47
  @app.get("/", response_class=HTMLResponse)
48
  async def index(request: Request):
49
  return templates.TemplateResponse("index.html", {
50
- "request": request,
51
- "num_tags": _tagger.num_tags if _tagger else 0,
52
- "vocab_path": _vocab_path,
 
53
  })
54
 
55
 
56
  @app.post("/tag/url")
57
  async def tag_url(
58
- url: str = Query(..., description="Image URL"),
59
- topk: int | None = Query(default=40),
60
- threshold: float | None = Query(default=None),
61
- max_size: int = Query(default=1024),
62
  ):
63
- """Tag an image from a URL."""
64
  assert _tagger is not None
65
  try:
 
66
  img = _open_image(url)
67
  except Exception as e:
68
  raise HTTPException(status_code=400, detail=f"Could not fetch image: {e}")
69
-
70
- return _run_tagger(img, topk, threshold, max_size)
71
 
72
 
73
  @app.post("/tag/upload")
74
  async def tag_upload(
75
- file: UploadFile = File(...),
76
- topk: int | None = Query(default=40),
77
- threshold: float | None = Query(default=None),
78
- max_size: int = Query(default=1024),
79
  ):
80
- """Tag an uploaded image file."""
81
  assert _tagger is not None
82
  try:
83
  data = await file.read()
84
  img = Image.open(io.BytesIO(data)).convert("RGB")
85
  except Exception as e:
86
  raise HTTPException(status_code=400, detail=f"Could not read image: {e}")
87
-
88
- return _run_tagger(img, topk, threshold, max_size)
89
 
90
 
91
  # ---------------------------------------------------------------------------
92
- # Shared inference helper
93
  # ---------------------------------------------------------------------------
94
 
95
  def _run_tagger(
96
- img: Image.Image,
97
- topk: int | None,
98
- threshold: float | None,
99
- max_size: int,
100
  ) -> dict:
 
 
 
101
  assert _tagger is not None
102
 
103
- if topk is None and threshold is None:
104
- topk = 40
105
-
106
- # Preprocess from PIL directly (avoids re-opening)
107
- from inference_tagger_standalone import _snap, PATCH_SIZE, _IMAGENET_MEAN, _IMAGENET_STD
108
- import torchvision.transforms.v2 as v2
109
-
110
  w, h = img.size
111
  scale = min(1.0, max_size / max(w, h))
112
  new_w = _snap(round(w * scale), PATCH_SIZE)
113
  new_h = _snap(round(h * scale), PATCH_SIZE)
114
 
115
- transform = v2.Compose([
116
  v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
117
  v2.ToImage(),
118
  v2.ToDtype(torch.float32, scale=True),
119
  v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
120
- ])
121
- pixel_values = transform(img).unsqueeze(0).to(_tagger.device)
122
 
123
  with torch.no_grad(), torch.autocast(device_type=_tagger.device.type, dtype=_tagger.dtype):
124
  logits = _tagger.model(pixel_values)[0]
125
 
126
  scores = torch.sigmoid(logits.float())
127
 
128
- if topk is not None:
129
- values, indices = scores.topk(min(topk, _tagger.num_tags))
130
- else:
131
- assert threshold is not None
132
- indices = (scores >= threshold).nonzero(as_tuple=True)[0]
133
- values = scores[indices]
134
- order = values.argsort(descending=True)
135
- indices, values = indices[order], values[order]
136
-
137
- tags = [
138
- {"tag": _tagger.idx2tag[i], "score": round(float(v), 4)}
139
- for i, v in zip(indices.tolist(), values.tolist())
140
- ]
141
- return {"tags": tags, "count": len(tags)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
 
144
  # ---------------------------------------------------------------------------
@@ -146,26 +183,34 @@ def _run_tagger(
146
  # ---------------------------------------------------------------------------
147
 
148
  def main():
149
- global _tagger, _vocab_path
 
150
 
151
  parser = argparse.ArgumentParser(description="DINOv3 Tagger Web UI")
152
- parser.add_argument("--checkpoint", required=True, help="Path to .safetensors checkpoint")
153
- parser.add_argument("--vocab", required=True, help="Path to tagger_vocab.json")
154
- parser.add_argument("--device", default="cuda", help="cuda / cpu (default: cuda)")
155
- parser.add_argument("--max-size", type=int, default=1024, help="Default long-edge cap")
156
- parser.add_argument("--host", default="0.0.0.0")
157
- parser.add_argument("--port", type=int, default=7860)
 
158
  args = parser.parse_args()
159
 
160
  _vocab_path = args.vocab
 
 
 
 
 
 
161
  _tagger = Tagger(
162
  checkpoint_path=args.checkpoint,
163
- vocab_path=args.vocab,
164
  device=args.device,
165
  max_size=args.max_size,
166
  )
167
 
168
- print(f"\n Tagger UI running at http://{args.host}:{args.port}\n")
169
  uvicorn.run(app, host=args.host, port=args.port)
170
 
171
 
 
1
+ """DINOv3 Tagger — FastAPI + Jinja2 Web UI (with category breakdown)
2
 
3
  Usage
4
  -----
5
  python tagger_ui_server.py \
6
  --checkpoint tagger_dino_v3/checkpoints/2026-03-28_22-57-47.safetensors \
7
+ --vocab tagger_dino_v3/tagger_vocab_with_categories.json \
8
+ --host 0.0.0.0 \
9
+ --port 7860
 
 
10
  """
11
 
12
  from __future__ import annotations
 
16
  from pathlib import Path
17
 
18
  import torch
19
+ import torchvision.transforms.v2 as v2
20
  import uvicorn
21
  from fastapi import FastAPI, File, HTTPException, Query, UploadFile
22
+ from fastapi.requests import Request
23
  from fastapi.responses import HTMLResponse
24
  from fastapi.templating import Jinja2Templates
 
25
  from PIL import Image
26
 
27
+ from inference_tagger_standalone import (
28
+ PATCH_SIZE,
29
+ Tagger,
30
+ _IMAGENET_MEAN,
31
+ _IMAGENET_STD,
32
+ _snap,
33
+ )
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Category metadata
37
+ # ---------------------------------------------------------------------------
38
+
39
+ # Raw category IDs from the vocab use -1 for unassigned.
40
+ # We offset every ID by +1 so all IDs are >= 0, avoiding negative
41
+ # numbers in HTML element IDs and JS inline handlers.
42
+ _CAT_OFFSET = 1
43
+
44
+ CATEGORY_META: dict[int, dict] = {
45
+ 0: {"name": "unassigned", "color": "#6b7280"}, # raw -1
46
+ 1: {"name": "general", "color": "#4ade80"}, # raw 0
47
+ 2: {"name": "artist", "color": "#f472b6"}, # raw 1
48
+ 3: {"name": "contributor", "color": "#a78bfa"}, # raw 2
49
+ 4: {"name": "copyright", "color": "#fb923c"}, # raw 3
50
+ 5: {"name": "character", "color": "#60a5fa"}, # raw 4
51
+ 6: {"name": "species/meta", "color": "#facc15"}, # raw 5
52
+ 7: {"name": "disambiguation", "color": "#94a3b8"}, # raw 6
53
+ 8: {"name": "meta", "color": "#e2e8f0"}, # raw 7
54
+ 9: {"name": "lore", "color": "#f87171"}, # raw 8
55
+ }
56
 
57
  # ---------------------------------------------------------------------------
58
+ # App
59
  # ---------------------------------------------------------------------------
60
 
61
  app = FastAPI(title="DINOv3 Tagger UI")
62
+ templates = Jinja2Templates(
63
+ directory=Path(__file__).parent / "tagger_ui" / "templates"
64
+ )
65
  templates.env.filters["format_number"] = lambda v: f"{v:,}"
66
 
67
  _tagger: Tagger | None = None
68
+ _tag2category: dict[str, int] = {}
69
  _vocab_path: str = ""
70
 
71
 
 
76
  @app.get("/", response_class=HTMLResponse)
77
  async def index(request: Request):
78
  return templates.TemplateResponse("index.html", {
79
+ "request": request,
80
+ "num_tags": _tagger.num_tags if _tagger else 0,
81
+ "vocab_path": _vocab_path,
82
+ "category_meta": CATEGORY_META,
83
  })
84
 
85
 
86
  @app.post("/tag/url")
87
  async def tag_url(
88
+ url: str = Query(...),
89
+ max_size: int = Query(default=1024),
90
+ floor: float = Query(default=0.05),
 
91
  ):
 
92
  assert _tagger is not None
93
  try:
94
+ from inference_tagger_standalone import _open_image
95
  img = _open_image(url)
96
  except Exception as e:
97
  raise HTTPException(status_code=400, detail=f"Could not fetch image: {e}")
98
+ return _run_tagger(img, max_size, floor)
 
99
 
100
 
101
  @app.post("/tag/upload")
102
  async def tag_upload(
103
+ file: UploadFile = File(...),
104
+ max_size: int = Query(default=1024),
105
+ floor: float = Query(default=0.05),
 
106
  ):
 
107
  assert _tagger is not None
108
  try:
109
  data = await file.read()
110
  img = Image.open(io.BytesIO(data)).convert("RGB")
111
  except Exception as e:
112
  raise HTTPException(status_code=400, detail=f"Could not read image: {e}")
113
+ return _run_tagger(img, max_size, floor)
 
114
 
115
 
116
  # ---------------------------------------------------------------------------
117
+ # Inference helper
118
  # ---------------------------------------------------------------------------
119
 
120
  def _run_tagger(
121
+ img: Image.Image,
122
+ max_size: int,
123
+ floor: float = 0.05,
 
124
  ) -> dict:
125
+ """Return every tag whose sigmoid score >= floor, sorted desc.
126
+ The frontend applies per-category topk / threshold on top of this.
127
+ """
128
  assert _tagger is not None
129
 
 
 
 
 
 
 
 
130
  w, h = img.size
131
  scale = min(1.0, max_size / max(w, h))
132
  new_w = _snap(round(w * scale), PATCH_SIZE)
133
  new_h = _snap(round(h * scale), PATCH_SIZE)
134
 
135
+ pixel_values = v2.Compose([
136
  v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
137
  v2.ToImage(),
138
  v2.ToDtype(torch.float32, scale=True),
139
  v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
140
+ ])(img).unsqueeze(0).to(_tagger.device)
 
141
 
142
  with torch.no_grad(), torch.autocast(device_type=_tagger.device.type, dtype=_tagger.dtype):
143
  logits = _tagger.model(pixel_values)[0]
144
 
145
  scores = torch.sigmoid(logits.float())
146
 
147
+ # Return all tags above the floor, sorted by score descending
148
+ indices = (scores >= floor).nonzero(as_tuple=True)[0]
149
+ values = scores[indices]
150
+ order = values.argsort(descending=True)
151
+ indices = indices[order]
152
+ values = values[order]
153
+
154
+ # Build per-category buckets
155
+ by_category: dict[int, list] = {}
156
+ all_tags = []
157
+ for i, v in zip(indices.tolist(), values.tolist()):
158
+ tag = _tagger.idx2tag[i]
159
+ cat = _tag2category.get(tag, -1) + _CAT_OFFSET
160
+ item = {"tag": tag, "score": round(v, 4), "category": cat}
161
+ all_tags.append(item)
162
+ by_category.setdefault(cat, []).append(item)
163
+
164
+ categories = []
165
+ for cat_id in sorted(by_category.keys()):
166
+ meta = CATEGORY_META.get(cat_id, {"name": str(cat_id), "color": "#6b7280"})
167
+ categories.append({
168
+ "id": cat_id,
169
+ "name": meta["name"],
170
+ "color": meta["color"],
171
+ "tags": by_category[cat_id],
172
+ })
173
+
174
+ return {
175
+ "tags": all_tags,
176
+ "categories": categories,
177
+ "count": len(all_tags),
178
+ }
179
 
180
 
181
  # ---------------------------------------------------------------------------
 
183
  # ---------------------------------------------------------------------------
184
 
185
  def main():
186
+ global _tagger, _tag2category, _vocab_path
187
+ import json
188
 
189
  parser = argparse.ArgumentParser(description="DINOv3 Tagger Web UI")
190
+ parser.add_argument("--checkpoint", required=True)
191
+ parser.add_argument("--vocab", required=True,
192
+ help="Path to tagger_vocab_with_categories.json")
193
+ parser.add_argument("--device", default="cuda")
194
+ parser.add_argument("--max-size", type=int, default=1024)
195
+ parser.add_argument("--host", default="0.0.0.0")
196
+ parser.add_argument("--port", type=int, default=7860)
197
  args = parser.parse_args()
198
 
199
  _vocab_path = args.vocab
200
+
201
+ # Load tag→category mapping from the enriched vocab file
202
+ with open(args.vocab) as f:
203
+ vocab_data = json.load(f)
204
+ _tag2category = vocab_data.get("tag2category", {})
205
+
206
  _tagger = Tagger(
207
  checkpoint_path=args.checkpoint,
208
+ vocab_path=args.vocab, # Tagger only reads idx2tag from this
209
  device=args.device,
210
  max_size=args.max_size,
211
  )
212
 
213
+ print(f"\n Tagger UI http://{args.host}:{args.port}\n")
214
  uvicorn.run(app, host=args.host, port=args.port)
215
 
216