taggerine / tagger_ui_server.py
lodestones's picture
Upload tagger_ui_server.py
bc0d1c8
raw
history blame
5.72 kB
"""DINOv3 Tagger — FastAPI + Jinja2 Web UI
Usage
-----
python tagger_ui_server.py \
--checkpoint tagger_dino_v3/checkpoints/2026-03-28_22-57-47.safetensors \
--vocab tagger_dino_v3/tagger_vocab.json \
--host 0.0.0.0 \
--port 7860
Then open http://localhost:7860 in your browser.
"""
from __future__ import annotations
import argparse
import io
from pathlib import Path
import torch
import uvicorn
from fastapi import FastAPI, File, HTTPException, Query, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.requests import Request
from PIL import Image
# Reuse the standalone inference code — no other deps needed
from inference_tagger_standalone import Tagger, preprocess_image, _open_image
# ---------------------------------------------------------------------------
# App setup
# ---------------------------------------------------------------------------
app = FastAPI(title="DINOv3 Tagger UI")
templates = Jinja2Templates(directory=Path(__file__).parent / "tagger_ui" / "templates")
templates.env.filters["format_number"] = lambda v: f"{v:,}"
_tagger: Tagger | None = None
_vocab_path: str = ""
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {
"request": request,
"num_tags": _tagger.num_tags if _tagger else 0,
"vocab_path": _vocab_path,
})
@app.post("/tag/url")
async def tag_url(
url: str = Query(..., description="Image URL"),
topk: int | None = Query(default=40),
threshold: float | None = Query(default=None),
max_size: int = Query(default=1024),
):
"""Tag an image from a URL."""
assert _tagger is not None
try:
img = _open_image(url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not fetch image: {e}")
return _run_tagger(img, topk, threshold, max_size)
@app.post("/tag/upload")
async def tag_upload(
file: UploadFile = File(...),
topk: int | None = Query(default=40),
threshold: float | None = Query(default=None),
max_size: int = Query(default=1024),
):
"""Tag an uploaded image file."""
assert _tagger is not None
try:
data = await file.read()
img = Image.open(io.BytesIO(data)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not read image: {e}")
return _run_tagger(img, topk, threshold, max_size)
# ---------------------------------------------------------------------------
# Shared inference helper
# ---------------------------------------------------------------------------
def _run_tagger(
img: Image.Image,
topk: int | None,
threshold: float | None,
max_size: int,
) -> dict:
assert _tagger is not None
if topk is None and threshold is None:
topk = 40
# Preprocess from PIL directly (avoids re-opening)
from inference_tagger_standalone import _snap, PATCH_SIZE, _IMAGENET_MEAN, _IMAGENET_STD
import torchvision.transforms.v2 as v2
w, h = img.size
scale = min(1.0, max_size / max(w, h))
new_w = _snap(round(w * scale), PATCH_SIZE)
new_h = _snap(round(h * scale), PATCH_SIZE)
transform = v2.Compose([
v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
])
pixel_values = transform(img).unsqueeze(0).to(_tagger.device)
with torch.no_grad(), torch.autocast(device_type=_tagger.device.type, dtype=_tagger.dtype):
logits = _tagger.model(pixel_values)[0]
scores = torch.sigmoid(logits.float())
if topk is not None:
values, indices = scores.topk(min(topk, _tagger.num_tags))
else:
assert threshold is not None
indices = (scores >= threshold).nonzero(as_tuple=True)[0]
values = scores[indices]
order = values.argsort(descending=True)
indices, values = indices[order], values[order]
tags = [
{"tag": _tagger.idx2tag[i], "score": round(float(v), 4)}
for i, v in zip(indices.tolist(), values.tolist())
]
return {"tags": tags, "count": len(tags)}
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main():
global _tagger, _vocab_path
parser = argparse.ArgumentParser(description="DINOv3 Tagger Web UI")
parser.add_argument("--checkpoint", required=True, help="Path to .safetensors checkpoint")
parser.add_argument("--vocab", required=True, help="Path to tagger_vocab.json")
parser.add_argument("--device", default="cuda", help="cuda / cpu (default: cuda)")
parser.add_argument("--max-size", type=int, default=1024, help="Default long-edge cap")
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
_vocab_path = args.vocab
_tagger = Tagger(
checkpoint_path=args.checkpoint,
vocab_path=args.vocab,
device=args.device,
max_size=args.max_size,
)
print(f"\n Tagger UI running at http://{args.host}:{args.port}\n")
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()