Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """Scimplify.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/11L85VXrmvxrfXd6A9FGuJjI53nVtM0tN | |
| # Scimplify | |
| A NeuroAI paper simplifier. You paste a paragraph and get a plain-language | |
| explanation back, with citations to the retrieved chunks the explanation | |
| came from. The system refuses to answer if it can't ground the claims. | |
| ## 1. Setup | |
| """ | |
| import os, json, re, io, textwrap, time | |
| from pathlib import Path | |
| from collections import Counter, defaultdict | |
| from typing import List, Dict, Tuple, Optional | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import requests | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import openai | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| from PyPDF2 import PdfReader | |
| try: | |
| import os | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| except (ImportError, Exception): | |
| assert os.environ.get("OPENAI_API_KEY"), "Set OPENAI_API_KEY env var" | |
| client_oai = openai.OpenAI() | |
| GENERATOR_MODEL = "gpt-4o-mini" | |
| JUDGE_MODEL = "gpt-4o-mini" | |
| GENERATOR_TEMPERATURE = 0.2 | |
| JUDGE_TEMPERATURE = 0.3 | |
| JUDGE_N_SAMPLES = 3 | |
| BOOTSTRAP_N = 2000 | |
| BOOTSTRAP_ALPHA = 0.05 | |
| _rng = np.random.default_rng(7) | |
| RUN_EXPERIMENTS = False # re-run experiments | |
| LIVE_SEMANTIC_CHECK = True # adds 1s per query | |
| JUDGE_PARALLELISM = 2 # rate limit cap | |
| print(f"generator: {GENERATOR_MODEL}") | |
| print(f"judge: {JUDGE_MODEL}") | |
| print(f"experiments: {'WILL RE-RUN' if RUN_EXPERIMENTS else 'using cached results'}") | |
| print(f"live semantic check: {'on' if LIVE_SEMANTIC_CHECK else 'off'}") | |
| """## 2. Data loading""" | |
| REPO_RAW_BASE = "https://raw.githubusercontent.com/martazavro/scimplify_data/main" | |
| LOCAL_DATA_DIR = Path("./data") | |
| def _load_json(filename): | |
| url = f"{REPO_RAW_BASE}/{filename}" | |
| try: | |
| r = requests.get(url, timeout=10) | |
| r.raise_for_status() | |
| print(f"loaded {filename} from repo") | |
| return r.json() | |
| except Exception as e: | |
| local = LOCAL_DATA_DIR / filename | |
| if local.exists(): | |
| print(f"repo fetch failed ({e.__class__.__name__}), loaded {filename} from local") | |
| return json.loads(local.read_text()) | |
| raise FileNotFoundError( | |
| f"Could not load {filename}. Set REPO_RAW_BASE correctly " | |
| f"or place the file in ./data/{filename}" | |
| ) | |
| neuroai_concepts = _load_json("concepts.json") | |
| print(f"loaded {len(neuroai_concepts)} concepts") | |
| def validate_validation_set(vs): | |
| items = vs["items"] | |
| ids = [x["id"] for x in items] | |
| assert len(set(ids)) == len(ids), "duplicate ids" | |
| required = {"id", "passage", "source", "key_terms", "category", "difficulty", "reference_explanation"} | |
| valid_cats = {"concepts_only", "recent_paper", "both", "neither"} | |
| for item in items: | |
| missing = required - set(item.keys()) | |
| assert not missing, f"item {item.get('id')} missing {missing}" | |
| assert item["category"] in valid_cats | |
| cat_counts = Counter(x["category"] for x in items) | |
| print(f"validation set: {len(items)} items") | |
| print(f" by category: {dict(cat_counts)}") | |
| validation_set = _load_json("validation_set.json") | |
| validate_validation_set(validation_set) | |
| """## 3. PDF extraction and chunking""" | |
| def extract_text_from_pdf(pdf_file): | |
| reader = PdfReader(pdf_file) | |
| text = "" | |
| for page in reader.pages: | |
| page_text = page.extract_text() | |
| if page_text: | |
| text += page_text + "\n" | |
| return text.strip() | |
| def chunk_text(text, chunk_size=300, overlap=50): | |
| paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()] | |
| chunks, current, current_len = [], [], 0 | |
| for para in paragraphs: | |
| words = para.split() | |
| n = len(words) | |
| if n > chunk_size: | |
| if current: | |
| chunks.append(" ".join(current)) | |
| tail = current[-overlap:] if len(current) > overlap else current | |
| current = list(tail); current_len = len(current) | |
| for i in range(0, n, chunk_size - overlap): | |
| chunk = words[i:i+chunk_size] | |
| if len(chunk) > 30: | |
| chunks.append(" ".join(chunk)) | |
| current = []; current_len = 0 | |
| elif current_len + n > chunk_size: | |
| chunks.append(" ".join(current)) | |
| tail = current[-overlap:] if len(current) > overlap else current | |
| current = list(tail) + words; current_len = len(current) | |
| else: | |
| current.extend(words); current_len += n | |
| if current and len(current) > 30: | |
| chunks.append(" ".join(current)) | |
| return chunks | |
| """## 4. Vector store setup""" | |
| chroma_client = chromadb.Client() | |
| ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2") | |
| def reset_concepts_collection(): | |
| try: | |
| chroma_client.delete_collection("neuroai_concepts") | |
| except Exception: | |
| pass | |
| coll = chroma_client.create_collection(name="neuroai_concepts", embedding_function=ef) | |
| for entry in neuroai_concepts: | |
| doc = ( | |
| f"Concept: {entry['concept']}\n" | |
| f"Definition: {entry['definition']}\n" | |
| f"Context: {entry['context']}\n" | |
| f"Typically found in: {entry['typical_usage']}" | |
| ) | |
| coll.add( | |
| documents=[doc], | |
| ids=[entry["id"]], | |
| metadatas=[{"concept_name": entry["concept"], "concept_id": entry["id"]}] | |
| ) | |
| return coll | |
| def reset_papers_collection(): | |
| try: | |
| chroma_client.delete_collection("neuroai_papers") | |
| except Exception: | |
| pass | |
| return chroma_client.create_collection(name="neuroai_papers", embedding_function=ef) | |
| concepts_collection = reset_concepts_collection() | |
| papers_collection = reset_papers_collection() | |
| print(f"concepts: {concepts_collection.count()}, papers: {papers_collection.count()}") | |
| """## 5. Recent papers ingestion""" | |
| PAPER_CHUNKS_URL = f"{REPO_RAW_BASE}/paper_chunks.json" | |
| def load_paper_chunks(): | |
| r = requests.get(PAPER_CHUNKS_URL, timeout=15) | |
| r.raise_for_status() | |
| return r.json() | |
| def ingest_paper_chunks_from_json(): | |
| chunks = load_paper_chunks() | |
| if not chunks: | |
| print("paper_chunks.json was empty") | |
| return 0 | |
| documents = [c["text"] for c in chunks] | |
| ids = [c["chunk_id"] for c in chunks] | |
| metadatas = [{ | |
| "source_name": c["source_name"], | |
| "source_type": c["source_type"], | |
| "arxiv_id": c["arxiv_id"], | |
| "title": c["title"], | |
| "chunk_idx": c["chunk_idx"], | |
| "chunk_id": c["chunk_id"], | |
| } for c in chunks] | |
| papers_collection.add(documents=documents, ids=ids, metadatas=metadatas) | |
| by_paper = {} | |
| for c in chunks: | |
| by_paper[c["arxiv_id"]] = by_paper.get(c["arxiv_id"], 0) + 1 | |
| for aid, n in by_paper.items(): | |
| print(f" {aid}: {n} chunks") | |
| print(f"papers_collection now has {papers_collection.count()} total chunks") | |
| return len(chunks) | |
| ingest_paper_chunks_from_json() | |
| """## 6. arXiv ingestion""" | |
| import arxiv | |
| def _existing_arxiv_ids(): | |
| if papers_collection.count() == 0: | |
| return set() | |
| metas = papers_collection.get()["metadatas"] | |
| return {m.get("arxiv_id") for m in metas if m.get("arxiv_id")} | |
| def ingest_from_arxiv(query="neuroAI OR (neural AND brain AND deep learning)", | |
| max_results=10, | |
| sort_by_recent=True, | |
| verbose=True): | |
| """Search arXiv, download PDFs, chunk them, add to papers_collection. | |
| Returns dict with stats: {n_papers, n_chunks, n_skipped, errors}. | |
| Already-ingested papers (matched by arxiv_id) are skipped. | |
| """ | |
| sort_by = arxiv.SortCriterion.SubmittedDate if sort_by_recent else arxiv.SortCriterion.Relevance | |
| arxiv_client = arxiv.Client(page_size=20, delay_seconds=3.0, num_retries=3) | |
| search = arxiv.Search(query=query, max_results=max_results, sort_by=sort_by) | |
| existing = _existing_arxiv_ids() | |
| download_dir = Path("./arxiv_papers") | |
| download_dir.mkdir(exist_ok=True) | |
| n_papers, n_chunks, n_skipped = 0, 0, 0 | |
| errors = [] | |
| for result in arxiv_client.results(search): | |
| # arxiv.org/abs/2509.23566v1 -> "2509.23566" | |
| full_id = result.entry_id.rsplit("/", 1)[-1] | |
| arxiv_id = full_id.split("v")[0] | |
| if arxiv_id in existing: | |
| n_skipped += 1 | |
| if verbose: | |
| print(f" skip {arxiv_id} (already ingested)") | |
| continue | |
| try: | |
| if verbose: | |
| print(f" fetching {arxiv_id}: {result.title[:60]}...") | |
| pdf_path = result.download_pdf(dirpath=str(download_dir), | |
| filename=f"{arxiv_id}.pdf") | |
| text = extract_text_from_pdf(pdf_path) | |
| chunks = chunk_text(text) | |
| if not chunks: | |
| errors.append(f"{arxiv_id}: no chunks extracted") | |
| continue | |
| chunk_ids = [f"arxiv_{arxiv_id.replace('.', '_')}::c{i}" for i in range(len(chunks))] | |
| metadatas = [{ | |
| "source_name": result.title, | |
| "source_type": "arxiv_paper", | |
| "arxiv_id": arxiv_id, | |
| "title": result.title, | |
| "chunk_idx": i, | |
| "chunk_id": chunk_ids[i], | |
| } for i in range(len(chunks))] | |
| papers_collection.add(documents=chunks, ids=chunk_ids, metadatas=metadatas) | |
| existing.add(arxiv_id) # avoid double-add within the same call | |
| n_papers += 1 | |
| n_chunks += len(chunks) | |
| if verbose: | |
| print(f" -> added {len(chunks)} chunks") | |
| except Exception as e: | |
| errors.append(f"{arxiv_id}: {e.__class__.__name__}: {e}") | |
| if verbose: | |
| print(f" ERROR: {e}") | |
| summary = { | |
| "n_papers": n_papers, | |
| "n_chunks": n_chunks, | |
| "n_skipped": n_skipped, | |
| "errors": errors, | |
| "total_in_kb": papers_collection.count(), | |
| } | |
| if verbose: | |
| print(f"\ningested {n_papers} papers ({n_chunks} chunks), skipped {n_skipped} duplicates") | |
| if errors: | |
| print(f"errors: {len(errors)}") | |
| print(f"total in knowledge base: {summary['total_in_kb']} chunks") | |
| return summary | |
| ingest_from_arxiv(query="NeuroAI", max_results=2) | |
| ingest_from_arxiv(query="NeuroAI", max_results=15) | |
| """## 7. Retrieval variants""" | |
| def _flexible_last_word(word): | |
| if len(word) < 4: | |
| return re.escape(word) | |
| stem = re.escape(word[:-2]) | |
| return stem + r"[a-z]{0,4}" | |
| def build_concept_patterns(concept_entry): | |
| name = concept_entry["concept"] | |
| abbrev_match = re.search(r"\(([^)]+)\)", name) | |
| abbrev = abbrev_match.group(1) if abbrev_match else None | |
| base = re.sub(r"\s*\([^)]+\)", "", name).strip() | |
| patterns = [] | |
| words = base.split() | |
| if len(words) == 1: | |
| long_re = r"\b" + _flexible_last_word(words[0]) + r"\b" | |
| else: | |
| parts = [re.escape(w) for w in words[:-1]] + [_flexible_last_word(words[-1])] | |
| long_re = r"\b" + r"\s+".join(parts) + r"\b" | |
| patterns.append(re.compile(long_re, re.IGNORECASE)) | |
| if abbrev: | |
| patterns.append(re.compile(r"\b" + re.escape(abbrev) + r"s?\b")) | |
| return patterns | |
| CONCEPT_PATTERNS = [(entry, build_concept_patterns(entry)) for entry in neuroai_concepts] | |
| def _concept_doc_text(entry): | |
| return ( | |
| f"Concept: {entry['concept']}\n" | |
| f"Definition: {entry['definition']}\n" | |
| f"Context: {entry['context']}\n" | |
| f"Typically found in: {entry['typical_usage']}" | |
| ) | |
| def regex_retrieve(passage): | |
| hits = [] | |
| for entry, patterns in CONCEPT_PATTERNS: | |
| if any(p.search(passage) for p in patterns): | |
| hits.append({ | |
| "type": "regex_concept", | |
| "concept_name": entry["concept"], | |
| "concept_id": entry["id"], | |
| "chunk_id": entry["id"], | |
| "content": _concept_doc_text(entry), | |
| "distance": 0.0, | |
| "source_method": "regex", | |
| }) | |
| return hits | |
| def retrieve_concepts_embedding(passage, n_results=3): | |
| results = concepts_collection.query(query_texts=[passage], n_results=n_results) | |
| out = [] | |
| for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0]): | |
| out.append({ | |
| "type": "concept", | |
| "concept_name": meta["concept_name"], | |
| "concept_id": meta.get("concept_id"), | |
| "chunk_id": meta.get("concept_id"), | |
| "content": doc, | |
| "distance": round(dist, 3), | |
| "source_method": "embedding", | |
| }) | |
| return out | |
| def retrieve_paper_chunks(passage, n_results=3): | |
| if papers_collection.count() == 0: | |
| return [] | |
| results = papers_collection.query( | |
| query_texts=[passage], | |
| n_results=min(n_results, papers_collection.count()) | |
| ) | |
| out = [] | |
| for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0]): | |
| out.append({ | |
| "type": "paper_chunk", | |
| "source_name": meta["source_name"], | |
| "source_type": meta["source_type"], | |
| "chunk_id": meta.get("chunk_id"), | |
| "content": doc, | |
| "distance": round(dist, 3), | |
| "source_method": "embedding", | |
| }) | |
| return out | |
| def hybrid_retrieve_concepts(passage, n_embedding=3, max_total=6): | |
| rgx = regex_retrieve(passage) | |
| seen_names = {h["concept_name"] for h in rgx} | |
| out = list(rgx) | |
| if len(out) < max_total: | |
| emb = retrieve_concepts_embedding(passage, n_results=n_embedding) | |
| for hit in emb: | |
| if hit["concept_name"] not in seen_names: | |
| out.append(hit) | |
| seen_names.add(hit["concept_name"]) | |
| else: | |
| for r in out: | |
| if r["concept_name"] == hit["concept_name"]: | |
| r["source_method"] = "both" | |
| break | |
| if len(out) >= max_total: | |
| break | |
| return out | |
| def retrieve_for_variant(passage, variant, n_concepts=3, n_papers=3): | |
| if variant == "no_rag": | |
| return [], [] | |
| elif variant == "embedding_only": | |
| return (retrieve_concepts_embedding(passage, n_results=n_concepts), | |
| retrieve_paper_chunks(passage, n_results=n_papers)) | |
| elif variant == "regex_only": | |
| return (regex_retrieve(passage), | |
| retrieve_paper_chunks(passage, n_results=n_papers)) | |
| elif variant == "hybrid": | |
| return (hybrid_retrieve_concepts(passage, n_embedding=n_concepts), | |
| retrieve_paper_chunks(passage, n_results=n_papers)) | |
| else: | |
| raise ValueError(f"unknown variant: {variant}") | |
| """## 8. Citation-enforced generation with semantic guard | |
| """ | |
| CITED_SYSTEM_PROMPT = """You are a scientific reading assistant that helps people understand passages from NeuroAI research papers. | |
| You have access to retrieved context. Each source has a stable ID in square brackets like [c004] (for a concept definition) or [arxiv_2511_12345::c3] (for a paper chunk). | |
| Your job: | |
| 1. Read the passage. | |
| 2. Rewrite it in plain language an undergraduate could follow. | |
| 3. For EVERY factual sentence in your explanation, append one or more citations in square brackets, drawn ONLY from the IDs of the retrieved sources shown to you. | |
| 4. Do not invent citation IDs. Do not cite sources you were not shown. | |
| 5. If the retrieved context does not contain enough information to answer faithfully, output EXACTLY this string and nothing else: | |
| I don't have enough evidence in the retrieved context. | |
| Format: | |
| **Key terms:** short definitions of technical terms, each with its citation | |
| **Plain-language version:** the passage rewritten clearly, with citations on every factual sentence | |
| **What this means in context:** 1-2 sentences on why this matters, with citations | |
| """ | |
| ABSTAIN_MESSAGE = "I don't have enough evidence in the retrieved context." | |
| CITATION_PATTERN = re.compile(r"\[([a-zA-Z0-9_\-:]+)\]") | |
| SEMANTIC_FAIL_THRESHOLD = 0.5 | |
| def _format_context_block(concept_results, paper_results): | |
| lines = [] | |
| if concept_results: | |
| lines.append("CONCEPT DEFINITIONS:") | |
| for r in concept_results: | |
| cid = r.get("chunk_id") or r.get("concept_id") | |
| lines.append(f"\n[{cid}] {r['content']}") | |
| lines.append("---") | |
| if paper_results: | |
| lines.append("\nPAPER/ARTICLE CONTEXT:") | |
| for r in paper_results: | |
| cid = r.get("chunk_id") | |
| lines.append(f"\n[{cid}] (from {r['source_name']}): {r['content']}") | |
| lines.append("---") | |
| if not concept_results and not paper_results: | |
| lines.append("(no context retrieved)") | |
| return "\n".join(lines) | |
| def _collect_allowed_ids(concept_results, paper_results): | |
| ids = set() | |
| for r in concept_results + paper_results: | |
| cid = r.get("chunk_id") or r.get("concept_id") | |
| if cid: | |
| ids.add(cid) | |
| return ids | |
| def _build_chunk_lookup(concept_results, paper_results): | |
| """Map citation_id -> chunk content. Used by the semantic check.""" | |
| lookup = {} | |
| for r in concept_results + paper_results: | |
| cid = r.get("chunk_id") or r.get("concept_id") | |
| if cid: | |
| lookup[cid] = r["content"] | |
| return lookup | |
| def generate_cited_explanation(passage, concept_results, paper_results, model=None): | |
| model = model or GENERATOR_MODEL | |
| context_block = _format_context_block(concept_results, paper_results) | |
| user_msg = f"{context_block}\n\nPASSAGE TO EXPLAIN:\n{passage}" | |
| resp = client_oai.chat.completions.create( | |
| model=model, | |
| temperature=GENERATOR_TEMPERATURE, | |
| messages=[ | |
| {"role": "system", "content": CITED_SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_msg}, | |
| ], | |
| ) | |
| return resp.choices[0].message.content | |
| def validate_citations(answer: str, allowed_ids: set) -> Tuple[bool, List[str]]: | |
| """Lexical guard: every citation ID in the answer must be in the allowed set.""" | |
| if ABSTAIN_MESSAGE in answer: | |
| return True, [] | |
| cited = CITATION_PATTERN.findall(answer) | |
| issues = [] | |
| if not cited: | |
| issues.append("No citations found in non-abstain answer") | |
| for cid in cited: | |
| if cid not in allowed_ids: | |
| issues.append(f"Invalid citation: {cid}") | |
| return len(issues) == 0, issues | |
| def _split_into_sentences(text): | |
| """Cheap sentence splitter that keeps citation brackets attached.""" | |
| # split on . ! ? followed by space and a capital, keeping the punctuation | |
| parts = re.split(r"(?<=[.!?])\s+(?=[A-Z*])", text.strip()) | |
| return [p.strip() for p in parts if p.strip()] | |
| def _strip_citations(sentence): | |
| return CITATION_PATTERN.sub("", sentence).strip() | |
| def check_sentence_supported(sentence_text, cited_chunks): | |
| claim = _strip_citations(sentence_text) | |
| if len(claim) < 10 or not cited_chunks: | |
| return {"label": "skipped", "reason": "no claim or no chunks"} | |
| evidence = "\n\n".join(f"[{cid}]: {text}" for cid, text in cited_chunks) | |
| return verify_claim_against_evidence(claim, [evidence]) | |
| def semantic_per_sentence_check(answer, chunk_lookup): | |
| if ABSTAIN_MESSAGE in answer: | |
| return [] | |
| sentences = _split_into_sentences(answer) | |
| findings = [] | |
| for sent in sentences: | |
| cited_ids = CITATION_PATTERN.findall(sent) | |
| if not cited_ids: | |
| continue | |
| cited_chunks = [(cid, chunk_lookup[cid]) for cid in cited_ids if cid in chunk_lookup] | |
| if not cited_chunks: | |
| continue | |
| result = check_sentence_supported(sent, cited_chunks) | |
| findings.append({ | |
| "sentence": sent, | |
| "citations": cited_ids, | |
| "label": result["label"], | |
| "reason": result["reason"], | |
| }) | |
| return findings | |
| def annotate_unsupported_sentences(answer, findings): | |
| """Mark unsupported sentences in the rendered output.""" | |
| for f in findings: | |
| if f["label"] in ("contradicted", "insufficient"): | |
| marker = "⚠️ " | |
| if marker not in f["sentence"]: | |
| answer = answer.replace(f["sentence"], marker + f["sentence"], 1) | |
| return answer | |
| def generate_with_citation_guard(passage, concept_results, paper_results, model=None, | |
| allow_no_context_bypass=False, | |
| do_semantic_check=None): | |
| do_semantic_check = (do_semantic_check if do_semantic_check is not None | |
| else LIVE_SEMANTIC_CHECK) | |
| if allow_no_context_bypass and not concept_results and not paper_results: | |
| resp = client_oai.chat.completions.create( | |
| model=model or GENERATOR_MODEL, | |
| temperature=GENERATOR_TEMPERATURE, | |
| messages=[ | |
| {"role": "system", "content": "You are a scientific reading assistant. Explain the given passage in plain language that an undergraduate could follow. Be concise."}, | |
| {"role": "user", "content": f"PASSAGE:\n{passage}"}, | |
| ], | |
| ) | |
| return { | |
| "answer": resp.choices[0].message.content, | |
| "valid_citations": None, | |
| "guard_triggered": False, | |
| "issues": [], | |
| "abstained": False, | |
| "semantic_findings": [], | |
| "semantic_fail_rate": np.nan, | |
| } | |
| raw = generate_cited_explanation(passage, concept_results, paper_results, model=model) | |
| allowed_ids = _collect_allowed_ids(concept_results, paper_results) | |
| ok, issues = validate_citations(raw, allowed_ids) | |
| # lexical guard | |
| if not ok: | |
| return { | |
| "answer": ABSTAIN_MESSAGE, | |
| "valid_citations": False, | |
| "guard_triggered": True, | |
| "issues": issues, | |
| "abstained": True, | |
| "raw_rejected": raw, | |
| "semantic_findings": [], | |
| "semantic_fail_rate": np.nan, | |
| } | |
| # semantic per-sentence check | |
| findings = [] | |
| semantic_fail_rate = np.nan | |
| if do_semantic_check and ABSTAIN_MESSAGE not in raw: | |
| chunk_lookup = _build_chunk_lookup(concept_results, paper_results) | |
| findings = semantic_per_sentence_check(raw, chunk_lookup) | |
| if findings: | |
| n_failed = sum(1 for f in findings if f["label"] in ("contradicted", "insufficient")) | |
| semantic_fail_rate = n_failed / len(findings) | |
| if semantic_fail_rate > SEMANTIC_FAIL_THRESHOLD: | |
| return { | |
| "answer": ABSTAIN_MESSAGE, | |
| "valid_citations": True, | |
| "guard_triggered": True, | |
| "issues": [f"semantic check failed: {n_failed}/{len(findings)} sentences unsupported"], | |
| "abstained": True, | |
| "raw_rejected": raw, | |
| "semantic_findings": findings, | |
| "semantic_fail_rate": semantic_fail_rate, | |
| } | |
| raw = annotate_unsupported_sentences(raw, findings) | |
| return { | |
| "answer": raw, | |
| "valid_citations": True, | |
| "guard_triggered": False, | |
| "issues": [], | |
| "abstained": ABSTAIN_MESSAGE in raw, | |
| "semantic_findings": findings, | |
| "semantic_fail_rate": semantic_fail_rate, | |
| } | |
| """## 9. LLM-as-judge metrics | |
| """ | |
| def _coerce_score(x): | |
| try: | |
| v = int(float(x)) | |
| except Exception: | |
| v = 0 | |
| return max(0, min(2, v)) | |
| def _single_judge_call(system_prompt, user_prompt): | |
| try: | |
| resp = client_oai.chat.completions.create( | |
| model=JUDGE_MODEL, | |
| temperature=JUDGE_TEMPERATURE, | |
| response_format={"type": "json_object"}, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| data = json.loads(resp.choices[0].message.content) | |
| return { | |
| "score": _coerce_score(data.get("score", 0)), | |
| "reason": str(data.get("reason", "")).strip(), | |
| } | |
| except Exception as e: | |
| return {"score": None, "reason": f"ERROR: {e}"} | |
| def _judge_call_parallel(system_prompt, user_prompt, n=None): | |
| """Run n judge calls in parallel via ThreadPoolExecutor.""" | |
| n = n or JUDGE_N_SAMPLES | |
| results = [None] * n | |
| with ThreadPoolExecutor(max_workers=min(n, JUDGE_PARALLELISM)) as ex: | |
| futures = {ex.submit(_single_judge_call, system_prompt, user_prompt): i | |
| for i in range(n)} | |
| for fut in as_completed(futures): | |
| i = futures[fut] | |
| results[i] = fut.result() | |
| return results | |
| def _aggregate(runs): | |
| valid = [r for r in runs if r["score"] is not None] | |
| if not valid: | |
| return {"score": None, "reasons": [r["reason"] for r in runs], "n_valid": 0} | |
| return { | |
| "score": sum(r["score"] for r in valid) / len(valid), | |
| "reasons": [r["reason"] for r in valid], | |
| "n_valid": len(valid), | |
| } | |
| CORRECTNESS_SYSTEM = """You are evaluating answer correctness for a question about a NeuroAI paper passage. | |
| Given a passage, a reference explanation (gold-standard), and a system explanation, score the system explanation's correctness using ONLY the information in the passage and reference. | |
| Return ONLY a JSON object: | |
| {"score": <int 0/1/2>, "reason": "<one sentence>"} | |
| Scoring scale: | |
| - 0 = wrong (contradicts the passage or says something incorrect) | |
| - 1 = partly correct (captures some but not all of the main idea, or adds unsupported claims) | |
| - 2 = correct (faithful to what the passage actually says) | |
| """ | |
| def score_correctness(passage, reference, candidate): | |
| user = f"PASSAGE:\n{passage}\n\nREFERENCE:\n{reference}\n\nSYSTEM EXPLANATION:\n{candidate}" | |
| runs = _judge_call_parallel(CORRECTNESS_SYSTEM, user) | |
| return _aggregate(runs) | |
| EVIDENCE_SYSTEM = """You are evaluating whether a system explanation's key claims are supported by retrieved context. | |
| Given a passage, the retrieved context that was shown to the system, and the system's explanation, score whether the explanation's factual claims are well-supported by the retrieved context. | |
| Return ONLY a JSON object: | |
| {"score": <int 0/1/2>, "reason": "<one sentence>"} | |
| Scoring scale: | |
| - 0 = unsupported (most claims cannot be found in retrieved context) | |
| - 1 = partly supported (some claims supported, others require outside knowledge) | |
| - 2 = well supported (claims are traceable to retrieved context) | |
| If the retrieved context is empty (no RAG baseline), score 0. | |
| """ | |
| def score_evidence_support(passage, retrieved_context, candidate): | |
| user = f"PASSAGE:\n{passage}\n\nRETRIEVED CONTEXT:\n{retrieved_context}\n\nSYSTEM EXPLANATION:\n{candidate}" | |
| runs = _judge_call_parallel(EVIDENCE_SYSTEM, user) | |
| return _aggregate(runs) | |
| CITATION_SYSTEM = """You are evaluating whether citations in a system explanation are faithful. | |
| The system was asked to cite each factual sentence with an ID from the retrieved context (like [c004] or [arxiv_2511_12345::c3]). Given the retrieved context and the system explanation with citations, score whether the citations are relevant and the cited material actually supports the adjacent claim. | |
| Return ONLY a JSON object: | |
| {"score": <int 0/1/2>, "reason": "<one sentence>"} | |
| Scoring scale: | |
| - 0 = unfaithful (citations invented, missing, or do not support adjacent claims) | |
| - 1 = mixed (some citations support their claims, others do not) | |
| - 2 = faithful (citations are present, relevant, and support adjacent claims) | |
| If the answer is the abstention message ("I don't have enough evidence..."), score 2 (correctly declined). | |
| """ | |
| def score_citation_faithfulness(retrieved_context, candidate): | |
| user = f"RETRIEVED CONTEXT:\n{retrieved_context}\n\nSYSTEM EXPLANATION:\n{candidate}" | |
| runs = _judge_call_parallel(CITATION_SYSTEM, user) | |
| return _aggregate(runs) | |
| def score_all_metrics(passage, reference, retrieved_context, candidate): | |
| """Run all three metrics in parallel.""" | |
| with ThreadPoolExecutor(max_workers=3) as ex: | |
| f_c = ex.submit(score_correctness, passage, reference, candidate) | |
| f_e = ex.submit(score_evidence_support, passage, retrieved_context, candidate) | |
| f_f = ex.submit(score_citation_faithfulness, retrieved_context, candidate) | |
| return { | |
| "correctness": f_c.result(), | |
| "evidence_support": f_e.result(), | |
| "citation_faithfulness": f_f.result(), | |
| } | |
| """## 10. Claim-based faithfulness | |
| """ | |
| CLAIM_EXTRACTION_SYSTEM = """Extract atomic factual claims from the given answer. | |
| Return ONLY a JSON object: | |
| {"claims": ["claim 1", "claim 2", ...]} | |
| Rules: | |
| - Each claim should be a single, minimal factual assertion | |
| - Ignore pure formatting, headers, or meta-commentary | |
| - Skip citation markers like [c004] when extracting claims | |
| - If there are no factual claims, return {"claims": []} | |
| """ | |
| EVIDENCE_EXTRACTION_SYSTEM = """Extract factual assertions from the given text chunk. | |
| Return ONLY a JSON object: | |
| {"assertions": ["assertion 1", "assertion 2", ...]} | |
| Rules: | |
| - One atomic factual assertion per entry | |
| - Skip anything that is a question, opinion, or example | |
| - If there are no assertions, return {"assertions": []} | |
| """ | |
| CLAIM_VERIFICATION_SYSTEM = """Classify if a claim is supported, contradicted, or insufficient given evidence. | |
| Return ONLY a JSON object: | |
| {"label": "supported" | "contradicted" | "insufficient", "reason": "<one short sentence>"} | |
| Definitions: | |
| - supported: the evidence directly supports the claim | |
| - contradicted: the evidence contradicts the claim | |
| - insufficient: the evidence is silent or unclear on the claim | |
| """ | |
| def _json_call(system_prompt, user_prompt, model=None): | |
| model = model or JUDGE_MODEL | |
| resp = client_oai.chat.completions.create( | |
| model=model, | |
| temperature=JUDGE_TEMPERATURE, | |
| response_format={"type": "json_object"}, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| try: | |
| return json.loads(resp.choices[0].message.content) | |
| except Exception: | |
| return {} | |
| def extract_claims(answer): | |
| data = _json_call(CLAIM_EXTRACTION_SYSTEM, f"ANSWER:\n{answer}") | |
| return [c for c in data.get("claims", []) if c and isinstance(c, str)] | |
| _ASSERTION_CACHE = {} | |
| def extract_assertions_from_chunk(chunk): | |
| key = hash(chunk) | |
| if key in _ASSERTION_CACHE: | |
| return _ASSERTION_CACHE[key] | |
| data = _json_call(EVIDENCE_EXTRACTION_SYSTEM, f"CHUNK:\n{chunk}") | |
| out = [a for a in data.get("assertions", []) if a and isinstance(a, str)] | |
| _ASSERTION_CACHE[key] = out | |
| return out | |
| def _normalize_label(label): | |
| x = (label or "").strip().lower() | |
| if "support" in x: return "supported" | |
| if "contrad" in x: return "contradicted" | |
| return "insufficient" | |
| def verify_claim_against_evidence(claim, assertions): | |
| evidence_blob = "\n".join(assertions) if assertions else "NO_EVIDENCE" | |
| data = _json_call( | |
| CLAIM_VERIFICATION_SYSTEM, | |
| f"CLAIM:\n{claim}\n\nEVIDENCE:\n{evidence_blob}" | |
| ) | |
| return { | |
| "label": _normalize_label(data.get("label")), | |
| "reason": str(data.get("reason", "")).strip(), | |
| } | |
| def claim_based_faithfulness(answer, retrieved_chunks): | |
| if ABSTAIN_MESSAGE in answer: | |
| return { | |
| "n_claims": 0, | |
| "support_rate": np.nan, | |
| "contradiction_rate": np.nan, | |
| "unsupported_rate": np.nan, | |
| "abstained": True, | |
| "details": [], | |
| } | |
| claims = extract_claims(answer) | |
| if not claims: | |
| return { | |
| "n_claims": 0, | |
| "support_rate": np.nan, | |
| "contradiction_rate": np.nan, | |
| "unsupported_rate": np.nan, | |
| "abstained": False, | |
| "details": [], | |
| } | |
| with ThreadPoolExecutor(max_workers=JUDGE_PARALLELISM) as ex: | |
| all_assertions_lists = list(ex.map(extract_assertions_from_chunk, retrieved_chunks)) | |
| all_assertions = [a for sub in all_assertions_lists for a in sub] | |
| with ThreadPoolExecutor(max_workers=JUDGE_PARALLELISM) as ex: | |
| verify_results = list(ex.map( | |
| lambda c: verify_claim_against_evidence(c, all_assertions), | |
| claims | |
| )) | |
| labels = [r["label"] for r in verify_results] | |
| details = [{"claim": c, **r} for c, r in zip(claims, verify_results)] | |
| n = len(labels) | |
| return { | |
| "n_claims": n, | |
| "support_rate": sum(1 for l in labels if l == "supported") / n, | |
| "contradiction_rate": sum(1 for l in labels if l == "contradicted") / n, | |
| "unsupported_rate": sum(1 for l in labels if l == "insufficient") / n, | |
| "abstained": False, | |
| "details": details, | |
| } | |
| """## 11. Retrieval precision@k / recall@k and bootstrap CIs""" | |
| def precision_recall_at_k(retrieved_chunks, gold_facts, k=3): | |
| if not gold_facts: | |
| return np.nan, np.nan | |
| top_k = retrieved_chunks[:k] | |
| if not top_k: | |
| return 0.0, 0.0 | |
| rel_flags = [] | |
| for chunk in top_k: | |
| c = chunk.lower() | |
| is_rel = any(fact.lower() in c for fact in gold_facts) | |
| rel_flags.append(is_rel) | |
| precision = float(np.mean(rel_flags)) | |
| covered = 0 | |
| for fact in gold_facts: | |
| if any(fact.lower() in chunk.lower() for chunk in top_k): | |
| covered += 1 | |
| recall = covered / len(gold_facts) | |
| return precision, recall | |
| def bootstrap_ci(values, n_boot=None, alpha=None): | |
| n_boot = n_boot or BOOTSTRAP_N | |
| alpha = alpha or BOOTSTRAP_ALPHA | |
| values = np.array(values, dtype=float) | |
| values = values[~np.isnan(values)] | |
| if len(values) == 0: | |
| return np.nan, np.nan, np.nan | |
| boots = np.empty(n_boot) | |
| n = len(values) | |
| for i in range(n_boot): | |
| sample = _rng.choice(values, size=n, replace=True) | |
| boots[i] = sample.mean() | |
| lo = np.percentile(boots, 100 * (alpha / 2)) | |
| hi = np.percentile(boots, 100 * (1 - alpha / 2)) | |
| return float(values.mean()), float(lo), float(hi) | |
| def format_ci(values, digits=3): | |
| m, lo, hi = bootstrap_ci(values) | |
| return f"{m:.{digits}f} [{lo:.{digits}f}, {hi:.{digits}f}]" | |
| """## 12. Logging""" | |
| EVAL_LOG_DIR = Path("./eval_logs") | |
| EVAL_LOG_DIR.mkdir(exist_ok=True) | |
| def log_eval_row(experiment_id, passage_id, variant, retrieved_sources, | |
| generation_result, judge_scores, extra=None): | |
| row = { | |
| "experiment_id": experiment_id, | |
| "passage_id": passage_id, | |
| "variant": variant, | |
| "model": GENERATOR_MODEL, | |
| "n_retrieved": len(retrieved_sources), | |
| "retrieved_chunk_ids": ";".join( | |
| str(r.get("chunk_id") or r.get("concept_id") or "?") for r in retrieved_sources | |
| ), | |
| "guard_triggered": int(generation_result.get("guard_triggered", False)), | |
| "abstained": int(generation_result.get("abstained", False)), | |
| "answer_chars": len(generation_result.get("answer", "")), | |
| "generated_text": generation_result.get("answer", ""), | |
| "correctness": judge_scores.get("correctness", {}).get("score"), | |
| "evidence_support": judge_scores.get("evidence_support", {}).get("score"), | |
| "citation_faithfulness": judge_scores.get("citation_faithfulness", {}).get("score"), | |
| "semantic_fail_rate": generation_result.get("semantic_fail_rate", np.nan), | |
| } | |
| if extra: | |
| row.update(extra) | |
| path = EVAL_LOG_DIR / f"{experiment_id}.csv" | |
| pd.DataFrame([row]).to_csv( | |
| path, mode="a", header=not path.exists(), index=False | |
| ) | |
| return row | |
| def load_or_run_experiment(experiment_id, runner_fn): | |
| local_path = EVAL_LOG_DIR / f"{experiment_id}.csv" | |
| if RUN_EXPERIMENTS: | |
| if local_path.exists(): | |
| local_path.unlink() | |
| print(f"running {experiment_id} from scratch...") | |
| return runner_fn() | |
| url = f"{REPO_RAW_BASE}/eval_logs/{experiment_id}.csv" | |
| try: | |
| df = pd.read_csv(url) | |
| print(f"loaded {experiment_id} from repo cache: {len(df)} rows") | |
| df.to_csv(local_path, index=False) | |
| return df | |
| except Exception: | |
| pass | |
| if local_path.exists(): | |
| df = pd.read_csv(local_path) | |
| print(f"loaded {experiment_id} from local cache: {len(df)} rows") | |
| return df | |
| print(f"⚠ no cached results for {experiment_id}. Set RUN_EXPERIMENTS=True to generate.") | |
| return None | |
| """## 13. Judge calibration | |
| """ | |
| RUN_CALIBRATION = False # HERE | |
| def calibrate_judge(n_items=5): | |
| items = [x for x in validation_set["items"]] | |
| sample = items[:n_items] | |
| diffs = {"correctness": [], "evidence_support": [], "citation_faithfulness": []} | |
| for item in sample: | |
| c, p = retrieve_for_variant(item["passage"], "hybrid") | |
| result = generate_with_citation_guard(item["passage"], c, p, do_semantic_check=False) | |
| explanation = result["answer"] | |
| context_text = _format_context_block(c, p) | |
| print("=" * 70) | |
| print(f"ITEM {item['id']}") | |
| print(f"PASSAGE: {item['passage'][:300]}") | |
| print(f"\nREFERENCE: {item['reference_explanation']}") | |
| print(f"\nSYSTEM EXPLANATION:\n{explanation}") | |
| print("\nScore each metric 0/1/2 (0=bad, 1=partial, 2=good):") | |
| try: | |
| human = { | |
| "correctness": int(input(" correctness: ")), | |
| "evidence_support": int(input(" evidence_support: ")), | |
| "citation_faithfulness": int(input(" citation_faithfulness: ")), | |
| } | |
| except (ValueError, EOFError): | |
| print("aborted") | |
| return None | |
| all_scores = score_all_metrics( | |
| item["passage"], item["reference_explanation"], context_text, explanation | |
| ) | |
| scores_clean = {k: all_scores[k]["score"] for k in all_scores} | |
| for k in diffs: | |
| if scores_clean[k] is not None: | |
| diffs[k].append(abs(human[k] - scores_clean[k])) | |
| print("\n=== CALIBRATION RESULTS ===") | |
| for k, vals in diffs.items(): | |
| if vals: | |
| mad = sum(vals) / len(vals) | |
| flag = " ⚠ DISAGREES" if mad > 0.5 else " ok" | |
| print(f" {k}: mean abs diff = {mad:.2f}{flag}") | |
| return diffs | |
| if RUN_CALIBRATION: | |
| calibrate_judge(n_items=5) | |
| else: | |
| print("calibration skipped (RUN_CALIBRATION=False)") | |
| print("Last calibration: correctness MAD=0.60 (DISAGREES), evidence MAD=0.40, citation MAD=0.20") | |
| """## 14. Experiment A — retrieval ablation | |
| **Question.** Does RAG help, and does the regex tier earn its place? | |
| **Hypothesis.** All RAG variants will beat the no-RAG baseline on claim_support_rate and evidence_support. Hybrid will beat either single-tier variant. | |
| **Variable changed.** Retrieval method ∈ {no_rag, embedding_only, regex_only, hybrid}. Everything else held constant. | |
| """ | |
| def run_experiment_A(): | |
| items = [x for x in validation_set["items"]] | |
| variants = ["no_rag", "embedding_only", "regex_only", "hybrid"] | |
| total_runs = len(items) * len(variants) | |
| print(f"running experiment A: {len(items)} items × {len(variants)} variants = {total_runs} runs") | |
| for i, item in enumerate(items): | |
| for variant in variants: | |
| try: | |
| c, p = retrieve_for_variant(item["passage"], variant) | |
| retrieved = c + p | |
| context_text = _format_context_block(c, p) | |
| result = generate_with_citation_guard( | |
| item["passage"], c, p, | |
| allow_no_context_bypass=(variant == "no_rag"), | |
| do_semantic_check=False, | |
| ) | |
| scores = score_all_metrics( | |
| item["passage"], item["reference_explanation"], | |
| context_text, result["answer"], | |
| ) | |
| cb = claim_based_faithfulness( | |
| result["answer"], [r["content"] for r in retrieved], | |
| ) | |
| rp, rr = precision_recall_at_k( | |
| [r["content"] for r in retrieved], item["key_terms"], k=3, | |
| ) | |
| log_eval_row( | |
| "experiment_A", item["id"], variant, | |
| retrieved, result, scores, | |
| extra={ | |
| "category": item["category"], | |
| "claim_support_rate": cb["support_rate"], | |
| "claim_contradiction_rate": cb["contradiction_rate"], | |
| "claim_unsupported_rate": cb["unsupported_rate"], | |
| "n_claims": cb["n_claims"], | |
| "retrieval_precision_at_3": rp, | |
| "retrieval_recall_at_3": rr, | |
| } | |
| ) | |
| except Exception as e: | |
| print(f" ERROR {item['id']}/{variant}: {e}") | |
| print(f" done {item['id']} ({i+1}/{len(items)})") | |
| return pd.read_csv(EVAL_LOG_DIR / "experiment_A.csv") | |
| experiment_A_df = load_or_run_experiment("experiment_A", run_experiment_A) | |
| def analyze_experiment_A(): | |
| df = pd.read_csv(EVAL_LOG_DIR / "experiment_A.csv") | |
| metric_cols = ["correctness", "evidence_support", "citation_faithfulness", | |
| "claim_support_rate", "retrieval_recall_at_3", "abstained"] | |
| print("=" * 70) | |
| print("OVERALL means with 95% bootstrap CIs") | |
| print("=" * 70) | |
| for variant in ["no_rag", "embedding_only", "regex_only", "hybrid"]: | |
| sub = df[df.variant == variant] | |
| print(f"\n{variant}") | |
| for m in metric_cols: | |
| if m in sub.columns: | |
| print(f" {m:28s} {format_ci(sub[m].values)}") | |
| print("\n" + "=" * 70) | |
| print("HEADLINE METRIC: claim_support_rate (correctness saturates — see report)") | |
| print("=" * 70) | |
| for variant in ["no_rag", "embedding_only", "regex_only", "hybrid"]: | |
| sub = df[df.variant == variant] | |
| if "claim_support_rate" in sub.columns: | |
| print(f" {variant:18s} {format_ci(sub['claim_support_rate'].values)}") | |
| return df | |
| def plot_experiment_A(): | |
| df = pd.read_csv(EVAL_LOG_DIR / "experiment_A.csv") | |
| variant_order = ["no_rag", "embedding_only", "regex_only", "hybrid"] | |
| colors = ["#888", "#4c72b0", "#dd8452", "#55a868"] | |
| fig, axes = plt.subplots(1, 3, figsize=(16, 5)) | |
| means, los, his = [], [], [] | |
| for v in variant_order: | |
| sub = df[df.variant == v] | |
| if "claim_support_rate" in sub.columns: | |
| m, lo, hi = bootstrap_ci(sub["claim_support_rate"].values) | |
| else: | |
| m, lo, hi = 0, 0, 0 | |
| means.append(m); los.append(m - lo); his.append(hi - m) | |
| axes[0].bar(variant_order, means, yerr=[los, his], color=colors, capsize=5) | |
| axes[0].set_title("Claim support rate (headline)") | |
| axes[0].set_ylabel("Fraction of claims supported") | |
| axes[0].set_ylim(0, 1) | |
| axes[0].tick_params(axis="x", rotation=20) | |
| means, los, his = [], [], [] | |
| for v in variant_order: | |
| sub = df[df.variant == v] | |
| if "retrieval_recall_at_3" in sub.columns: | |
| m, lo, hi = bootstrap_ci(sub["retrieval_recall_at_3"].values) | |
| else: | |
| m, lo, hi = 0, 0, 0 | |
| means.append(m); los.append(m - lo); his.append(hi - m) | |
| axes[1].bar(variant_order, means, yerr=[los, his], color=colors, capsize=5) | |
| axes[1].set_title("Retrieval recall@3") | |
| axes[1].set_ylabel("Fraction of gold key_terms covered") | |
| axes[1].set_ylim(0, 1) | |
| axes[1].tick_params(axis="x", rotation=20) | |
| abs_by_var = df.groupby("variant")["abstained"].mean().reindex(variant_order) | |
| axes[2].bar(variant_order, abs_by_var.values, color=colors) | |
| axes[2].set_title("Abstention rate") | |
| axes[2].set_ylabel("Fraction of items guard triggered") | |
| axes[2].set_ylim(0, 1) | |
| axes[2].tick_params(axis="x", rotation=20) | |
| plt.tight_layout() | |
| plt.show() | |
| if experiment_A_df is not None: | |
| analyze_experiment_A() | |
| plot_experiment_A() | |
| """### Release gate""" | |
| def release_gate_A(variant="hybrid"): | |
| df = pd.read_csv(EVAL_LOG_DIR / "experiment_A.csv") | |
| sub = df[df.variant == variant] | |
| thresholds = { | |
| "claim_support_rate": 0.70, # primary | |
| "evidence_support": 1.40, | |
| "citation_faithfulness": 1.40, | |
| "retrieval_recall_at_3": 0.60, | |
| "abstained": 0.30, | |
| } | |
| lower_is_better = {"abstained"} | |
| agg = {k: float(np.nanmean(sub[k].values)) for k in thresholds if k in sub.columns} | |
| print(f"Release gate for variant: {variant}") | |
| print("=" * 60) | |
| all_pass = True | |
| for k, t in thresholds.items(): | |
| if k not in agg: continue | |
| v = agg[k] | |
| ok = (v <= t) if k in lower_is_better else (v >= t) | |
| direction = "≤" if k in lower_is_better else "≥" | |
| status = "PASS" if ok else "FAIL" | |
| print(f" {k:28s} {v:.3f} (need {direction} {t}) {status}") | |
| all_pass = all_pass and ok | |
| print(f"\nFINAL: {'PASS' if all_pass else 'FAIL'}") | |
| return all_pass | |
| if experiment_A_df is not None: | |
| release_gate_A(variant="hybrid") | |
| print() | |
| release_gate_A(variant="regex_only") | |
| """## 16. Experiment B — top-k sweep | |
| **Question:** How does the number of retrieved sources (top-k) affect answer correctness? | |
| **Hypothesis:** Performance peaks somewhere in the middle. k=1 misses context; large k dilutes the prompt with irrelevant chunks. | |
| **Variable changed:** `top_k ∈ {1, 3, 5, 7}`, applied to both retrieval tiers. Hybrid retrieval; everything else held constant. | |
| """ | |
| def run_experiment_B(top_k_values=(1, 3, 5, 7)): | |
| items = [x for x in validation_set["items"]] | |
| print(f"running experiment B: {len(items)} items × {len(top_k_values)} top-k values " | |
| f"= {len(items) * len(top_k_values)} runs") | |
| for k in top_k_values: | |
| time.sleep(1.5) | |
| print(f"\n--- top_k = {k} ---") | |
| for item in items: | |
| try: | |
| c, p = retrieve_for_variant( | |
| item["passage"], "hybrid", n_concepts=k, n_papers=k, | |
| ) | |
| retrieved = c + p | |
| context_text = _format_context_block(c, p) | |
| result = generate_with_citation_guard( | |
| item["passage"], c, p, do_semantic_check=False | |
| ) | |
| scores = score_all_metrics( | |
| item["passage"], item["reference_explanation"], | |
| context_text, result["answer"] | |
| ) | |
| avg_dist = (float(np.mean([r["distance"] for r in retrieved if r["distance"] > 0])) | |
| if retrieved else None) | |
| log_eval_row( | |
| "experiment_B", item["id"], f"topk_{k}", | |
| retrieved, result, scores, | |
| extra={ | |
| "category": item["category"], | |
| "top_k": k, | |
| "n_retrieved_total": len(retrieved), | |
| "avg_distance": avg_dist, | |
| } | |
| ) | |
| except Exception as e: | |
| print(f" ERROR {item['id']}: {e}") | |
| return pd.read_csv(EVAL_LOG_DIR / "experiment_B.csv") | |
| def analyze_experiment_B(): | |
| df = pd.read_csv(EVAL_LOG_DIR / "experiment_B.csv") | |
| print("mean correctness by top-k (with 95% CI):") | |
| for k in sorted(df["top_k"].unique()): | |
| sub = df[df.top_k == k] | |
| print(f" top_k={k} " | |
| f"correctness={format_ci(sub['correctness'].values)} " | |
| f"evidence={format_ci(sub['evidence_support'].values)} " | |
| f"avg_n_retrieved={sub['n_retrieved_total'].mean():.1f}") | |
| return df | |
| def plot_experiment_B(): | |
| df = pd.read_csv(EVAL_LOG_DIR / "experiment_B.csv") | |
| ks = sorted(df["top_k"].unique()) | |
| means, los, his = [], [], [] | |
| for k in ks: | |
| m, lo, hi = bootstrap_ci(df[df.top_k == k]["correctness"].values) | |
| means.append(m); los.append(m - lo); his.append(hi - m) | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| ax.errorbar(ks, means, yerr=[los, his], marker="o", linewidth=2, capsize=5) | |
| ax.set_xlabel("top-k (per retrieval tier)") | |
| ax.set_ylabel("Mean correctness (0-2)") | |
| ax.set_title("Experiment B — Correctness vs top-k (95% CI)") | |
| ax.set_xticks(ks) | |
| ax.set_ylim(0, 2) | |
| ax.grid(alpha=0.3) | |
| plt.tight_layout() | |
| plt.show() | |
| experiment_B_df = load_or_run_experiment("experiment_B", run_experiment_B) | |
| if experiment_B_df is not None: | |
| analyze_experiment_B() | |
| plot_experiment_B() | |
| """## 17. Experiment C — confidence threshold tuning | |
| **Question.** Where should the low-confidence threshold sit so that warnings correlate with wrong answers? | |
| **Hypothesis.** The default 1.3 threshold was a guess. The F1-maximizing threshold is probably lower. | |
| **Variable changed.** Threshold ∈ [0.6, 1.6] by 0.1. | |
| """ | |
| def run_experiment_C(): | |
| items = [x for x in validation_set["items"]] | |
| for item in items: | |
| try: | |
| time.sleep(1.5) | |
| c, p = retrieve_for_variant(item["passage"], "hybrid") | |
| all_dists = [r["distance"] for r in (c + p) if r["distance"] > 0] | |
| best_dist = float(min(all_dists)) if all_dists else 999.0 | |
| context_text = _format_context_block(c, p) | |
| result = generate_with_citation_guard( | |
| item["passage"], c, p, do_semantic_check=False | |
| ) | |
| scores = score_all_metrics( | |
| item["passage"], item["reference_explanation"], context_text, result["answer"] | |
| ) | |
| corr = scores["correctness"]["score"] | |
| log_eval_row( | |
| "experiment_C", item["id"], "default_system", | |
| c + p, result, scores, | |
| extra={ | |
| "category": item["category"], | |
| "best_distance": best_dist, | |
| "correctness_raw": corr, | |
| } | |
| ) | |
| except Exception as e: | |
| print(f" ERROR {item['id']}: {e}") | |
| return pd.read_csv(EVAL_LOG_DIR / "experiment_C.csv") | |
| def analyze_experiment_C(): | |
| df = pd.read_csv(EVAL_LOG_DIR / "experiment_C.csv") | |
| if "correctness_raw" not in df.columns: | |
| df["correctness_raw"] = df["correctness"] | |
| df = df.dropna(subset=["best_distance", "correctness_raw"]) | |
| strict_pos = (df["correctness_raw"] < 1.0).sum() | |
| if strict_pos > 0: | |
| df["is_wrong"] = (df["correctness_raw"] < 1.0).astype(int) | |
| wrongness_def = "correctness < 1.0 (strict)" | |
| else: | |
| df["is_wrong"] = (df["correctness_raw"] < 2.0).astype(int) | |
| wrongness_def = "correctness < 2.0 (saturation fallback)" | |
| print(f"using wrongness definition: {wrongness_def}") | |
| print(f" positive class size: {df['is_wrong'].sum()}/{len(df)}") | |
| if df["is_wrong"].sum() == 0: | |
| print("⚠ WARNING: no wrong answers in eval set. Tuning is meaningless on this data.") | |
| return None, None, None | |
| thresholds = [round(0.6 + 0.1 * i, 2) for i in range(11)] | |
| rows = [] | |
| for t in thresholds: | |
| warns = df["best_distance"] > t | |
| wrong = df["is_wrong"] == 1 | |
| tp = int(((warns) & (wrong)).sum()) | |
| fp = int(((warns) & (~wrong)).sum()) | |
| fn = int(((~warns) & (wrong)).sum()) | |
| tn = int(((~warns) & (~wrong)).sum()) | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 1.0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = 2*precision*recall / (precision + recall) if (precision + recall) > 0 else 0.0 | |
| tpr = recall | |
| fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0 | |
| rows.append({"threshold": t, "tp": tp, "fp": fp, "fn": fn, "tn": tn, | |
| "refusal_precision": round(precision, 3), | |
| "refusal_recall": round(recall, 3), | |
| "f1": round(f1, 3), | |
| "tpr": round(tpr, 3), "fpr": round(fpr, 3)}) | |
| sweep = pd.DataFrame(rows) | |
| print(sweep) | |
| best = sweep.loc[sweep["f1"].idxmax()] | |
| print(f"\nF1-maximizing threshold: {best['threshold']} (F1={best['f1']})") | |
| print(f" refusal precision: {best['refusal_precision']}") | |
| print(f" refusal recall: {best['refusal_recall']}") | |
| s = sweep.sort_values("fpr") | |
| auc = 0.0 | |
| for i in range(1, len(s)): | |
| auc += (s.iloc[i]["fpr"] - s.iloc[i-1]["fpr"]) * (s.iloc[i]["tpr"] + s.iloc[i-1]["tpr"]) / 2 | |
| print(f"approx ROC AUC: {auc:.3f}") | |
| return sweep, best, auc | |
| def plot_experiment_C(): | |
| out = analyze_experiment_C() | |
| if out is None or out[0] is None: | |
| return | |
| sweep, best, auc = out | |
| fig, ax = plt.subplots(figsize=(7, 7)) | |
| s = sweep.sort_values("fpr") | |
| ax.plot(s["fpr"], s["tpr"], marker="o", linewidth=2, label=f"ROC (AUC≈{auc:.3f})") | |
| ax.plot([0, 1], [0, 1], "--", color="gray", alpha=0.5, label="chance") | |
| ax.scatter([best["fpr"]], [best["tpr"]], s=200, color="red", zorder=5, | |
| label=f"best F1 @ threshold={best['threshold']}") | |
| ax.set_xlabel("False positive rate") | |
| ax.set_ylabel("True positive rate (refusal recall)") | |
| ax.set_title("Experiment C — Abstention threshold ROC") | |
| ax.set_xlim(-0.05, 1.05); ax.set_ylim(-0.05, 1.05) | |
| ax.legend(loc="lower right") | |
| ax.grid(alpha=0.3) | |
| plt.tight_layout() | |
| plt.show() | |
| experiment_C_df = load_or_run_experiment("experiment_C", run_experiment_C) | |
| if experiment_C_df is not None: | |
| plot_experiment_C() | |
| DEFAULT_CONFIDENCE_THRESHOLD = 1.3 | |
| try: | |
| if experiment_C_df is not None: | |
| out = analyze_experiment_C() | |
| if out and out[1] is not None: | |
| _, best, _ = out | |
| TUNED_CONFIDENCE_THRESHOLD = float(best["threshold"]) | |
| else: | |
| TUNED_CONFIDENCE_THRESHOLD = DEFAULT_CONFIDENCE_THRESHOLD | |
| else: | |
| TUNED_CONFIDENCE_THRESHOLD = DEFAULT_CONFIDENCE_THRESHOLD | |
| except Exception as e: | |
| print(f"falling back to default threshold ({e})") | |
| TUNED_CONFIDENCE_THRESHOLD = DEFAULT_CONFIDENCE_THRESHOLD | |
| print(f"TUNED_CONFIDENCE_THRESHOLD = {TUNED_CONFIDENCE_THRESHOLD}") | |
| """## 18. Main pipeline (with citations + tuned threshold + semantic check)""" | |
| def check_input_quality(text): | |
| if len(text.strip()) < 20: | |
| return False, "That's pretty short — try pasting a full sentence or paragraph from a paper." | |
| if len(text.strip()) > 3000: | |
| return False, "That's a lot of text. Try pasting just 1-2 paragraphs at a time." | |
| if len(text.split()) < 5: | |
| return False, "Try a longer passage — at least a full sentence from a paper." | |
| return True, "ok" | |
| def assess_retrieval_confidence(concept_results, paper_results, threshold=None): | |
| threshold = threshold if threshold is not None else TUNED_CONFIDENCE_THRESHOLD | |
| dists = [r["distance"] for r in (concept_results + paper_results) if r["distance"] > 0] | |
| if not dists: | |
| return "low", "I couldn't find any relevant context in my knowledge base." | |
| best = min(dists) | |
| if best < 0.8: | |
| return "high", "" | |
| elif best < threshold: | |
| return "medium", ("Note: my knowledge base has some related material, but the match isn't perfect. " | |
| "Double-check against the paper's own definitions.") | |
| else: | |
| return "low", "Heads up: the concepts in this passage don't match well with my current knowledge base." | |
| SCOPE_DISCLAIMER = ( | |
| "---\n" | |
| "*This tool helps you understand papers; it doesn't replace them. " | |
| "Every factual sentence above is cited to a specific retrieved source. " | |
| "⚠️ marks indicate the semantic guard flagged that sentence as not fully supported by its citation. " | |
| "Always check the original paper.*" | |
| ) | |
| def scimplify(passage, variant="hybrid"): | |
| is_ok, msg = check_input_quality(passage) | |
| if not is_ok: | |
| return msg | |
| c, p = retrieve_for_variant(passage, variant) | |
| confidence, warning = assess_retrieval_confidence(c, p) | |
| result = generate_with_citation_guard(passage, c, p) | |
| parts = [] | |
| if result["guard_triggered"]: | |
| which = "semantic" if any("semantic" in i for i in result.get("issues", [])) else "lexical" | |
| parts.append(f"⚠️ The {which} citation guard triggered. Returning abstention rather than a potentially ungrounded answer.") | |
| if result.get("issues"): | |
| parts.append(f"\n*Reason: {'; '.join(result['issues'])}*") | |
| parts.append(f"\n{result['answer']}") | |
| parts.append(f"\n{SCOPE_DISCLAIMER}") | |
| return "\n".join(parts) | |
| if result["answer"].strip() == ABSTAIN_MESSAGE or result["abstained"]: | |
| parts = [result["answer"], SCOPE_DISCLAIMER] | |
| return "\n".join(parts) | |
| if confidence == "low": | |
| parts.append(f"⚠️ {warning}\n") | |
| elif confidence == "medium": | |
| parts.append(f"ℹ️ {warning}\n") | |
| parts.append(result["answer"]) | |
| # show retrieved sources | |
| concept_names = [r["concept_name"] for r in c if "concept_name" in r] | |
| if concept_names: | |
| parts.append(f"\n\n**Retrieved concepts:** {', '.join(concept_names)}") | |
| if p: | |
| sources = sorted(set(r["source_name"] for r in p)) | |
| parts.append(f"**Paper sources:** {', '.join(sources)}") | |
| # surface semantic check stats if any sentences were checked | |
| findings = result.get("semantic_findings", []) | |
| if findings: | |
| n_total = len(findings) | |
| n_unsupported = sum(1 for f in findings if f["label"] in ("contradicted", "insufficient")) | |
| if n_unsupported > 0: | |
| parts.append(f"\n*Semantic guard: {n_unsupported}/{n_total} cited sentences flagged as not fully supported.*") | |
| else: | |
| parts.append(f"\n*Semantic guard: all {n_total} cited sentences supported by their citations ✓*") | |
| parts.append(f"\n{SCOPE_DISCLAIMER}") | |
| return "\n".join(parts) | |
| """## 19. Gradio UI | |
| """ | |
| def add_pdf_to_kb(pdf_file, source_name, source_type): | |
| if pdf_file is None: | |
| return "Please upload a PDF file." | |
| if not source_name.strip(): | |
| return "Please provide a name for this source." | |
| try: | |
| text = extract_text_from_pdf(pdf_file) | |
| chunks = chunk_text(text) | |
| base = source_name.strip().replace(" ", "_") | |
| ids = [f"user_{base}::c{i}" for i in range(len(chunks))] | |
| metas = [{ | |
| "source_name": source_name.strip(), | |
| "source_type": source_type, | |
| "chunk_id": ids[i], | |
| } for i in range(len(chunks))] | |
| if chunks: | |
| papers_collection.add(documents=chunks, ids=ids, metadatas=metas) | |
| return f"Added {len(chunks)} chunks. Total: {papers_collection.count()}" | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def pull_from_arxiv_ui(query, max_results): | |
| """Gradio handler for arXiv ingestion.""" | |
| try: | |
| max_results = int(max_results) | |
| if max_results < 1 or max_results > 25: | |
| return "Please pick a max_results between 1 and 25." | |
| summary = ingest_from_arxiv(query=query, max_results=max_results, verbose=False) | |
| msg = ( | |
| f"✅ Ingested {summary['n_papers']} new paper(s), " | |
| f"{summary['n_chunks']} chunks. " | |
| f"Skipped {summary['n_skipped']} duplicates. " | |
| f"Total in KB: {summary['total_in_kb']} chunks." | |
| ) | |
| if summary["errors"]: | |
| msg += f"\n\n⚠️ Errors: {'; '.join(summary['errors'][:3])}" | |
| return msg | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def get_kb_status(): | |
| n_concepts = concepts_collection.count() | |
| n_papers = papers_collection.count() | |
| status = f"**Concept definitions:** {n_concepts}\n\n**Paper chunks:** {n_papers}\n" | |
| if n_papers > 0: | |
| metas = papers_collection.get()["metadatas"] | |
| sources = Counter(m["source_name"] for m in metas) | |
| status += "\n**Ingested sources:**\n" | |
| for name, count in sources.most_common(): | |
| status += f"- {name} — {count} chunks\n" | |
| return status | |
| DEMO_CLEAN = "The Diels-Alder reaction is a [4+2] cycloaddition between a conjugated diene and a dienophile, producing a six-membered ring with up to four new stereocenters. The reaction proceeds through a concerted, suprafacial transition state and is highly stereospecific: cis-dienophiles yield cis-substituted cyclohexenes. Electron-withdrawing groups on the dienophile dramatically accelerate the reaction." | |
| DEMO_PAPER = "Multi-region neural population dynamics in the brain have been studied using techniques like LFADS to model the latent factors driving observed activity across regions." | |
| DEMO_ABSTAIN = "Laminated pastry dough is created by repeatedly folding butter into flour-water dough, producing alternating layers that puff up during baking as steam expands between them. Croissants are the canonical example." | |
| with gr.Blocks(title="Scimplify") as app: | |
| gr.Markdown("# Scimplify — NeuroAI Paper Simplifier") | |
| gr.Markdown( | |
| "Paste a NeuroAI paragraph; get a plain-language explanation with citations. " | |
| "Every factual sentence is grounded in a retrieved source. The lexical guard rejects " | |
| "invented citation IDs, and the semantic guard verifies that each cited chunk actually " | |
| "supports the claim. If neither passes, the system abstains rather than hallucinate." | |
| ) | |
| with gr.Tab("Explain Passage"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp = gr.Textbox(label="Passage", lines=8, | |
| placeholder="Paste a paragraph from a paper...") | |
| btn = gr.Button("Explain", variant="primary") | |
| gr.Examples( | |
| examples=[ | |
| [DEMO_CLEAN], | |
| [DEMO_PAPER], | |
| [DEMO_ABSTAIN], | |
| ], | |
| inputs=[inp], | |
| label="Demo passages (clean / paper-chunk / out-of-scope)", | |
| ) | |
| with gr.Column(scale=2): | |
| out = gr.Markdown(label="Explanation") | |
| btn.click(fn=lambda x: scimplify(x), inputs=[inp], outputs=[out]) | |
| with gr.Tab("Add Papers (PDF)"): | |
| pdf_in = gr.File(label="PDF", file_types=[".pdf"]) | |
| name_in = gr.Textbox(label="Source name") | |
| type_in = gr.Radio(["paper", "article", "review"], label="Type", value="paper") | |
| add_btn = gr.Button("Add to knowledge base") | |
| add_out = gr.Textbox(label="Status") | |
| add_btn.click(fn=add_pdf_to_kb, inputs=[pdf_in, name_in, type_in], outputs=[add_out]) | |
| with gr.Tab("Pull from arXiv"): | |
| gr.Markdown( | |
| "Fetch recent NeuroAI papers from arXiv directly. " | |
| "Skips papers already in the knowledge base (matched by arxiv_id)." | |
| ) | |
| arxiv_query = gr.Textbox( | |
| label="arXiv query", | |
| value="NeuroAI", | |
| placeholder="e.g. NeuroAI, brain-inspired deep learning, neural population dynamics", | |
| ) | |
| arxiv_n = gr.Slider(label="Max papers", minimum=1, maximum=20, value=5, step=1) | |
| arxiv_btn = gr.Button("Pull from arXiv", variant="primary") | |
| arxiv_out = gr.Markdown() | |
| arxiv_btn.click(fn=pull_from_arxiv_ui, inputs=[arxiv_query, arxiv_n], outputs=[arxiv_out]) | |
| with gr.Tab("Knowledge Base"): | |
| status_out = gr.Markdown(value=get_kb_status()) | |
| refresh_btn = gr.Button("Refresh") | |
| refresh_btn.click(fn=get_kb_status, outputs=[status_out]) | |
| app.launch(share=True) |