# app.py import io import uuid import json import threading import hashlib from contextvars import ContextVar from typing import Optional, Dict, Any import torch import torch.nn.functional as F import timm from PIL import Image from fastapi import FastAPI, UploadFile, File, Query, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from timm.layers.pos_embed import resample_abs_pos_embed try: from timm.layers.patch_embed import resample_patch_embed except Exception: resample_patch_embed = None # ----------------------- # Config # ----------------------- MODEL_NAME = "flexivit_large.300ep_in1k" TARGET_IMG = 96 TARGET_PATCH = 32 NEW_GRID = (TARGET_IMG // TARGET_PATCH, TARGET_IMG // TARGET_PATCH) # (3,3) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ImageNet normalization IMNET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) IMNET_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) # ----------------------- # Load labels (local file recommended) # ----------------------- def load_imagenet_labels(path="imagenet_classes.txt"): try: with open(path, "r", encoding="utf-8") as f: return [line.strip() for line in f.readlines() if line.strip()] except FileNotFoundError: # If missing, still works but without names. return None IMAGENET_LABELS = load_imagenet_labels() # ----------------------- # Build & adapt model once # ----------------------- def adapt_flexivit_to_3x3(model: torch.nn.Module): # --- Resize patch embedding conv weight --- with torch.no_grad(): proj = model.patch_embed.proj w = proj.weight.detach().cpu() # [embed_dim, in_chans, old_ps, old_ps] b = proj.bias.detach().cpu() if proj.bias is not None else None old_ps = w.shape[-1] if old_ps != TARGET_PATCH: if resample_patch_embed is not None: w2 = resample_patch_embed(w, (TARGET_PATCH, TARGET_PATCH)) else: ed, ic, _, _ = w.shape w_ = w.reshape(ed * ic, 1, old_ps, old_ps) w_ = F.interpolate(w_, size=(TARGET_PATCH, TARGET_PATCH), mode="bicubic", align_corners=False) w2 = w_.reshape(ed, ic, TARGET_PATCH, TARGET_PATCH) else: w2 = w embed_dim, in_chans, _, _ = w2.shape new_proj = torch.nn.Conv2d( in_channels=in_chans, out_channels=embed_dim, kernel_size=TARGET_PATCH, stride=TARGET_PATCH, padding=0, bias=(b is not None), ) new_proj.weight.copy_(w2) if b is not None: new_proj.bias.copy_(b) model.patch_embed.proj = new_proj.to(DEVICE) # Update patch embed metadata if present if hasattr(model.patch_embed, "patch_size"): model.patch_embed.patch_size = (TARGET_PATCH, TARGET_PATCH) if hasattr(model.patch_embed, "img_size"): model.patch_embed.img_size = (TARGET_IMG, TARGET_IMG) if hasattr(model.patch_embed, "grid_size"): model.patch_embed.grid_size = NEW_GRID if hasattr(model.patch_embed, "num_patches"): model.patch_embed.num_patches = NEW_GRID[0] * NEW_GRID[1] # --- Resize absolute positional embeddings to 3x3 --- if hasattr(model, "pos_embed") and model.pos_embed is not None: with torch.no_grad(): pe = model.pos_embed.detach() # infer prefix tokens (cls, dist, etc.) prefix = int(getattr(model, "num_prefix_tokens", 0)) if prefix == 0 and hasattr(model, "cls_token") and model.cls_token is not None: prefix = 1 # infer old grid old_grid = None if hasattr(model, "patch_embed") and hasattr(model.patch_embed, "grid_size"): old_grid = tuple(model.patch_embed.grid_size) if old_grid is not None: grid_tokens = old_grid[0] * old_grid[1] if pe.shape[1] == grid_tokens: prefix = 0 elif pe.shape[1] == grid_tokens + prefix: pass else: prefix = 0 old_grid = None if old_grid is None: n = pe.shape[1] - prefix g = int(n ** 0.5) old_grid = (g, g) pe2 = resample_abs_pos_embed( pe, new_size=NEW_GRID, old_size=old_grid, num_prefix_tokens=prefix, interpolation="bicubic", antialias=True, ) model.pos_embed = torch.nn.Parameter(pe2) return model def build_model(): model = timm.create_model(MODEL_NAME, pretrained=True).to(DEVICE).eval() # (Recommended) disable fused attention if present (helps hooks) for blk in model.blocks: if hasattr(blk.attn, "fused_attn"): blk.attn.fused_attn = False model = adapt_flexivit_to_3x3(model) return model MODEL = build_model() print(f"[server] model={MODEL_NAME} device={DEVICE} grid={NEW_GRID}") # ----------------------- # Hooks using ContextVar (safe-ish for concurrent requests) # ----------------------- _attn_var: ContextVar[Optional[list]] = ContextVar("_attn_var", default=None) _tok_var: ContextVar[Optional[list]] = ContextVar("_tok_var", default=None) def _save_attn_drop_input(module, inp, out): lst = _attn_var.get() if lst is not None and len(inp) > 0 and torch.is_tensor(inp[0]): # inp[0]: [B, H, N, N] lst.append(inp[0].detach().cpu()) def _save_block_out(module, inp, out): lst = _tok_var.get() if lst is not None and torch.is_tensor(out): # out: [B, N, D] lst.append(out.detach()) # Register hooks once ATTN_HOOKS = [] TOK_HOOKS = [] for blk in MODEL.blocks: ATTN_HOOKS.append(blk.attn.attn_drop.register_forward_hook(_save_attn_drop_input)) TOK_HOOKS.append(blk.register_forward_hook(_save_block_out)) # ----------------------- # Preprocess # ----------------------- def preprocess(pil_img: Image.Image) -> torch.Tensor: img = pil_img.convert("RGB") w, h = img.size s = min(w, h) left = (w - s) // 2 top = (h - s) // 2 img = img.crop((left, top, left + s, top + s)).resize((TARGET_IMG, TARGET_IMG), Image.BICUBIC) x = torch.from_numpy( (torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())) .view(TARGET_IMG, TARGET_IMG, 3).numpy()).astype("float32") / 255.0 ) x = x.permute(2, 0, 1) # CHW x = (x - IMNET_MEAN) / IMNET_STD return x.unsqueeze(0) # [1,3,H,W] # ----------------------- # Compute logit lens + attention export # ----------------------- def compute_logit_lens_from_tokens(tokens_per_layer, model): logits_list = [] probs_list = [] with torch.no_grad(): for x_l in tokens_per_layer: x_ln = model.norm(x_l) if hasattr(model, "norm") and model.norm is not None else x_l cls_l = x_ln[:, 0] # CLS token logits_l = model.head(cls_l) logits_list.append(logits_l.detach().cpu()) probs_list.append(torch.softmax(logits_l, dim=-1).detach().cpu()) logits_per_layer = torch.stack(logits_list, dim=0) # [L,B,C] probs_per_layer = torch.stack(probs_list, dim=0) return logits_per_layer, probs_per_layer def round_tensor(t: torch.Tensor, decimals: int): s = 10 ** decimals return torch.round(t * s) / s MODEL_LOCK = threading.Lock() def analyze_image(pil_img: Image.Image) -> Dict[str, Any]: x = preprocess(pil_img).to(DEVICE) # Per-request storage attn_maps = [] layer_tokens = [] tok_token = _tok_var.set(layer_tokens) attn_token = _attn_var.set(attn_maps) try: with torch.no_grad(): # Lock recommended if you run multiple workers/threads with GPU, # and because we use shared model + hooks with MODEL_LOCK: logits_final = MODEL(x) # Final probs probs_final = torch.softmax(logits_final, dim=-1)[0].detach().cpu() probs_final = round_tensor(probs_final, 6) # Logit lens logits_by_layer, probs_by_layer = compute_logit_lens_from_tokens(layer_tokens, MODEL) # Export logit lens json export_logit = { "model": MODEL_NAME, "grid": [NEW_GRID[0], NEW_GRID[1]], "num_layers": int(logits_by_layer.shape[0]), "num_classes": int(logits_by_layer.shape[-1]), "class_names": IMAGENET_LABELS, "logits": [], "final_probs": probs_final.tolist() } for l in range(logits_by_layer.shape[0]): v = logits_by_layer[l, 0] # [C] v = round_tensor(v, 3) export_logit["logits"].append(v.tolist()) # Attention json # attn_maps is list length L, each: [B,H,N,N] CPU attn_maps2 = [a.squeeze(0) for a in attn_maps] # -> [H,N,N] if len(attn_maps2) == 0: raise RuntimeError("No attention captured. (Hook may not match this timm model/config)") attn_serializable = [] for layer in attn_maps2: layer_data = [] for head in layer: head = round_tensor(head, 4) layer_data.append(head.tolist()) attn_serializable.append(layer_data) export_attn = { "num_layers": len(attn_serializable), "num_heads": len(attn_serializable[0]), "num_tokens": len(attn_serializable[0][0]), "grid": [NEW_GRID[0], NEW_GRID[1]], "attention": attn_serializable } return { "logit_lens_full": export_logit, "attention_full": export_attn } finally: _tok_var.reset(tok_token) _attn_var.reset(attn_token) layer_tokens.clear() attn_maps.clear() # ----------------------- # FastAPI app # ----------------------- app = FastAPI(title="ViT Explainer API", version="1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], # tighten in prod allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # In-memory store for "file-like endpoints" (job-based) RESULTS: Dict[str, Dict[str, Any]] = {} # In-memory store for "current files" (no-regenerate on GET) CURRENT: Dict[str, Any] = { "hash": None, "attention_full": None, "logit_lens_full": None, } def _no_store(resp: JSONResponse) -> JSONResponse: resp.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0" resp.headers["Pragma"] = "no-cache" return resp @app.get("/health") def health(): return { "status": "ok", "model": MODEL_NAME, "device": DEVICE, "grid": list(NEW_GRID), "has_current": CURRENT["attention_full"] is not None, } # ----------------------- # Legacy: returns JSON directly OR job endpoints # ----------------------- @app.post("/analyze") async def analyze( file: UploadFile = File(...), store: int = Query(0, description="1 => guarda resultados y entrega endpoints /results/{id}/..."), ): if not file.content_type or not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Please upload an image file.") raw = await file.read() try: img = Image.open(io.BytesIO(raw)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Could not decode image.") try: out = analyze_image(img) except Exception as e: raise HTTPException(status_code=500, detail=f"Model inference failed: {e}") if store == 1: job_id = str(uuid.uuid4()) RESULTS[job_id] = out return { "job_id": job_id, "endpoints": { "attention_full": f"/results/{job_id}/attention_full.json", "logit_lens_full": f"/results/{job_id}/logit_lens_full.json", } } return out @app.get("/results/{job_id}/attention_full.json") def get_attention(job_id: str): if job_id not in RESULTS: raise HTTPException(status_code=404, detail="job_id not found") return _no_store(JSONResponse(RESULTS[job_id]["attention_full"])) @app.get("/results/{job_id}/logit_lens_full.json") def get_logit(job_id: str): if job_id not in RESULTS: raise HTTPException(status_code=404, detail="job_id not found") return _no_store(JSONResponse(RESULTS[job_id]["logit_lens_full"])) # ----------------------- # Preferred: "current files" endpoints (keep frontend fetch paths stable) # - POST /analyze_current only when image changes # - GET /attention_full.json and /logit_lens_full.json are just readers # ----------------------- @app.post("/analyze_current") async def analyze_current(file: UploadFile = File(...)): if not file.content_type or not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Please upload an image file.") raw = await file.read() img_hash = hashlib.sha256(raw).hexdigest() # ✅ no regenerate if same image already processed if CURRENT["hash"] == img_hash and CURRENT["attention_full"] is not None: return {"status": "unchanged", "hash": img_hash} try: img = Image.open(io.BytesIO(raw)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Could not decode image.") try: out = analyze_image(img) except Exception as e: raise HTTPException(status_code=500, detail=f"Model inference failed: {e}") CURRENT["hash"] = img_hash CURRENT["attention_full"] = out["attention_full"] CURRENT["logit_lens_full"] = out["logit_lens_full"] return {"status": "updated", "hash": img_hash} @app.get("/attention_full.json") def attention_full_current(): if CURRENT["attention_full"] is None: raise HTTPException(status_code=404, detail="No attention computed yet. Call POST /analyze_current first.") return _no_store(JSONResponse(CURRENT["attention_full"])) @app.get("/logit_lens_full.json") def logit_lens_current(): if CURRENT["logit_lens_full"] is None: raise HTTPException(status_code=404, detail="No logit lens computed yet. Call POST /analyze_current first.") return _no_store(JSONResponse(CURRENT["logit_lens_full"]))