# -*- coding: utf-8 -*- """Scimplify.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/11L85VXrmvxrfXd6A9FGuJjI53nVtM0tN # Scimplify A NeuroAI paper simplifier. You paste a paragraph and get a plain-language explanation back, with citations to the retrieved chunks the explanation came from. The system refuses to answer if it can't ground the claims. ## 1. Setup """ import os, json, re, io, textwrap, time from pathlib import Path from collections import Counter, defaultdict from typing import List, Dict, Tuple, Optional from concurrent.futures import ThreadPoolExecutor, as_completed import requests import numpy as np import pandas as pd import matplotlib.pyplot as plt import openai import chromadb from chromadb.utils import embedding_functions import gradio as gr from sentence_transformers import SentenceTransformer from PyPDF2 import PdfReader try: import os api_key = os.getenv("OPENAI_API_KEY") except (ImportError, Exception): assert os.environ.get("OPENAI_API_KEY"), "Set OPENAI_API_KEY env var" client_oai = openai.OpenAI() GENERATOR_MODEL = "gpt-4o-mini" JUDGE_MODEL = "gpt-4o-mini" GENERATOR_TEMPERATURE = 0.2 JUDGE_TEMPERATURE = 0.3 JUDGE_N_SAMPLES = 3 BOOTSTRAP_N = 2000 BOOTSTRAP_ALPHA = 0.05 _rng = np.random.default_rng(7) RUN_EXPERIMENTS = False # re-run experiments LIVE_SEMANTIC_CHECK = True # adds 1s per query JUDGE_PARALLELISM = 2 # rate limit cap print(f"generator: {GENERATOR_MODEL}") print(f"judge: {JUDGE_MODEL}") print(f"experiments: {'WILL RE-RUN' if RUN_EXPERIMENTS else 'using cached results'}") print(f"live semantic check: {'on' if LIVE_SEMANTIC_CHECK else 'off'}") """## 2. Data loading""" REPO_RAW_BASE = "https://raw.githubusercontent.com/martazavro/scimplify_data/main" LOCAL_DATA_DIR = Path("./data") def _load_json(filename): url = f"{REPO_RAW_BASE}/{filename}" try: r = requests.get(url, timeout=10) r.raise_for_status() print(f"loaded {filename} from repo") return r.json() except Exception as e: local = LOCAL_DATA_DIR / filename if local.exists(): print(f"repo fetch failed ({e.__class__.__name__}), loaded {filename} from local") return json.loads(local.read_text()) raise FileNotFoundError( f"Could not load {filename}. Set REPO_RAW_BASE correctly " f"or place the file in ./data/{filename}" ) neuroai_concepts = _load_json("concepts.json") print(f"loaded {len(neuroai_concepts)} concepts") def validate_validation_set(vs): items = vs["items"] ids = [x["id"] for x in items] assert len(set(ids)) == len(ids), "duplicate ids" required = {"id", "passage", "source", "key_terms", "category", "difficulty", "reference_explanation"} valid_cats = {"concepts_only", "recent_paper", "both", "neither"} for item in items: missing = required - set(item.keys()) assert not missing, f"item {item.get('id')} missing {missing}" assert item["category"] in valid_cats cat_counts = Counter(x["category"] for x in items) print(f"validation set: {len(items)} items") print(f" by category: {dict(cat_counts)}") validation_set = _load_json("validation_set.json") validate_validation_set(validation_set) """## 3. PDF extraction and chunking""" def extract_text_from_pdf(pdf_file): reader = PdfReader(pdf_file) text = "" for page in reader.pages: page_text = page.extract_text() if page_text: text += page_text + "\n" return text.strip() def chunk_text(text, chunk_size=300, overlap=50): paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()] chunks, current, current_len = [], [], 0 for para in paragraphs: words = para.split() n = len(words) if n > chunk_size: if current: chunks.append(" ".join(current)) tail = current[-overlap:] if len(current) > overlap else current current = list(tail); current_len = len(current) for i in range(0, n, chunk_size - overlap): chunk = words[i:i+chunk_size] if len(chunk) > 30: chunks.append(" ".join(chunk)) current = []; current_len = 0 elif current_len + n > chunk_size: chunks.append(" ".join(current)) tail = current[-overlap:] if len(current) > overlap else current current = list(tail) + words; current_len = len(current) else: current.extend(words); current_len += n if current and len(current) > 30: chunks.append(" ".join(current)) return chunks """## 4. Vector store setup""" chroma_client = chromadb.Client() ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2") def reset_concepts_collection(): try: chroma_client.delete_collection("neuroai_concepts") except Exception: pass coll = chroma_client.create_collection(name="neuroai_concepts", embedding_function=ef) for entry in neuroai_concepts: doc = ( f"Concept: {entry['concept']}\n" f"Definition: {entry['definition']}\n" f"Context: {entry['context']}\n" f"Typically found in: {entry['typical_usage']}" ) coll.add( documents=[doc], ids=[entry["id"]], metadatas=[{"concept_name": entry["concept"], "concept_id": entry["id"]}] ) return coll def reset_papers_collection(): try: chroma_client.delete_collection("neuroai_papers") except Exception: pass return chroma_client.create_collection(name="neuroai_papers", embedding_function=ef) concepts_collection = reset_concepts_collection() papers_collection = reset_papers_collection() print(f"concepts: {concepts_collection.count()}, papers: {papers_collection.count()}") """## 5. Recent papers ingestion""" PAPER_CHUNKS_URL = f"{REPO_RAW_BASE}/paper_chunks.json" def load_paper_chunks(): r = requests.get(PAPER_CHUNKS_URL, timeout=15) r.raise_for_status() return r.json() def ingest_paper_chunks_from_json(): chunks = load_paper_chunks() if not chunks: print("paper_chunks.json was empty") return 0 documents = [c["text"] for c in chunks] ids = [c["chunk_id"] for c in chunks] metadatas = [{ "source_name": c["source_name"], "source_type": c["source_type"], "arxiv_id": c["arxiv_id"], "title": c["title"], "chunk_idx": c["chunk_idx"], "chunk_id": c["chunk_id"], } for c in chunks] papers_collection.add(documents=documents, ids=ids, metadatas=metadatas) by_paper = {} for c in chunks: by_paper[c["arxiv_id"]] = by_paper.get(c["arxiv_id"], 0) + 1 for aid, n in by_paper.items(): print(f" {aid}: {n} chunks") print(f"papers_collection now has {papers_collection.count()} total chunks") return len(chunks) ingest_paper_chunks_from_json() """## 6. arXiv ingestion""" import arxiv def _existing_arxiv_ids(): if papers_collection.count() == 0: return set() metas = papers_collection.get()["metadatas"] return {m.get("arxiv_id") for m in metas if m.get("arxiv_id")} def ingest_from_arxiv(query="neuroAI OR (neural AND brain AND deep learning)", max_results=10, sort_by_recent=True, verbose=True): """Search arXiv, download PDFs, chunk them, add to papers_collection. Returns dict with stats: {n_papers, n_chunks, n_skipped, errors}. Already-ingested papers (matched by arxiv_id) are skipped. """ sort_by = arxiv.SortCriterion.SubmittedDate if sort_by_recent else arxiv.SortCriterion.Relevance arxiv_client = arxiv.Client(page_size=20, delay_seconds=3.0, num_retries=3) search = arxiv.Search(query=query, max_results=max_results, sort_by=sort_by) existing = _existing_arxiv_ids() download_dir = Path("./arxiv_papers") download_dir.mkdir(exist_ok=True) n_papers, n_chunks, n_skipped = 0, 0, 0 errors = [] for result in arxiv_client.results(search): # arxiv.org/abs/2509.23566v1 -> "2509.23566" full_id = result.entry_id.rsplit("/", 1)[-1] arxiv_id = full_id.split("v")[0] if arxiv_id in existing: n_skipped += 1 if verbose: print(f" skip {arxiv_id} (already ingested)") continue try: if verbose: print(f" fetching {arxiv_id}: {result.title[:60]}...") pdf_path = result.download_pdf(dirpath=str(download_dir), filename=f"{arxiv_id}.pdf") text = extract_text_from_pdf(pdf_path) chunks = chunk_text(text) if not chunks: errors.append(f"{arxiv_id}: no chunks extracted") continue chunk_ids = [f"arxiv_{arxiv_id.replace('.', '_')}::c{i}" for i in range(len(chunks))] metadatas = [{ "source_name": result.title, "source_type": "arxiv_paper", "arxiv_id": arxiv_id, "title": result.title, "chunk_idx": i, "chunk_id": chunk_ids[i], } for i in range(len(chunks))] papers_collection.add(documents=chunks, ids=chunk_ids, metadatas=metadatas) existing.add(arxiv_id) # avoid double-add within the same call n_papers += 1 n_chunks += len(chunks) if verbose: print(f" -> added {len(chunks)} chunks") except Exception as e: errors.append(f"{arxiv_id}: {e.__class__.__name__}: {e}") if verbose: print(f" ERROR: {e}") summary = { "n_papers": n_papers, "n_chunks": n_chunks, "n_skipped": n_skipped, "errors": errors, "total_in_kb": papers_collection.count(), } if verbose: print(f"\ningested {n_papers} papers ({n_chunks} chunks), skipped {n_skipped} duplicates") if errors: print(f"errors: {len(errors)}") print(f"total in knowledge base: {summary['total_in_kb']} chunks") return summary ingest_from_arxiv(query="NeuroAI", max_results=2) ingest_from_arxiv(query="NeuroAI", max_results=15) """## 7. Retrieval variants""" def _flexible_last_word(word): if len(word) < 4: return re.escape(word) stem = re.escape(word[:-2]) return stem + r"[a-z]{0,4}" def build_concept_patterns(concept_entry): name = concept_entry["concept"] abbrev_match = re.search(r"\(([^)]+)\)", name) abbrev = abbrev_match.group(1) if abbrev_match else None base = re.sub(r"\s*\([^)]+\)", "", name).strip() patterns = [] words = base.split() if len(words) == 1: long_re = r"\b" + _flexible_last_word(words[0]) + r"\b" else: parts = [re.escape(w) for w in words[:-1]] + [_flexible_last_word(words[-1])] long_re = r"\b" + r"\s+".join(parts) + r"\b" patterns.append(re.compile(long_re, re.IGNORECASE)) if abbrev: patterns.append(re.compile(r"\b" + re.escape(abbrev) + r"s?\b")) return patterns CONCEPT_PATTERNS = [(entry, build_concept_patterns(entry)) for entry in neuroai_concepts] def _concept_doc_text(entry): return ( f"Concept: {entry['concept']}\n" f"Definition: {entry['definition']}\n" f"Context: {entry['context']}\n" f"Typically found in: {entry['typical_usage']}" ) def regex_retrieve(passage): hits = [] for entry, patterns in CONCEPT_PATTERNS: if any(p.search(passage) for p in patterns): hits.append({ "type": "regex_concept", "concept_name": entry["concept"], "concept_id": entry["id"], "chunk_id": entry["id"], "content": _concept_doc_text(entry), "distance": 0.0, "source_method": "regex", }) return hits def retrieve_concepts_embedding(passage, n_results=3): results = concepts_collection.query(query_texts=[passage], n_results=n_results) out = [] for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0]): out.append({ "type": "concept", "concept_name": meta["concept_name"], "concept_id": meta.get("concept_id"), "chunk_id": meta.get("concept_id"), "content": doc, "distance": round(dist, 3), "source_method": "embedding", }) return out def retrieve_paper_chunks(passage, n_results=3): if papers_collection.count() == 0: return [] results = papers_collection.query( query_texts=[passage], n_results=min(n_results, papers_collection.count()) ) out = [] for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0]): out.append({ "type": "paper_chunk", "source_name": meta["source_name"], "source_type": meta["source_type"], "chunk_id": meta.get("chunk_id"), "content": doc, "distance": round(dist, 3), "source_method": "embedding", }) return out def hybrid_retrieve_concepts(passage, n_embedding=3, max_total=6): rgx = regex_retrieve(passage) seen_names = {h["concept_name"] for h in rgx} out = list(rgx) if len(out) < max_total: emb = retrieve_concepts_embedding(passage, n_results=n_embedding) for hit in emb: if hit["concept_name"] not in seen_names: out.append(hit) seen_names.add(hit["concept_name"]) else: for r in out: if r["concept_name"] == hit["concept_name"]: r["source_method"] = "both" break if len(out) >= max_total: break return out def retrieve_for_variant(passage, variant, n_concepts=3, n_papers=3): if variant == "no_rag": return [], [] elif variant == "embedding_only": return (retrieve_concepts_embedding(passage, n_results=n_concepts), retrieve_paper_chunks(passage, n_results=n_papers)) elif variant == "regex_only": return (regex_retrieve(passage), retrieve_paper_chunks(passage, n_results=n_papers)) elif variant == "hybrid": return (hybrid_retrieve_concepts(passage, n_embedding=n_concepts), retrieve_paper_chunks(passage, n_results=n_papers)) else: raise ValueError(f"unknown variant: {variant}") """## 8. Citation-enforced generation with semantic guard """ CITED_SYSTEM_PROMPT = """You are a scientific reading assistant that helps people understand passages from NeuroAI research papers. You have access to retrieved context. Each source has a stable ID in square brackets like [c004] (for a concept definition) or [arxiv_2511_12345::c3] (for a paper chunk). Your job: 1. Read the passage. 2. Rewrite it in plain language an undergraduate could follow. 3. For EVERY factual sentence in your explanation, append one or more citations in square brackets, drawn ONLY from the IDs of the retrieved sources shown to you. 4. Do not invent citation IDs. Do not cite sources you were not shown. 5. If the retrieved context does not contain enough information to answer faithfully, output EXACTLY this string and nothing else: I don't have enough evidence in the retrieved context. Format: **Key terms:** short definitions of technical terms, each with its citation **Plain-language version:** the passage rewritten clearly, with citations on every factual sentence **What this means in context:** 1-2 sentences on why this matters, with citations """ ABSTAIN_MESSAGE = "I don't have enough evidence in the retrieved context." CITATION_PATTERN = re.compile(r"\[([a-zA-Z0-9_\-:]+)\]") SEMANTIC_FAIL_THRESHOLD = 0.5 def _format_context_block(concept_results, paper_results): lines = [] if concept_results: lines.append("CONCEPT DEFINITIONS:") for r in concept_results: cid = r.get("chunk_id") or r.get("concept_id") lines.append(f"\n[{cid}] {r['content']}") lines.append("---") if paper_results: lines.append("\nPAPER/ARTICLE CONTEXT:") for r in paper_results: cid = r.get("chunk_id") lines.append(f"\n[{cid}] (from {r['source_name']}): {r['content']}") lines.append("---") if not concept_results and not paper_results: lines.append("(no context retrieved)") return "\n".join(lines) def _collect_allowed_ids(concept_results, paper_results): ids = set() for r in concept_results + paper_results: cid = r.get("chunk_id") or r.get("concept_id") if cid: ids.add(cid) return ids def _build_chunk_lookup(concept_results, paper_results): """Map citation_id -> chunk content. Used by the semantic check.""" lookup = {} for r in concept_results + paper_results: cid = r.get("chunk_id") or r.get("concept_id") if cid: lookup[cid] = r["content"] return lookup def generate_cited_explanation(passage, concept_results, paper_results, model=None): model = model or GENERATOR_MODEL context_block = _format_context_block(concept_results, paper_results) user_msg = f"{context_block}\n\nPASSAGE TO EXPLAIN:\n{passage}" resp = client_oai.chat.completions.create( model=model, temperature=GENERATOR_TEMPERATURE, messages=[ {"role": "system", "content": CITED_SYSTEM_PROMPT}, {"role": "user", "content": user_msg}, ], ) return resp.choices[0].message.content def validate_citations(answer: str, allowed_ids: set) -> Tuple[bool, List[str]]: """Lexical guard: every citation ID in the answer must be in the allowed set.""" if ABSTAIN_MESSAGE in answer: return True, [] cited = CITATION_PATTERN.findall(answer) issues = [] if not cited: issues.append("No citations found in non-abstain answer") for cid in cited: if cid not in allowed_ids: issues.append(f"Invalid citation: {cid}") return len(issues) == 0, issues def _split_into_sentences(text): """Cheap sentence splitter that keeps citation brackets attached.""" # split on . ! ? followed by space and a capital, keeping the punctuation parts = re.split(r"(?<=[.!?])\s+(?=[A-Z*])", text.strip()) return [p.strip() for p in parts if p.strip()] def _strip_citations(sentence): return CITATION_PATTERN.sub("", sentence).strip() def check_sentence_supported(sentence_text, cited_chunks): claim = _strip_citations(sentence_text) if len(claim) < 10 or not cited_chunks: return {"label": "skipped", "reason": "no claim or no chunks"} evidence = "\n\n".join(f"[{cid}]: {text}" for cid, text in cited_chunks) return verify_claim_against_evidence(claim, [evidence]) def semantic_per_sentence_check(answer, chunk_lookup): if ABSTAIN_MESSAGE in answer: return [] sentences = _split_into_sentences(answer) findings = [] for sent in sentences: cited_ids = CITATION_PATTERN.findall(sent) if not cited_ids: continue cited_chunks = [(cid, chunk_lookup[cid]) for cid in cited_ids if cid in chunk_lookup] if not cited_chunks: continue result = check_sentence_supported(sent, cited_chunks) findings.append({ "sentence": sent, "citations": cited_ids, "label": result["label"], "reason": result["reason"], }) return findings def annotate_unsupported_sentences(answer, findings): """Mark unsupported sentences in the rendered output.""" for f in findings: if f["label"] in ("contradicted", "insufficient"): marker = "⚠️ " if marker not in f["sentence"]: answer = answer.replace(f["sentence"], marker + f["sentence"], 1) return answer def generate_with_citation_guard(passage, concept_results, paper_results, model=None, allow_no_context_bypass=False, do_semantic_check=None): do_semantic_check = (do_semantic_check if do_semantic_check is not None else LIVE_SEMANTIC_CHECK) if allow_no_context_bypass and not concept_results and not paper_results: resp = client_oai.chat.completions.create( model=model or GENERATOR_MODEL, temperature=GENERATOR_TEMPERATURE, messages=[ {"role": "system", "content": "You are a scientific reading assistant. Explain the given passage in plain language that an undergraduate could follow. Be concise."}, {"role": "user", "content": f"PASSAGE:\n{passage}"}, ], ) return { "answer": resp.choices[0].message.content, "valid_citations": None, "guard_triggered": False, "issues": [], "abstained": False, "semantic_findings": [], "semantic_fail_rate": np.nan, } raw = generate_cited_explanation(passage, concept_results, paper_results, model=model) allowed_ids = _collect_allowed_ids(concept_results, paper_results) ok, issues = validate_citations(raw, allowed_ids) # lexical guard if not ok: return { "answer": ABSTAIN_MESSAGE, "valid_citations": False, "guard_triggered": True, "issues": issues, "abstained": True, "raw_rejected": raw, "semantic_findings": [], "semantic_fail_rate": np.nan, } # semantic per-sentence check findings = [] semantic_fail_rate = np.nan if do_semantic_check and ABSTAIN_MESSAGE not in raw: chunk_lookup = _build_chunk_lookup(concept_results, paper_results) findings = semantic_per_sentence_check(raw, chunk_lookup) if findings: n_failed = sum(1 for f in findings if f["label"] in ("contradicted", "insufficient")) semantic_fail_rate = n_failed / len(findings) if semantic_fail_rate > SEMANTIC_FAIL_THRESHOLD: return { "answer": ABSTAIN_MESSAGE, "valid_citations": True, "guard_triggered": True, "issues": [f"semantic check failed: {n_failed}/{len(findings)} sentences unsupported"], "abstained": True, "raw_rejected": raw, "semantic_findings": findings, "semantic_fail_rate": semantic_fail_rate, } raw = annotate_unsupported_sentences(raw, findings) return { "answer": raw, "valid_citations": True, "guard_triggered": False, "issues": [], "abstained": ABSTAIN_MESSAGE in raw, "semantic_findings": findings, "semantic_fail_rate": semantic_fail_rate, } """## 9. LLM-as-judge metrics """ def _coerce_score(x): try: v = int(float(x)) except Exception: v = 0 return max(0, min(2, v)) def _single_judge_call(system_prompt, user_prompt): try: resp = client_oai.chat.completions.create( model=JUDGE_MODEL, temperature=JUDGE_TEMPERATURE, response_format={"type": "json_object"}, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], ) data = json.loads(resp.choices[0].message.content) return { "score": _coerce_score(data.get("score", 0)), "reason": str(data.get("reason", "")).strip(), } except Exception as e: return {"score": None, "reason": f"ERROR: {e}"} def _judge_call_parallel(system_prompt, user_prompt, n=None): """Run n judge calls in parallel via ThreadPoolExecutor.""" n = n or JUDGE_N_SAMPLES results = [None] * n with ThreadPoolExecutor(max_workers=min(n, JUDGE_PARALLELISM)) as ex: futures = {ex.submit(_single_judge_call, system_prompt, user_prompt): i for i in range(n)} for fut in as_completed(futures): i = futures[fut] results[i] = fut.result() return results def _aggregate(runs): valid = [r for r in runs if r["score"] is not None] if not valid: return {"score": None, "reasons": [r["reason"] for r in runs], "n_valid": 0} return { "score": sum(r["score"] for r in valid) / len(valid), "reasons": [r["reason"] for r in valid], "n_valid": len(valid), } CORRECTNESS_SYSTEM = """You are evaluating answer correctness for a question about a NeuroAI paper passage. Given a passage, a reference explanation (gold-standard), and a system explanation, score the system explanation's correctness using ONLY the information in the passage and reference. Return ONLY a JSON object: {"score": , "reason": ""} Scoring scale: - 0 = wrong (contradicts the passage or says something incorrect) - 1 = partly correct (captures some but not all of the main idea, or adds unsupported claims) - 2 = correct (faithful to what the passage actually says) """ def score_correctness(passage, reference, candidate): user = f"PASSAGE:\n{passage}\n\nREFERENCE:\n{reference}\n\nSYSTEM EXPLANATION:\n{candidate}" runs = _judge_call_parallel(CORRECTNESS_SYSTEM, user) return _aggregate(runs) EVIDENCE_SYSTEM = """You are evaluating whether a system explanation's key claims are supported by retrieved context. Given a passage, the retrieved context that was shown to the system, and the system's explanation, score whether the explanation's factual claims are well-supported by the retrieved context. Return ONLY a JSON object: {"score": , "reason": ""} Scoring scale: - 0 = unsupported (most claims cannot be found in retrieved context) - 1 = partly supported (some claims supported, others require outside knowledge) - 2 = well supported (claims are traceable to retrieved context) If the retrieved context is empty (no RAG baseline), score 0. """ def score_evidence_support(passage, retrieved_context, candidate): user = f"PASSAGE:\n{passage}\n\nRETRIEVED CONTEXT:\n{retrieved_context}\n\nSYSTEM EXPLANATION:\n{candidate}" runs = _judge_call_parallel(EVIDENCE_SYSTEM, user) return _aggregate(runs) CITATION_SYSTEM = """You are evaluating whether citations in a system explanation are faithful. The system was asked to cite each factual sentence with an ID from the retrieved context (like [c004] or [arxiv_2511_12345::c3]). Given the retrieved context and the system explanation with citations, score whether the citations are relevant and the cited material actually supports the adjacent claim. Return ONLY a JSON object: {"score": , "reason": ""} Scoring scale: - 0 = unfaithful (citations invented, missing, or do not support adjacent claims) - 1 = mixed (some citations support their claims, others do not) - 2 = faithful (citations are present, relevant, and support adjacent claims) If the answer is the abstention message ("I don't have enough evidence..."), score 2 (correctly declined). """ def score_citation_faithfulness(retrieved_context, candidate): user = f"RETRIEVED CONTEXT:\n{retrieved_context}\n\nSYSTEM EXPLANATION:\n{candidate}" runs = _judge_call_parallel(CITATION_SYSTEM, user) return _aggregate(runs) def score_all_metrics(passage, reference, retrieved_context, candidate): """Run all three metrics in parallel.""" with ThreadPoolExecutor(max_workers=3) as ex: f_c = ex.submit(score_correctness, passage, reference, candidate) f_e = ex.submit(score_evidence_support, passage, retrieved_context, candidate) f_f = ex.submit(score_citation_faithfulness, retrieved_context, candidate) return { "correctness": f_c.result(), "evidence_support": f_e.result(), "citation_faithfulness": f_f.result(), } """## 10. Claim-based faithfulness """ CLAIM_EXTRACTION_SYSTEM = """Extract atomic factual claims from the given answer. Return ONLY a JSON object: {"claims": ["claim 1", "claim 2", ...]} Rules: - Each claim should be a single, minimal factual assertion - Ignore pure formatting, headers, or meta-commentary - Skip citation markers like [c004] when extracting claims - If there are no factual claims, return {"claims": []} """ EVIDENCE_EXTRACTION_SYSTEM = """Extract factual assertions from the given text chunk. Return ONLY a JSON object: {"assertions": ["assertion 1", "assertion 2", ...]} Rules: - One atomic factual assertion per entry - Skip anything that is a question, opinion, or example - If there are no assertions, return {"assertions": []} """ CLAIM_VERIFICATION_SYSTEM = """Classify if a claim is supported, contradicted, or insufficient given evidence. Return ONLY a JSON object: {"label": "supported" | "contradicted" | "insufficient", "reason": ""} Definitions: - supported: the evidence directly supports the claim - contradicted: the evidence contradicts the claim - insufficient: the evidence is silent or unclear on the claim """ def _json_call(system_prompt, user_prompt, model=None): model = model or JUDGE_MODEL resp = client_oai.chat.completions.create( model=model, temperature=JUDGE_TEMPERATURE, response_format={"type": "json_object"}, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], ) try: return json.loads(resp.choices[0].message.content) except Exception: return {} def extract_claims(answer): data = _json_call(CLAIM_EXTRACTION_SYSTEM, f"ANSWER:\n{answer}") return [c for c in data.get("claims", []) if c and isinstance(c, str)] _ASSERTION_CACHE = {} def extract_assertions_from_chunk(chunk): key = hash(chunk) if key in _ASSERTION_CACHE: return _ASSERTION_CACHE[key] data = _json_call(EVIDENCE_EXTRACTION_SYSTEM, f"CHUNK:\n{chunk}") out = [a for a in data.get("assertions", []) if a and isinstance(a, str)] _ASSERTION_CACHE[key] = out return out def _normalize_label(label): x = (label or "").strip().lower() if "support" in x: return "supported" if "contrad" in x: return "contradicted" return "insufficient" def verify_claim_against_evidence(claim, assertions): evidence_blob = "\n".join(assertions) if assertions else "NO_EVIDENCE" data = _json_call( CLAIM_VERIFICATION_SYSTEM, f"CLAIM:\n{claim}\n\nEVIDENCE:\n{evidence_blob}" ) return { "label": _normalize_label(data.get("label")), "reason": str(data.get("reason", "")).strip(), } def claim_based_faithfulness(answer, retrieved_chunks): if ABSTAIN_MESSAGE in answer: return { "n_claims": 0, "support_rate": np.nan, "contradiction_rate": np.nan, "unsupported_rate": np.nan, "abstained": True, "details": [], } claims = extract_claims(answer) if not claims: return { "n_claims": 0, "support_rate": np.nan, "contradiction_rate": np.nan, "unsupported_rate": np.nan, "abstained": False, "details": [], } with ThreadPoolExecutor(max_workers=JUDGE_PARALLELISM) as ex: all_assertions_lists = list(ex.map(extract_assertions_from_chunk, retrieved_chunks)) all_assertions = [a for sub in all_assertions_lists for a in sub] with ThreadPoolExecutor(max_workers=JUDGE_PARALLELISM) as ex: verify_results = list(ex.map( lambda c: verify_claim_against_evidence(c, all_assertions), claims )) labels = [r["label"] for r in verify_results] details = [{"claim": c, **r} for c, r in zip(claims, verify_results)] n = len(labels) return { "n_claims": n, "support_rate": sum(1 for l in labels if l == "supported") / n, "contradiction_rate": sum(1 for l in labels if l == "contradicted") / n, "unsupported_rate": sum(1 for l in labels if l == "insufficient") / n, "abstained": False, "details": details, } """## 11. Retrieval precision@k / recall@k and bootstrap CIs""" def precision_recall_at_k(retrieved_chunks, gold_facts, k=3): if not gold_facts: return np.nan, np.nan top_k = retrieved_chunks[:k] if not top_k: return 0.0, 0.0 rel_flags = [] for chunk in top_k: c = chunk.lower() is_rel = any(fact.lower() in c for fact in gold_facts) rel_flags.append(is_rel) precision = float(np.mean(rel_flags)) covered = 0 for fact in gold_facts: if any(fact.lower() in chunk.lower() for chunk in top_k): covered += 1 recall = covered / len(gold_facts) return precision, recall def bootstrap_ci(values, n_boot=None, alpha=None): n_boot = n_boot or BOOTSTRAP_N alpha = alpha or BOOTSTRAP_ALPHA values = np.array(values, dtype=float) values = values[~np.isnan(values)] if len(values) == 0: return np.nan, np.nan, np.nan boots = np.empty(n_boot) n = len(values) for i in range(n_boot): sample = _rng.choice(values, size=n, replace=True) boots[i] = sample.mean() lo = np.percentile(boots, 100 * (alpha / 2)) hi = np.percentile(boots, 100 * (1 - alpha / 2)) return float(values.mean()), float(lo), float(hi) def format_ci(values, digits=3): m, lo, hi = bootstrap_ci(values) return f"{m:.{digits}f} [{lo:.{digits}f}, {hi:.{digits}f}]" """## 12. Logging""" EVAL_LOG_DIR = Path("./eval_logs") EVAL_LOG_DIR.mkdir(exist_ok=True) def log_eval_row(experiment_id, passage_id, variant, retrieved_sources, generation_result, judge_scores, extra=None): row = { "experiment_id": experiment_id, "passage_id": passage_id, "variant": variant, "model": GENERATOR_MODEL, "n_retrieved": len(retrieved_sources), "retrieved_chunk_ids": ";".join( str(r.get("chunk_id") or r.get("concept_id") or "?") for r in retrieved_sources ), "guard_triggered": int(generation_result.get("guard_triggered", False)), "abstained": int(generation_result.get("abstained", False)), "answer_chars": len(generation_result.get("answer", "")), "generated_text": generation_result.get("answer", ""), "correctness": judge_scores.get("correctness", {}).get("score"), "evidence_support": judge_scores.get("evidence_support", {}).get("score"), "citation_faithfulness": judge_scores.get("citation_faithfulness", {}).get("score"), "semantic_fail_rate": generation_result.get("semantic_fail_rate", np.nan), } if extra: row.update(extra) path = EVAL_LOG_DIR / f"{experiment_id}.csv" pd.DataFrame([row]).to_csv( path, mode="a", header=not path.exists(), index=False ) return row def load_or_run_experiment(experiment_id, runner_fn): local_path = EVAL_LOG_DIR / f"{experiment_id}.csv" if RUN_EXPERIMENTS: if local_path.exists(): local_path.unlink() print(f"running {experiment_id} from scratch...") return runner_fn() url = f"{REPO_RAW_BASE}/eval_logs/{experiment_id}.csv" try: df = pd.read_csv(url) print(f"loaded {experiment_id} from repo cache: {len(df)} rows") df.to_csv(local_path, index=False) return df except Exception: pass if local_path.exists(): df = pd.read_csv(local_path) print(f"loaded {experiment_id} from local cache: {len(df)} rows") return df print(f"⚠ no cached results for {experiment_id}. Set RUN_EXPERIMENTS=True to generate.") return None """## 13. Judge calibration """ RUN_CALIBRATION = False # HERE def calibrate_judge(n_items=5): items = [x for x in validation_set["items"]] sample = items[:n_items] diffs = {"correctness": [], "evidence_support": [], "citation_faithfulness": []} for item in sample: c, p = retrieve_for_variant(item["passage"], "hybrid") result = generate_with_citation_guard(item["passage"], c, p, do_semantic_check=False) explanation = result["answer"] context_text = _format_context_block(c, p) print("=" * 70) print(f"ITEM {item['id']}") print(f"PASSAGE: {item['passage'][:300]}") print(f"\nREFERENCE: {item['reference_explanation']}") print(f"\nSYSTEM EXPLANATION:\n{explanation}") print("\nScore each metric 0/1/2 (0=bad, 1=partial, 2=good):") try: human = { "correctness": int(input(" correctness: ")), "evidence_support": int(input(" evidence_support: ")), "citation_faithfulness": int(input(" citation_faithfulness: ")), } except (ValueError, EOFError): print("aborted") return None all_scores = score_all_metrics( item["passage"], item["reference_explanation"], context_text, explanation ) scores_clean = {k: all_scores[k]["score"] for k in all_scores} for k in diffs: if scores_clean[k] is not None: diffs[k].append(abs(human[k] - scores_clean[k])) print("\n=== CALIBRATION RESULTS ===") for k, vals in diffs.items(): if vals: mad = sum(vals) / len(vals) flag = " ⚠ DISAGREES" if mad > 0.5 else " ok" print(f" {k}: mean abs diff = {mad:.2f}{flag}") return diffs if RUN_CALIBRATION: calibrate_judge(n_items=5) else: print("calibration skipped (RUN_CALIBRATION=False)") print("Last calibration: correctness MAD=0.60 (DISAGREES), evidence MAD=0.40, citation MAD=0.20") """## 14. Experiment A — retrieval ablation **Question.** Does RAG help, and does the regex tier earn its place? **Hypothesis.** All RAG variants will beat the no-RAG baseline on claim_support_rate and evidence_support. Hybrid will beat either single-tier variant. **Variable changed.** Retrieval method ∈ {no_rag, embedding_only, regex_only, hybrid}. Everything else held constant. """ def run_experiment_A(): items = [x for x in validation_set["items"]] variants = ["no_rag", "embedding_only", "regex_only", "hybrid"] total_runs = len(items) * len(variants) print(f"running experiment A: {len(items)} items × {len(variants)} variants = {total_runs} runs") for i, item in enumerate(items): for variant in variants: try: c, p = retrieve_for_variant(item["passage"], variant) retrieved = c + p context_text = _format_context_block(c, p) result = generate_with_citation_guard( item["passage"], c, p, allow_no_context_bypass=(variant == "no_rag"), do_semantic_check=False, ) scores = score_all_metrics( item["passage"], item["reference_explanation"], context_text, result["answer"], ) cb = claim_based_faithfulness( result["answer"], [r["content"] for r in retrieved], ) rp, rr = precision_recall_at_k( [r["content"] for r in retrieved], item["key_terms"], k=3, ) log_eval_row( "experiment_A", item["id"], variant, retrieved, result, scores, extra={ "category": item["category"], "claim_support_rate": cb["support_rate"], "claim_contradiction_rate": cb["contradiction_rate"], "claim_unsupported_rate": cb["unsupported_rate"], "n_claims": cb["n_claims"], "retrieval_precision_at_3": rp, "retrieval_recall_at_3": rr, } ) except Exception as e: print(f" ERROR {item['id']}/{variant}: {e}") print(f" done {item['id']} ({i+1}/{len(items)})") return pd.read_csv(EVAL_LOG_DIR / "experiment_A.csv") experiment_A_df = load_or_run_experiment("experiment_A", run_experiment_A) def analyze_experiment_A(): df = pd.read_csv(EVAL_LOG_DIR / "experiment_A.csv") metric_cols = ["correctness", "evidence_support", "citation_faithfulness", "claim_support_rate", "retrieval_recall_at_3", "abstained"] print("=" * 70) print("OVERALL means with 95% bootstrap CIs") print("=" * 70) for variant in ["no_rag", "embedding_only", "regex_only", "hybrid"]: sub = df[df.variant == variant] print(f"\n{variant}") for m in metric_cols: if m in sub.columns: print(f" {m:28s} {format_ci(sub[m].values)}") print("\n" + "=" * 70) print("HEADLINE METRIC: claim_support_rate (correctness saturates — see report)") print("=" * 70) for variant in ["no_rag", "embedding_only", "regex_only", "hybrid"]: sub = df[df.variant == variant] if "claim_support_rate" in sub.columns: print(f" {variant:18s} {format_ci(sub['claim_support_rate'].values)}") return df def plot_experiment_A(): df = pd.read_csv(EVAL_LOG_DIR / "experiment_A.csv") variant_order = ["no_rag", "embedding_only", "regex_only", "hybrid"] colors = ["#888", "#4c72b0", "#dd8452", "#55a868"] fig, axes = plt.subplots(1, 3, figsize=(16, 5)) means, los, his = [], [], [] for v in variant_order: sub = df[df.variant == v] if "claim_support_rate" in sub.columns: m, lo, hi = bootstrap_ci(sub["claim_support_rate"].values) else: m, lo, hi = 0, 0, 0 means.append(m); los.append(m - lo); his.append(hi - m) axes[0].bar(variant_order, means, yerr=[los, his], color=colors, capsize=5) axes[0].set_title("Claim support rate (headline)") axes[0].set_ylabel("Fraction of claims supported") axes[0].set_ylim(0, 1) axes[0].tick_params(axis="x", rotation=20) means, los, his = [], [], [] for v in variant_order: sub = df[df.variant == v] if "retrieval_recall_at_3" in sub.columns: m, lo, hi = bootstrap_ci(sub["retrieval_recall_at_3"].values) else: m, lo, hi = 0, 0, 0 means.append(m); los.append(m - lo); his.append(hi - m) axes[1].bar(variant_order, means, yerr=[los, his], color=colors, capsize=5) axes[1].set_title("Retrieval recall@3") axes[1].set_ylabel("Fraction of gold key_terms covered") axes[1].set_ylim(0, 1) axes[1].tick_params(axis="x", rotation=20) abs_by_var = df.groupby("variant")["abstained"].mean().reindex(variant_order) axes[2].bar(variant_order, abs_by_var.values, color=colors) axes[2].set_title("Abstention rate") axes[2].set_ylabel("Fraction of items guard triggered") axes[2].set_ylim(0, 1) axes[2].tick_params(axis="x", rotation=20) plt.tight_layout() plt.show() if experiment_A_df is not None: analyze_experiment_A() plot_experiment_A() """### Release gate""" def release_gate_A(variant="hybrid"): df = pd.read_csv(EVAL_LOG_DIR / "experiment_A.csv") sub = df[df.variant == variant] thresholds = { "claim_support_rate": 0.70, # primary "evidence_support": 1.40, "citation_faithfulness": 1.40, "retrieval_recall_at_3": 0.60, "abstained": 0.30, } lower_is_better = {"abstained"} agg = {k: float(np.nanmean(sub[k].values)) for k in thresholds if k in sub.columns} print(f"Release gate for variant: {variant}") print("=" * 60) all_pass = True for k, t in thresholds.items(): if k not in agg: continue v = agg[k] ok = (v <= t) if k in lower_is_better else (v >= t) direction = "≤" if k in lower_is_better else "≥" status = "PASS" if ok else "FAIL" print(f" {k:28s} {v:.3f} (need {direction} {t}) {status}") all_pass = all_pass and ok print(f"\nFINAL: {'PASS' if all_pass else 'FAIL'}") return all_pass if experiment_A_df is not None: release_gate_A(variant="hybrid") print() release_gate_A(variant="regex_only") """## 16. Experiment B — top-k sweep **Question:** How does the number of retrieved sources (top-k) affect answer correctness? **Hypothesis:** Performance peaks somewhere in the middle. k=1 misses context; large k dilutes the prompt with irrelevant chunks. **Variable changed:** `top_k ∈ {1, 3, 5, 7}`, applied to both retrieval tiers. Hybrid retrieval; everything else held constant. """ def run_experiment_B(top_k_values=(1, 3, 5, 7)): items = [x for x in validation_set["items"]] print(f"running experiment B: {len(items)} items × {len(top_k_values)} top-k values " f"= {len(items) * len(top_k_values)} runs") for k in top_k_values: time.sleep(1.5) print(f"\n--- top_k = {k} ---") for item in items: try: c, p = retrieve_for_variant( item["passage"], "hybrid", n_concepts=k, n_papers=k, ) retrieved = c + p context_text = _format_context_block(c, p) result = generate_with_citation_guard( item["passage"], c, p, do_semantic_check=False ) scores = score_all_metrics( item["passage"], item["reference_explanation"], context_text, result["answer"] ) avg_dist = (float(np.mean([r["distance"] for r in retrieved if r["distance"] > 0])) if retrieved else None) log_eval_row( "experiment_B", item["id"], f"topk_{k}", retrieved, result, scores, extra={ "category": item["category"], "top_k": k, "n_retrieved_total": len(retrieved), "avg_distance": avg_dist, } ) except Exception as e: print(f" ERROR {item['id']}: {e}") return pd.read_csv(EVAL_LOG_DIR / "experiment_B.csv") def analyze_experiment_B(): df = pd.read_csv(EVAL_LOG_DIR / "experiment_B.csv") print("mean correctness by top-k (with 95% CI):") for k in sorted(df["top_k"].unique()): sub = df[df.top_k == k] print(f" top_k={k} " f"correctness={format_ci(sub['correctness'].values)} " f"evidence={format_ci(sub['evidence_support'].values)} " f"avg_n_retrieved={sub['n_retrieved_total'].mean():.1f}") return df def plot_experiment_B(): df = pd.read_csv(EVAL_LOG_DIR / "experiment_B.csv") ks = sorted(df["top_k"].unique()) means, los, his = [], [], [] for k in ks: m, lo, hi = bootstrap_ci(df[df.top_k == k]["correctness"].values) means.append(m); los.append(m - lo); his.append(hi - m) fig, ax = plt.subplots(figsize=(8, 5)) ax.errorbar(ks, means, yerr=[los, his], marker="o", linewidth=2, capsize=5) ax.set_xlabel("top-k (per retrieval tier)") ax.set_ylabel("Mean correctness (0-2)") ax.set_title("Experiment B — Correctness vs top-k (95% CI)") ax.set_xticks(ks) ax.set_ylim(0, 2) ax.grid(alpha=0.3) plt.tight_layout() plt.show() experiment_B_df = load_or_run_experiment("experiment_B", run_experiment_B) if experiment_B_df is not None: analyze_experiment_B() plot_experiment_B() """## 17. Experiment C — confidence threshold tuning **Question.** Where should the low-confidence threshold sit so that warnings correlate with wrong answers? **Hypothesis.** The default 1.3 threshold was a guess. The F1-maximizing threshold is probably lower. **Variable changed.** Threshold ∈ [0.6, 1.6] by 0.1. """ def run_experiment_C(): items = [x for x in validation_set["items"]] for item in items: try: time.sleep(1.5) c, p = retrieve_for_variant(item["passage"], "hybrid") all_dists = [r["distance"] for r in (c + p) if r["distance"] > 0] best_dist = float(min(all_dists)) if all_dists else 999.0 context_text = _format_context_block(c, p) result = generate_with_citation_guard( item["passage"], c, p, do_semantic_check=False ) scores = score_all_metrics( item["passage"], item["reference_explanation"], context_text, result["answer"] ) corr = scores["correctness"]["score"] log_eval_row( "experiment_C", item["id"], "default_system", c + p, result, scores, extra={ "category": item["category"], "best_distance": best_dist, "correctness_raw": corr, } ) except Exception as e: print(f" ERROR {item['id']}: {e}") return pd.read_csv(EVAL_LOG_DIR / "experiment_C.csv") def analyze_experiment_C(): df = pd.read_csv(EVAL_LOG_DIR / "experiment_C.csv") if "correctness_raw" not in df.columns: df["correctness_raw"] = df["correctness"] df = df.dropna(subset=["best_distance", "correctness_raw"]) strict_pos = (df["correctness_raw"] < 1.0).sum() if strict_pos > 0: df["is_wrong"] = (df["correctness_raw"] < 1.0).astype(int) wrongness_def = "correctness < 1.0 (strict)" else: df["is_wrong"] = (df["correctness_raw"] < 2.0).astype(int) wrongness_def = "correctness < 2.0 (saturation fallback)" print(f"using wrongness definition: {wrongness_def}") print(f" positive class size: {df['is_wrong'].sum()}/{len(df)}") if df["is_wrong"].sum() == 0: print("⚠ WARNING: no wrong answers in eval set. Tuning is meaningless on this data.") return None, None, None thresholds = [round(0.6 + 0.1 * i, 2) for i in range(11)] rows = [] for t in thresholds: warns = df["best_distance"] > t wrong = df["is_wrong"] == 1 tp = int(((warns) & (wrong)).sum()) fp = int(((warns) & (~wrong)).sum()) fn = int(((~warns) & (wrong)).sum()) tn = int(((~warns) & (~wrong)).sum()) precision = tp / (tp + fp) if (tp + fp) > 0 else 1.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2*precision*recall / (precision + recall) if (precision + recall) > 0 else 0.0 tpr = recall fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0 rows.append({"threshold": t, "tp": tp, "fp": fp, "fn": fn, "tn": tn, "refusal_precision": round(precision, 3), "refusal_recall": round(recall, 3), "f1": round(f1, 3), "tpr": round(tpr, 3), "fpr": round(fpr, 3)}) sweep = pd.DataFrame(rows) print(sweep) best = sweep.loc[sweep["f1"].idxmax()] print(f"\nF1-maximizing threshold: {best['threshold']} (F1={best['f1']})") print(f" refusal precision: {best['refusal_precision']}") print(f" refusal recall: {best['refusal_recall']}") s = sweep.sort_values("fpr") auc = 0.0 for i in range(1, len(s)): auc += (s.iloc[i]["fpr"] - s.iloc[i-1]["fpr"]) * (s.iloc[i]["tpr"] + s.iloc[i-1]["tpr"]) / 2 print(f"approx ROC AUC: {auc:.3f}") return sweep, best, auc def plot_experiment_C(): out = analyze_experiment_C() if out is None or out[0] is None: return sweep, best, auc = out fig, ax = plt.subplots(figsize=(7, 7)) s = sweep.sort_values("fpr") ax.plot(s["fpr"], s["tpr"], marker="o", linewidth=2, label=f"ROC (AUC≈{auc:.3f})") ax.plot([0, 1], [0, 1], "--", color="gray", alpha=0.5, label="chance") ax.scatter([best["fpr"]], [best["tpr"]], s=200, color="red", zorder=5, label=f"best F1 @ threshold={best['threshold']}") ax.set_xlabel("False positive rate") ax.set_ylabel("True positive rate (refusal recall)") ax.set_title("Experiment C — Abstention threshold ROC") ax.set_xlim(-0.05, 1.05); ax.set_ylim(-0.05, 1.05) ax.legend(loc="lower right") ax.grid(alpha=0.3) plt.tight_layout() plt.show() experiment_C_df = load_or_run_experiment("experiment_C", run_experiment_C) if experiment_C_df is not None: plot_experiment_C() DEFAULT_CONFIDENCE_THRESHOLD = 1.3 try: if experiment_C_df is not None: out = analyze_experiment_C() if out and out[1] is not None: _, best, _ = out TUNED_CONFIDENCE_THRESHOLD = float(best["threshold"]) else: TUNED_CONFIDENCE_THRESHOLD = DEFAULT_CONFIDENCE_THRESHOLD else: TUNED_CONFIDENCE_THRESHOLD = DEFAULT_CONFIDENCE_THRESHOLD except Exception as e: print(f"falling back to default threshold ({e})") TUNED_CONFIDENCE_THRESHOLD = DEFAULT_CONFIDENCE_THRESHOLD print(f"TUNED_CONFIDENCE_THRESHOLD = {TUNED_CONFIDENCE_THRESHOLD}") """## 18. Main pipeline (with citations + tuned threshold + semantic check)""" def check_input_quality(text): if len(text.strip()) < 20: return False, "That's pretty short — try pasting a full sentence or paragraph from a paper." if len(text.strip()) > 3000: return False, "That's a lot of text. Try pasting just 1-2 paragraphs at a time." if len(text.split()) < 5: return False, "Try a longer passage — at least a full sentence from a paper." return True, "ok" def assess_retrieval_confidence(concept_results, paper_results, threshold=None): threshold = threshold if threshold is not None else TUNED_CONFIDENCE_THRESHOLD dists = [r["distance"] for r in (concept_results + paper_results) if r["distance"] > 0] if not dists: return "low", "I couldn't find any relevant context in my knowledge base." best = min(dists) if best < 0.8: return "high", "" elif best < threshold: return "medium", ("Note: my knowledge base has some related material, but the match isn't perfect. " "Double-check against the paper's own definitions.") else: return "low", "Heads up: the concepts in this passage don't match well with my current knowledge base." SCOPE_DISCLAIMER = ( "---\n" "*This tool helps you understand papers; it doesn't replace them. " "Every factual sentence above is cited to a specific retrieved source. " "⚠️ marks indicate the semantic guard flagged that sentence as not fully supported by its citation. " "Always check the original paper.*" ) def scimplify(passage, variant="hybrid"): is_ok, msg = check_input_quality(passage) if not is_ok: return msg c, p = retrieve_for_variant(passage, variant) confidence, warning = assess_retrieval_confidence(c, p) result = generate_with_citation_guard(passage, c, p) parts = [] if result["guard_triggered"]: which = "semantic" if any("semantic" in i for i in result.get("issues", [])) else "lexical" parts.append(f"⚠️ The {which} citation guard triggered. Returning abstention rather than a potentially ungrounded answer.") if result.get("issues"): parts.append(f"\n*Reason: {'; '.join(result['issues'])}*") parts.append(f"\n{result['answer']}") parts.append(f"\n{SCOPE_DISCLAIMER}") return "\n".join(parts) if result["answer"].strip() == ABSTAIN_MESSAGE or result["abstained"]: parts = [result["answer"], SCOPE_DISCLAIMER] return "\n".join(parts) if confidence == "low": parts.append(f"⚠️ {warning}\n") elif confidence == "medium": parts.append(f"ℹ️ {warning}\n") parts.append(result["answer"]) # show retrieved sources concept_names = [r["concept_name"] for r in c if "concept_name" in r] if concept_names: parts.append(f"\n\n**Retrieved concepts:** {', '.join(concept_names)}") if p: sources = sorted(set(r["source_name"] for r in p)) parts.append(f"**Paper sources:** {', '.join(sources)}") # surface semantic check stats if any sentences were checked findings = result.get("semantic_findings", []) if findings: n_total = len(findings) n_unsupported = sum(1 for f in findings if f["label"] in ("contradicted", "insufficient")) if n_unsupported > 0: parts.append(f"\n*Semantic guard: {n_unsupported}/{n_total} cited sentences flagged as not fully supported.*") else: parts.append(f"\n*Semantic guard: all {n_total} cited sentences supported by their citations ✓*") parts.append(f"\n{SCOPE_DISCLAIMER}") return "\n".join(parts) """## 19. Gradio UI """ def add_pdf_to_kb(pdf_file, source_name, source_type): if pdf_file is None: return "Please upload a PDF file." if not source_name.strip(): return "Please provide a name for this source." try: text = extract_text_from_pdf(pdf_file) chunks = chunk_text(text) base = source_name.strip().replace(" ", "_") ids = [f"user_{base}::c{i}" for i in range(len(chunks))] metas = [{ "source_name": source_name.strip(), "source_type": source_type, "chunk_id": ids[i], } for i in range(len(chunks))] if chunks: papers_collection.add(documents=chunks, ids=ids, metadatas=metas) return f"Added {len(chunks)} chunks. Total: {papers_collection.count()}" except Exception as e: return f"Error: {e}" def pull_from_arxiv_ui(query, max_results): """Gradio handler for arXiv ingestion.""" try: max_results = int(max_results) if max_results < 1 or max_results > 25: return "Please pick a max_results between 1 and 25." summary = ingest_from_arxiv(query=query, max_results=max_results, verbose=False) msg = ( f"✅ Ingested {summary['n_papers']} new paper(s), " f"{summary['n_chunks']} chunks. " f"Skipped {summary['n_skipped']} duplicates. " f"Total in KB: {summary['total_in_kb']} chunks." ) if summary["errors"]: msg += f"\n\n⚠️ Errors: {'; '.join(summary['errors'][:3])}" return msg except Exception as e: return f"Error: {e}" def get_kb_status(): n_concepts = concepts_collection.count() n_papers = papers_collection.count() status = f"**Concept definitions:** {n_concepts}\n\n**Paper chunks:** {n_papers}\n" if n_papers > 0: metas = papers_collection.get()["metadatas"] sources = Counter(m["source_name"] for m in metas) status += "\n**Ingested sources:**\n" for name, count in sources.most_common(): status += f"- {name} — {count} chunks\n" return status DEMO_CLEAN = "The Diels-Alder reaction is a [4+2] cycloaddition between a conjugated diene and a dienophile, producing a six-membered ring with up to four new stereocenters. The reaction proceeds through a concerted, suprafacial transition state and is highly stereospecific: cis-dienophiles yield cis-substituted cyclohexenes. Electron-withdrawing groups on the dienophile dramatically accelerate the reaction." DEMO_PAPER = "Multi-region neural population dynamics in the brain have been studied using techniques like LFADS to model the latent factors driving observed activity across regions." DEMO_ABSTAIN = "Laminated pastry dough is created by repeatedly folding butter into flour-water dough, producing alternating layers that puff up during baking as steam expands between them. Croissants are the canonical example." with gr.Blocks(title="Scimplify") as app: gr.Markdown("# Scimplify — NeuroAI Paper Simplifier") gr.Markdown( "Paste a NeuroAI paragraph; get a plain-language explanation with citations. " "Every factual sentence is grounded in a retrieved source. The lexical guard rejects " "invented citation IDs, and the semantic guard verifies that each cited chunk actually " "supports the claim. If neither passes, the system abstains rather than hallucinate." ) with gr.Tab("Explain Passage"): with gr.Row(): with gr.Column(scale=1): inp = gr.Textbox(label="Passage", lines=8, placeholder="Paste a paragraph from a paper...") btn = gr.Button("Explain", variant="primary") gr.Examples( examples=[ [DEMO_CLEAN], [DEMO_PAPER], [DEMO_ABSTAIN], ], inputs=[inp], label="Demo passages (clean / paper-chunk / out-of-scope)", ) with gr.Column(scale=2): out = gr.Markdown(label="Explanation") btn.click(fn=lambda x: scimplify(x), inputs=[inp], outputs=[out]) with gr.Tab("Add Papers (PDF)"): pdf_in = gr.File(label="PDF", file_types=[".pdf"]) name_in = gr.Textbox(label="Source name") type_in = gr.Radio(["paper", "article", "review"], label="Type", value="paper") add_btn = gr.Button("Add to knowledge base") add_out = gr.Textbox(label="Status") add_btn.click(fn=add_pdf_to_kb, inputs=[pdf_in, name_in, type_in], outputs=[add_out]) with gr.Tab("Pull from arXiv"): gr.Markdown( "Fetch recent NeuroAI papers from arXiv directly. " "Skips papers already in the knowledge base (matched by arxiv_id)." ) arxiv_query = gr.Textbox( label="arXiv query", value="NeuroAI", placeholder="e.g. NeuroAI, brain-inspired deep learning, neural population dynamics", ) arxiv_n = gr.Slider(label="Max papers", minimum=1, maximum=20, value=5, step=1) arxiv_btn = gr.Button("Pull from arXiv", variant="primary") arxiv_out = gr.Markdown() arxiv_btn.click(fn=pull_from_arxiv_ui, inputs=[arxiv_query, arxiv_n], outputs=[arxiv_out]) with gr.Tab("Knowledge Base"): status_out = gr.Markdown(value=get_kb_status()) refresh_btn = gr.Button("Refresh") refresh_btn.click(fn=get_kb_status, outputs=[status_out]) app.launch(share=True)