"""GuardLLM - Precompute Embeddings & t-SNE (resumable, MiniLM-based). Uses sentence-transformers/all-MiniLM-L6-v2 (22M params) to compute embeddings for t-SNE visualization. The downstream risk classifier (Llama Prompt Guard 2) is *not* loaded here - it is loaded by the Gradio app on-demand when a user clicks a point. """ import sys, os, json, logging, time from pathlib import Path import numpy as np import torch logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger("precompute") CACHE_DIR = Path(__file__).parent / "cache" CACHE_FILE = CACHE_DIR / "embeddings_tsne.npz" META_FILE = CACHE_DIR / "metadata.json" SAMPLES_FILE = CACHE_DIR / "samples.json" EMB_CHUNKS_DIR = CACHE_DIR / "emb_chunks_mini" # NEW folder so old chunks don't collide EMB_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" DATASET_ID = "neuralchemy/Prompt-injection-dataset" DATASET_CONFIG = "core" BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "32")) MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "256")) TSNE_PERPLEXITY = 30 TSNE_SEED = 42 SAMPLE_SIZE = int(os.environ.get("SAMPLE_SIZE", "0")) or None TIME_BUDGET = int(os.environ.get("TIME_BUDGET", "35")) STAGE = os.environ.get("STAGE", "auto") def prepare_samples(): if SAMPLES_FILE.exists(): with open(SAMPLES_FILE, "r", encoding="utf-8") as f: s = json.load(f) logger.info("Loaded existing samples.json (%d samples)", len(s)) return s from datasets import load_dataset logger.info("Downloading %s/%s", DATASET_ID, DATASET_CONFIG) ds = load_dataset(DATASET_ID, DATASET_CONFIG) all_samples = [] for split_name in ["train", "validation", "test"]: if split_name in ds: for row in ds[split_name]: all_samples.append({ "text": row["text"], "label": int(row["label"]), "category": row.get("category", "unknown"), "severity": row.get("severity", ""), "source": row.get("source", ""), "split": split_name, }) logger.info("Total %d", len(all_samples)) if SAMPLE_SIZE and SAMPLE_SIZE < len(all_samples): import random random.seed(42) by_cat = {} for s in all_samples: by_cat.setdefault(s["category"], []).append(s) total = len(all_samples) sampled = [] for cat, items in by_cat.items(): n = max(1, round(len(items) / total * SAMPLE_SIZE)) sampled.extend(random.sample(items, min(n, len(items)))) random.shuffle(sampled) all_samples = sampled logger.info("Subsampled to %d", len(all_samples)) CACHE_DIR.mkdir(parents=True, exist_ok=True) with open(SAMPLES_FILE, "w", encoding="utf-8") as f: json.dump(all_samples, f, ensure_ascii=False) return all_samples def mean_pool(last_hidden, attention_mask): mask = attention_mask.unsqueeze(-1).float() s = (last_hidden * mask).sum(dim=1) d = mask.sum(dim=1).clamp(min=1e-9) return s / d def embed_chunked(samples): EMB_CHUNKS_DIR.mkdir(parents=True, exist_ok=True) num_batches = (len(samples) + BATCH_SIZE - 1) // BATCH_SIZE done = {int(p.stem) for p in EMB_CHUNKS_DIR.glob("*.npy")} todo = [b for b in range(num_batches) if b not in done] logger.info("Batches: total=%d done=%d todo=%d", num_batches, len(done), len(todo)) if not todo: return True from transformers import AutoTokenizer, AutoModel logger.info("Loading MiniLM model...") t0 = time.time() tok = AutoTokenizer.from_pretrained(EMB_MODEL_ID) mdl = AutoModel.from_pretrained(EMB_MODEL_ID) mdl.eval() logger.info("Model loaded in %.1fs", time.time() - t0) texts = [s["text"] for s in samples] start = time.time() processed = 0 for b in todo: if time.time() - start > TIME_BUDGET: logger.info("Time budget reached after %d batches", processed) break i = b * BATCH_SIZE bt = texts[i:i + BATCH_SIZE] inputs = tok(bt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH, padding=True) with torch.no_grad(): out = mdl(**inputs) emb = mean_pool(out.last_hidden_state, inputs["attention_mask"]) emb = torch.nn.functional.normalize(emb, p=2, dim=1) emb = emb.cpu().numpy().astype(np.float32) np.save(EMB_CHUNKS_DIR / f"{b}.npy", emb) processed += 1 if processed % 10 == 0 or processed == len(todo): logger.info("batch %d/%d (this run=%d elapsed=%.1fs)", b+1, num_batches, processed, time.time()-start) remaining = len(todo) - processed logger.info("This run: %d batches; remaining: %d", processed, remaining) return remaining == 0 def assemble_and_tsne(samples): from sklearn.manifold import TSNE num_batches = (len(samples) + BATCH_SIZE - 1) // BATCH_SIZE parts = [] for b in range(num_batches): parts.append(np.load(EMB_CHUNKS_DIR / f"{b}.npy")) emb = np.concatenate(parts, axis=0) logger.info("Embeddings shape %s", emb.shape) n = emb.shape[0] perp = min(TSNE_PERPLEXITY, max(5, n - 1)) logger.info("t-SNE perp=%d...", perp) t0 = time.time() try: tsne = TSNE(n_components=2, perplexity=perp, random_state=TSNE_SEED, max_iter=1000, learning_rate="auto", init="pca") except TypeError: tsne = TSNE(n_components=2, perplexity=perp, random_state=TSNE_SEED, max_iter=1000, learning_rate="auto", init="pca") coords = tsne.fit_transform(emb) logger.info("t-SNE done %.1fs", time.time() - t0) np.savez_compressed(CACHE_FILE, embeddings=emb, tsne_2d=coords) meta = [{"text": s["text"], "label": s["label"], "category": s["category"], "severity": s["severity"], "source": s["source"], "split": s["split"]} for s in samples] with open(META_FILE, "w", encoding="utf-8") as f: json.dump(meta, f, ensure_ascii=False) logger.info("Cache complete at %s", CACHE_DIR) def status(): samples_exists = SAMPLES_FILE.exists() n_samples = 0 if samples_exists: with open(SAMPLES_FILE, "r", encoding="utf-8") as f: n_samples = len(json.load(f)) n_done = len(list(EMB_CHUNKS_DIR.glob("*.npy"))) if EMB_CHUNKS_DIR.exists() else 0 n_batches = (n_samples + BATCH_SIZE - 1) // BATCH_SIZE if n_samples else 0 cache_done = CACHE_FILE.exists() and META_FILE.exists() print(f"samples={n_samples} batches_done={n_done}/{n_batches} final_cache={cache_done}") def main(): if STAGE == "status": status(); return if STAGE in ("download", "auto"): samples = prepare_samples() if STAGE == "download": return else: with open(SAMPLES_FILE, "r", encoding="utf-8") as f: samples = json.load(f) if STAGE in ("embed", "auto"): all_done = embed_chunked(samples) if STAGE == "embed" or not all_done: return if STAGE in ("tsne", "auto"): assemble_and_tsne(samples) if __name__ == "__main__": main()