"""DINOv3 Tagger — FastAPI + Jinja2 Web UI Usage ----- python tagger_ui_server.py \ --checkpoint tagger_dino_v3/checkpoints/2026-03-28_22-57-47.safetensors \ --vocab tagger_dino_v3/tagger_vocab.json \ --host 0.0.0.0 \ --port 7860 Then open http://localhost:7860 in your browser. """ from __future__ import annotations import argparse import io from pathlib import Path import torch import uvicorn from fastapi import FastAPI, File, HTTPException, Query, UploadFile from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from fastapi.requests import Request from PIL import Image # Reuse the standalone inference code — no other deps needed from inference_tagger_standalone import Tagger, preprocess_image, _open_image # --------------------------------------------------------------------------- # App setup # --------------------------------------------------------------------------- app = FastAPI(title="DINOv3 Tagger UI") templates = Jinja2Templates(directory=Path(__file__).parent / "tagger_ui" / "templates") templates.env.filters["format_number"] = lambda v: f"{v:,}" _tagger: Tagger | None = None _vocab_path: str = "" # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @app.get("/", response_class=HTMLResponse) async def index(request: Request): return templates.TemplateResponse("index.html", { "request": request, "num_tags": _tagger.num_tags if _tagger else 0, "vocab_path": _vocab_path, }) @app.post("/tag/url") async def tag_url( url: str = Query(..., description="Image URL"), topk: int | None = Query(default=40), threshold: float | None = Query(default=None), max_size: int = Query(default=1024), ): """Tag an image from a URL.""" assert _tagger is not None try: img = _open_image(url) except Exception as e: raise HTTPException(status_code=400, detail=f"Could not fetch image: {e}") return _run_tagger(img, topk, threshold, max_size) @app.post("/tag/upload") async def tag_upload( file: UploadFile = File(...), topk: int | None = Query(default=40), threshold: float | None = Query(default=None), max_size: int = Query(default=1024), ): """Tag an uploaded image file.""" assert _tagger is not None try: data = await file.read() img = Image.open(io.BytesIO(data)).convert("RGB") except Exception as e: raise HTTPException(status_code=400, detail=f"Could not read image: {e}") return _run_tagger(img, topk, threshold, max_size) # --------------------------------------------------------------------------- # Shared inference helper # --------------------------------------------------------------------------- def _run_tagger( img: Image.Image, topk: int | None, threshold: float | None, max_size: int, ) -> dict: assert _tagger is not None if topk is None and threshold is None: topk = 40 # Preprocess from PIL directly (avoids re-opening) from inference_tagger_standalone import _snap, PATCH_SIZE, _IMAGENET_MEAN, _IMAGENET_STD import torchvision.transforms.v2 as v2 w, h = img.size scale = min(1.0, max_size / max(w, h)) new_w = _snap(round(w * scale), PATCH_SIZE) new_h = _snap(round(h * scale), PATCH_SIZE) transform = v2.Compose([ v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD), ]) pixel_values = transform(img).unsqueeze(0).to(_tagger.device) with torch.no_grad(), torch.autocast(device_type=_tagger.device.type, dtype=_tagger.dtype): logits = _tagger.model(pixel_values)[0] scores = torch.sigmoid(logits.float()) if topk is not None: values, indices = scores.topk(min(topk, _tagger.num_tags)) else: assert threshold is not None indices = (scores >= threshold).nonzero(as_tuple=True)[0] values = scores[indices] order = values.argsort(descending=True) indices, values = indices[order], values[order] tags = [ {"tag": _tagger.idx2tag[i], "score": round(float(v), 4)} for i, v in zip(indices.tolist(), values.tolist()) ] return {"tags": tags, "count": len(tags)} # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- def main(): global _tagger, _vocab_path parser = argparse.ArgumentParser(description="DINOv3 Tagger Web UI") parser.add_argument("--checkpoint", required=True, help="Path to .safetensors checkpoint") parser.add_argument("--vocab", required=True, help="Path to tagger_vocab.json") parser.add_argument("--device", default="cuda", help="cuda / cpu (default: cuda)") parser.add_argument("--max-size", type=int, default=1024, help="Default long-edge cap") parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) args = parser.parse_args() _vocab_path = args.vocab _tagger = Tagger( checkpoint_path=args.checkpoint, vocab_path=args.vocab, device=args.device, max_size=args.max_size, ) print(f"\n Tagger UI running at http://{args.host}:{args.port}\n") uvicorn.run(app, host=args.host, port=args.port) if __name__ == "__main__": main()