| import contextlib, io, base64, torch, json, os, threading |
| from PIL import Image |
| import open_clip |
| from huggingface_hub import hf_hub_download, create_commit, CommitOperationAdd |
| from safetensors.torch import save_file, load_file |
| from reparam import reparameterize_model |
|
|
| ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "") |
| HF_LABEL_REPO = os.getenv("HF_LABEL_REPO", "") |
| HF_WRITE_TOKEN = os.getenv("HF_WRITE_TOKEN", "") |
| HF_READ_TOKEN = os.getenv("HF_READ_TOKEN", HF_WRITE_TOKEN) |
|
|
|
|
| def _fingerprint(device: str, dtype: torch.dtype) -> dict: |
| return { |
| "model_id": "MobileCLIP-B", |
| "pretrained": "datacompdr", |
| "open_clip": getattr(open_clip, "__version__", "unknown"), |
| "torch": torch.__version__, |
| "cuda": torch.version.cuda if torch.cuda.is_available() else None, |
| "dtype_runtime": str(dtype), |
| "text_norm": "L2", |
| "logit_scale": 100.0, |
| } |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.dtype = torch.float16 if self.device == "cuda" else torch.float32 |
|
|
| |
| model, _, self.preprocess = open_clip.create_model_and_transforms( |
| "MobileCLIP-B", pretrained="datacompdr" |
| ) |
| model.eval() |
| model = reparameterize_model(model) |
| model.to(self.device) |
| if self.device == "cuda": |
| model = model.to(torch.float16) |
| self.model = model |
| self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B") |
| self.fingerprint = _fingerprint(self.device, self.dtype) |
| self._lock = threading.Lock() |
|
|
| |
| loaded = False |
| if HF_LABEL_REPO: |
| with contextlib.suppress(Exception): |
| loaded = self._load_snapshot_from_hub_latest() |
| if not loaded: |
| items_path = "items.json" if not path else f"{path}/items.json" |
| with open(items_path, "r", encoding="utf-8") as f: |
| items = json.load(f) |
| prompts = [it["prompt"] for it in items] |
| self.class_ids = [int(it["id"]) for it in items] |
| self.class_names = [it["name"] for it in items] |
| with torch.no_grad(): |
| toks = self.tokenizer(prompts).to(self.device) |
| feats = self.model.encode_text(toks) |
| feats = feats / feats.norm(dim=-1, keepdim=True) |
| self.text_features_cpu = feats.detach().cpu().to(torch.float32).contiguous() |
| self._to_device() |
| self.labels_version = 1 |
|
|
| def __call__(self, data): |
| payload = data.get("inputs", data) |
|
|
| |
| op = payload.get("op") |
| if op == "upsert_labels": |
| if payload.get("token") != ADMIN_TOKEN: |
| return {"error": "unauthorized"} |
| items = payload.get("items", []) or [] |
| added = self._upsert_items(items) |
| if added > 0: |
| new_ver = int(getattr(self, "labels_version", 1)) + 1 |
| try: |
| self._persist_snapshot_to_hub(new_ver) |
| self.labels_version = new_ver |
| except Exception as e: |
| return {"status": "error", "added": added, "detail": str(e)} |
| return {"status": "ok", "added": added, "labels_version": getattr(self, "labels_version", 1)} |
|
|
| |
| if op == "reload_labels": |
| if payload.get("token") != ADMIN_TOKEN: |
| return {"error": "unauthorized"} |
| try: |
| ver = int(payload.get("version")) |
| except Exception: |
| return {"error": "invalid_version"} |
| ok = self._load_snapshot_from_hub_version(ver) |
| return {"status": "ok" if ok else "nochange", "labels_version": getattr(self, "labels_version", 0)} |
|
|
| |
| if op == "remove_labels": |
| if payload.get("token") != ADMIN_TOKEN: |
| return {"error": "unauthorized"} |
| ids_to_remove = set(payload.get("ids", [])) |
| if not ids_to_remove: |
| return {"error": "no_ids_provided"} |
| |
| removed = self._remove_items(ids_to_remove) |
| if removed > 0: |
| new_ver = int(getattr(self, "labels_version", 1)) + 1 |
| try: |
| self._persist_snapshot_to_hub(new_ver) |
| self.labels_version = new_ver |
| except Exception as e: |
| return {"status": "error", "removed": removed, "detail": str(e)} |
| return {"status": "ok", "removed": removed, "labels_version": getattr(self, "labels_version", 1)} |
|
|
| |
| min_ver = payload.get("min_labels_version") |
| if isinstance(min_ver, int) and min_ver > getattr(self, "labels_version", 0): |
| with contextlib.suppress(Exception): |
| self._load_snapshot_from_hub_version(min_ver) |
|
|
| |
| img_b64 = payload["image"] |
| image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") |
| img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
| if self.device == "cuda": |
| img_tensor = img_tensor.to(torch.float16) |
| with torch.no_grad(): |
| img_feat = self.model.encode_image(img_tensor) |
| img_feat /= img_feat.norm(dim=-1, keepdim=True) |
| probs = (100.0 * img_feat @ self.text_features.T).softmax(dim=-1)[0] |
| results = zip(self.class_ids, self.class_names, probs.detach().cpu().tolist()) |
| top_k = int(payload.get("top_k", len(self.class_ids))) |
| return sorted( |
| [{"id": i, "label": name, "score": float(p)} for i, name, p in results], |
| key=lambda x: x["score"], |
| reverse=True, |
| )[:top_k] |
|
|
| |
| def _encode_text(self, prompts): |
| with torch.no_grad(): |
| toks = self.tokenizer(prompts).to(self.device) |
| feats = self.model.encode_text(toks) |
| feats = feats / feats.norm(dim=-1, keepdim=True) |
| return feats |
|
|
| def _to_device(self): |
| self.text_features = self.text_features_cpu.to( |
| self.device, dtype=(torch.float16 if self.device == "cuda" else torch.float32) |
| ) |
|
|
| def _upsert_items(self, new_items): |
| if not new_items: |
| return 0 |
| with self._lock: |
| |
| known_ids = set(getattr(self, "class_ids", [])) |
| |
| known_names_lower = set(name.lower() for name in getattr(self, "class_names", [])) |
| |
| |
| batch = [] |
| for it in new_items: |
| item_id = int(it.get("id")) |
| item_name = it.get("name") |
| |
| |
| if item_id in known_ids: |
| continue |
| elif item_name.lower() in known_names_lower: |
| continue |
| else: |
| batch.append(it) |
| |
| if not batch: |
| return 0 |
| |
| |
| prompts = [it["prompt"] for it in batch] |
| feats = self._encode_text(prompts).detach().cpu().to(torch.float32) |
| |
| |
| if not hasattr(self, "text_features_cpu"): |
| self.text_features_cpu = feats.contiguous() |
| self.class_ids = [int(it["id"]) for it in batch] |
| self.class_names = [it["name"] for it in batch] |
| else: |
| self.text_features_cpu = torch.cat([self.text_features_cpu, feats], dim=0).contiguous() |
| self.class_ids.extend([int(it["id"]) for it in batch]) |
| self.class_names.extend([it["name"] for it in batch]) |
| |
| self._to_device() |
| return len(batch) |
|
|
| def _remove_items(self, ids_to_remove): |
| if not ids_to_remove or not hasattr(self, "class_ids"): |
| return 0 |
| with self._lock: |
| ids_to_remove = set(int(id) for id in ids_to_remove) |
| |
| indices_to_keep = [] |
| removed_count = 0 |
| for i, class_id in enumerate(self.class_ids): |
| if class_id not in ids_to_remove: |
| indices_to_keep.append(i) |
| else: |
| removed_count += 1 |
| |
| if removed_count == 0: |
| return 0 |
| |
| |
| if indices_to_keep: |
| self.text_features_cpu = self.text_features_cpu[indices_to_keep].contiguous() |
| self.class_ids = [self.class_ids[i] for i in indices_to_keep] |
| self.class_names = [self.class_names[i] for i in indices_to_keep] |
| else: |
| |
| self.text_features_cpu = torch.empty(0, self.text_features_cpu.shape[1]) |
| self.class_ids = [] |
| self.class_names = [] |
| |
| self._to_device() |
| return removed_count |
|
|
| def _persist_snapshot_to_hub(self, version: int): |
| if not HF_LABEL_REPO: |
| raise RuntimeError("HF_LABEL_REPO not set") |
| if not HF_WRITE_TOKEN: |
| raise RuntimeError("HF_WRITE_TOKEN not set for publishing") |
|
|
| emb_path = "/tmp/embeddings.safetensors" |
| meta_path = "/tmp/meta.json" |
| latest_bytes = io.BytesIO(json.dumps({"version": int(version)}).encode("utf-8")) |
|
|
| save_file({"embeddings": self.text_features_cpu.to(torch.float32)}, emb_path) |
| meta = { |
| "items": [{"id": int(i), "name": n} for i, n in zip(self.class_ids, self.class_names)], |
| "fingerprint": self.fingerprint, |
| "dims": int(self.text_features_cpu.shape[1]), |
| "count": int(self.text_features_cpu.shape[0]), |
| "version": int(version), |
| } |
| with open(meta_path, "w", encoding="utf-8") as f: |
| json.dump(meta, f) |
|
|
| ops = [ |
| CommitOperationAdd( |
| path_in_repo=f"snapshots/v{version}/embeddings.safetensors", |
| path_or_fileobj=emb_path |
| ), |
| CommitOperationAdd( |
| path_in_repo=f"snapshots/v{version}/meta.json", |
| path_or_fileobj=meta_path |
| ), |
| CommitOperationAdd( |
| path_in_repo="snapshots/latest.json", |
| path_or_fileobj=latest_bytes |
| ), |
| ] |
| create_commit( |
| repo_id=HF_LABEL_REPO, |
| repo_type="dataset", |
| operations=ops, |
| token=HF_WRITE_TOKEN, |
| commit_message=f"labels v{version}", |
| ) |
|
|
| def _load_snapshot_from_hub_version(self, version: int) -> bool: |
| if not HF_LABEL_REPO: |
| return False |
| with self._lock: |
| emb_p = hf_hub_download( |
| HF_LABEL_REPO, |
| f"snapshots/v{version}/embeddings.safetensors", |
| repo_type="dataset", |
| token=HF_READ_TOKEN, |
| force_download=True, |
| ) |
| meta_p = hf_hub_download( |
| HF_LABEL_REPO, |
| f"snapshots/v{version}/meta.json", |
| repo_type="dataset", |
| token=HF_READ_TOKEN, |
| force_download=True, |
| ) |
| meta = json.load(open(meta_p, "r", encoding="utf-8")) |
| if meta.get("fingerprint") != self.fingerprint: |
| raise RuntimeError("Embedding/model fingerprint mismatch") |
| feats = load_file(emb_p)["embeddings"] |
| self.text_features_cpu = feats.contiguous() |
| self.class_ids = [int(x["id"]) for x in meta.get("items", [])] |
| self.class_names = [x["name"] for x in meta.get("items", [])] |
| self.labels_version = int(meta.get("version", version)) |
| self._to_device() |
| return True |
|
|
| def _load_snapshot_from_hub_latest(self) -> bool: |
| if not HF_LABEL_REPO: |
| return False |
| try: |
| latest_p = hf_hub_download( |
| HF_LABEL_REPO, |
| "snapshots/latest.json", |
| repo_type="dataset", |
| token=HF_READ_TOKEN, |
| ) |
| except Exception: |
| return False |
| latest = json.load(open(latest_p, "r", encoding="utf-8")) |
| ver = int(latest.get("version", 0)) |
| if ver <= 0: |
| return False |
| return self._load_snapshot_from_hub_version(ver) |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
|
|