| """DINOv3 Tagger — FastAPI + Jinja2 Web UI |
| |
| Usage |
| ----- |
| python tagger_ui_server.py \ |
| --checkpoint tagger_dino_v3/checkpoints/2026-03-28_22-57-47.safetensors \ |
| --vocab tagger_dino_v3/tagger_vocab.json \ |
| --host 0.0.0.0 \ |
| --port 7860 |
| |
| Then open http://localhost:7860 in your browser. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import io |
| from pathlib import Path |
|
|
| import torch |
| import uvicorn |
| from fastapi import FastAPI, File, HTTPException, Query, UploadFile |
| from fastapi.responses import HTMLResponse |
| from fastapi.templating import Jinja2Templates |
| from fastapi.requests import Request |
| from PIL import Image |
|
|
| |
| from inference_tagger_standalone import Tagger, preprocess_image, _open_image |
|
|
| |
| |
| |
|
|
| app = FastAPI(title="DINOv3 Tagger UI") |
| templates = Jinja2Templates(directory=Path(__file__).parent / "tagger_ui" / "templates") |
| templates.env.filters["format_number"] = lambda v: f"{v:,}" |
|
|
| _tagger: Tagger | None = None |
| _vocab_path: str = "" |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def index(request: Request): |
| return templates.TemplateResponse("index.html", { |
| "request": request, |
| "num_tags": _tagger.num_tags if _tagger else 0, |
| "vocab_path": _vocab_path, |
| }) |
|
|
|
|
| @app.post("/tag/url") |
| async def tag_url( |
| url: str = Query(..., description="Image URL"), |
| topk: int | None = Query(default=40), |
| threshold: float | None = Query(default=None), |
| max_size: int = Query(default=1024), |
| ): |
| """Tag an image from a URL.""" |
| assert _tagger is not None |
| try: |
| img = _open_image(url) |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Could not fetch image: {e}") |
|
|
| return _run_tagger(img, topk, threshold, max_size) |
|
|
|
|
| @app.post("/tag/upload") |
| async def tag_upload( |
| file: UploadFile = File(...), |
| topk: int | None = Query(default=40), |
| threshold: float | None = Query(default=None), |
| max_size: int = Query(default=1024), |
| ): |
| """Tag an uploaded image file.""" |
| assert _tagger is not None |
| try: |
| data = await file.read() |
| img = Image.open(io.BytesIO(data)).convert("RGB") |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Could not read image: {e}") |
|
|
| return _run_tagger(img, topk, threshold, max_size) |
|
|
|
|
| |
| |
| |
|
|
| def _run_tagger( |
| img: Image.Image, |
| topk: int | None, |
| threshold: float | None, |
| max_size: int, |
| ) -> dict: |
| assert _tagger is not None |
|
|
| if topk is None and threshold is None: |
| topk = 40 |
|
|
| |
| from inference_tagger_standalone import _snap, PATCH_SIZE, _IMAGENET_MEAN, _IMAGENET_STD |
| import torchvision.transforms.v2 as v2 |
|
|
| w, h = img.size |
| scale = min(1.0, max_size / max(w, h)) |
| new_w = _snap(round(w * scale), PATCH_SIZE) |
| new_h = _snap(round(h * scale), PATCH_SIZE) |
|
|
| transform = v2.Compose([ |
| v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS), |
| v2.ToImage(), |
| v2.ToDtype(torch.float32, scale=True), |
| v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD), |
| ]) |
| pixel_values = transform(img).unsqueeze(0).to(_tagger.device) |
|
|
| with torch.no_grad(), torch.autocast(device_type=_tagger.device.type, dtype=_tagger.dtype): |
| logits = _tagger.model(pixel_values)[0] |
|
|
| scores = torch.sigmoid(logits.float()) |
|
|
| if topk is not None: |
| values, indices = scores.topk(min(topk, _tagger.num_tags)) |
| else: |
| assert threshold is not None |
| indices = (scores >= threshold).nonzero(as_tuple=True)[0] |
| values = scores[indices] |
| order = values.argsort(descending=True) |
| indices, values = indices[order], values[order] |
|
|
| tags = [ |
| {"tag": _tagger.idx2tag[i], "score": round(float(v), 4)} |
| for i, v in zip(indices.tolist(), values.tolist()) |
| ] |
| return {"tags": tags, "count": len(tags)} |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| global _tagger, _vocab_path |
|
|
| parser = argparse.ArgumentParser(description="DINOv3 Tagger Web UI") |
| parser.add_argument("--checkpoint", required=True, help="Path to .safetensors checkpoint") |
| parser.add_argument("--vocab", required=True, help="Path to tagger_vocab.json") |
| parser.add_argument("--device", default="cuda", help="cuda / cpu (default: cuda)") |
| parser.add_argument("--max-size", type=int, default=1024, help="Default long-edge cap") |
| parser.add_argument("--host", default="0.0.0.0") |
| parser.add_argument("--port", type=int, default=7860) |
| args = parser.parse_args() |
|
|
| _vocab_path = args.vocab |
| _tagger = Tagger( |
| checkpoint_path=args.checkpoint, |
| vocab_path=args.vocab, |
| device=args.device, |
| max_size=args.max_size, |
| ) |
|
|
| print(f"\n Tagger UI running at http://{args.host}:{args.port}\n") |
| uvicorn.run(app, host=args.host, port=args.port) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|