| import argparse |
| import json |
| import sys |
| from pathlib import Path |
| from typing import Any, Dict, List |
|
|
| SRC_ROOT = Path(__file__).resolve().parents[1] |
| if str(SRC_ROOT) not in sys.path: |
| sys.path.insert(0, str(SRC_ROOT)) |
|
|
| from common.llm_client import LLMClient |
|
|
| from prompts import build_uses_extends_verification_prompt |
| from schemas import USES_EXTENDS_VERIFICATION_JSON_SCHEMA |
|
|
|
|
| PAPER_META_FILE = "paper_metadata.json" |
| USAGE_LABELS_FILE = "usage_context_labels.json" |
| OUT_FILE = "usage_uses_extends_verified.json" |
|
|
| USE_LABELS = {"Uses", "Extends"} |
|
|
|
|
| def load_json(path: Path) -> Any | None: |
| if not path.exists(): |
| return None |
| try: |
| return json.loads(path.read_text(encoding="utf-8")) |
| except Exception: |
| return None |
|
|
|
|
| def iter_paper_dirs(root: Path) -> List[Path]: |
| out: List[Path] = [] |
| for child in root.iterdir(): |
| if child.is_dir() and (child / PAPER_META_FILE).exists(): |
| out.append(child) |
| return out |
|
|
|
|
| def _normalize_author_last(name: str) -> str: |
| parts = [p for p in (name or "").split() if p.strip()] |
| return parts[-1] if parts else "" |
|
|
|
|
| def extract_target_info(meta: Any) -> Dict[str, str]: |
| if isinstance(meta, list) and meta: |
| meta = meta[0] |
| if not isinstance(meta, dict): |
| return {"title": "", "first_author_last": "", "year": ""} |
| authors = meta.get("authors") or [] |
| first_author = authors[0]["name"] if authors else "" |
| return { |
| "title": meta.get("title", ""), |
| "first_author_last": _normalize_author_last(first_author), |
| "year": str(meta.get("year", "")), |
| } |
|
|
|
|
| def verify_candidates( |
| client: LLMClient, |
| target_info: Dict[str, str], |
| candidates: List[Dict[str, Any]], |
| ) -> List[Dict[str, Any]]: |
| prompt = build_uses_extends_verification_prompt(target_info, candidates) |
| try: |
| raw = client.call(prompt, schema=USES_EXTENDS_VERIFICATION_JSON_SCHEMA) |
| except Exception as exc: |
| print(f"[WARN] LLM call failed: {exc}. Marking all candidates NOT_CONFIRMED.") |
| return [ |
| { |
| "id": item.get("id"), |
| "label": "NOT_CONFIRMED", |
| "cue_span": "", |
| "rationale": "", |
| "text": item.get("text", ""), |
| "citing_paper_id": item.get("citing_paper_id", ""), |
| "citing_title": item.get("citing_title", ""), |
| "original_label": item.get("original_label", ""), |
| } |
| for item in candidates |
| ] |
| data = _parse_llm_json(raw) |
| if not isinstance(data, dict): |
| print("[WARN] Failed to parse LLM JSON response; marking all candidates NOT_CONFIRMED.") |
| return [ |
| { |
| "id": item.get("id"), |
| "label": "NOT_CONFIRMED", |
| "cue_span": "", |
| "rationale": "", |
| "text": item.get("text", ""), |
| "citing_paper_id": item.get("citing_paper_id", ""), |
| "citing_title": item.get("citing_title", ""), |
| "original_label": item.get("original_label", ""), |
| } |
| for item in candidates |
| ] |
| labels = data.get("labels", []) |
| by_id = {item.get("id"): item for item in labels if isinstance(item, dict)} |
|
|
| verified: List[Dict[str, Any]] = [] |
| for candidate in candidates: |
| item_id = candidate["id"] |
| model = by_id.get(item_id, {}) |
| label = model.get("label", "NOT_CONFIRMED") |
| cue_span = model.get("cue_span", "") |
| if not cue_span: |
| label = "NOT_CONFIRMED" |
| verified.append( |
| { |
| "id": item_id, |
| "label": label, |
| "cue_span": cue_span, |
| "rationale": model.get("rationale", ""), |
| "text": candidate.get("text", ""), |
| "citing_paper_id": candidate.get("citing_paper_id", ""), |
| "citing_title": candidate.get("citing_title", ""), |
| "original_label": candidate.get("original_label", ""), |
| } |
| ) |
| return verified |
|
|
|
|
| def _parse_llm_json(raw: str) -> Any | None: |
| try: |
| return json.loads(raw) |
| except json.JSONDecodeError: |
| pass |
|
|
| cleaned = raw.strip() |
| if cleaned.startswith("```"): |
| cleaned = cleaned.strip("`") |
| cleaned = cleaned.replace("json", "", 1).strip() |
|
|
| start = cleaned.find("{") |
| end = cleaned.rfind("}") |
| if start == -1 or end == -1 or end <= start: |
| return None |
|
|
| snippet = cleaned[start : end + 1] |
| try: |
| return json.loads(snippet) |
| except json.JSONDecodeError: |
| return None |
|
|
|
|
| def process_paper( |
| paper_dir: Path, |
| client: LLMClient, |
| k: int, |
| batch_size: int, |
| overwrite: bool, |
| resume: bool, |
| ) -> str: |
| labels_path = paper_dir / USAGE_LABELS_FILE |
| payload = load_json(labels_path) |
| if not isinstance(payload, dict): |
| return "missing_labels" |
|
|
| out_path = paper_dir / OUT_FILE |
| if out_path.exists() and (resume or not overwrite): |
| return "skipped" |
|
|
| labels = payload.get("labels", []) |
| candidates_all = [] |
| for item in labels: |
| if item.get("label") in USE_LABELS: |
| candidates_all.append( |
| { |
| "id": item.get("id"), |
| "text": item.get("text", ""), |
| "citing_paper_id": item.get("citing_paper_id", ""), |
| "citing_title": item.get("citing_title", ""), |
| "original_label": item.get("label"), |
| "confidence": float(item.get("confidence", 0.0) or 0.0), |
| } |
| ) |
|
|
| if not candidates_all: |
| result = { |
| "paper_id": payload.get("paper_id"), |
| "target": {}, |
| "candidates_total": 0, |
| "candidates_considered": 0, |
| "verified": [], |
| "confirmed": [], |
| } |
| out_path.write_text(json.dumps(result, indent=2), encoding="utf-8") |
| return "no_candidates" |
|
|
| |
| |
| candidates_all = sorted( |
| candidates_all, |
| key=lambda x: x.get("confidence", 0.0), |
| reverse=True, |
| ) |
| candidates = candidates_all if k <= 0 else candidates_all[:k] |
|
|
| target_info = extract_target_info(load_json(paper_dir / PAPER_META_FILE)) |
| verified: List[Dict[str, Any]] = [] |
| if batch_size <= 0: |
| batch_size = 25 |
| for i in range(0, len(candidates), batch_size): |
| batch = candidates[i : i + batch_size] |
| verified.extend(verify_candidates(client, target_info, batch)) |
| confirmed = [v for v in verified if v["label"] in {"USES", "EXTENDS"}] |
| if any(item["label"] == "EXTENDS" for item in confirmed): |
| final_label = "EXTENDS" |
| elif confirmed: |
| final_label = "USES" |
| else: |
| final_label = "NOT_CONFIRMED" |
|
|
| result = { |
| "paper_id": payload.get("paper_id"), |
| "target": target_info, |
| "candidates_total": len(candidates_all), |
| "candidates_considered": len(candidates), |
| "verification_batch_size": int(batch_size), |
| "verification_num_batches": (len(candidates) + batch_size - 1) // batch_size if candidates else 0, |
| "candidates_selected": len(confirmed), |
| "verified": verified, |
| "confirmed": confirmed, |
| "confirmed_extends": sum(1 for x in confirmed if x.get("label") == "EXTENDS"), |
| "confirmed_uses": sum(1 for x in confirmed if x.get("label") == "USES"), |
| "final_label": final_label, |
| } |
| out_path.write_text(json.dumps(result, indent=2), encoding="utf-8") |
| return "verified" |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Verify USES/EXTENDS candidates via LLM and select top-K." |
| ) |
| parser.add_argument( |
| "--root", |
| type=str, |
| default="runs/processed_papers", |
| help="Root directory containing processed paper directories.", |
| ) |
| parser.add_argument( |
| "--k", |
| type=int, |
| default=0, |
| help="Verify top-k USES/EXTENDS candidates ranked by classifier confidence (<=0 means all).", |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=25, |
| help="Number of candidates per LLM verification batch.", |
| ) |
| parser.add_argument( |
| "--overwrite", |
| action="store_true", |
| help="Overwrite existing usage_uses_extends_verified.json files.", |
| ) |
| parser.add_argument( |
| "--resume", |
| action="store_true", |
| help="Skip papers with existing output files (even if --overwrite is set).", |
| ) |
| args = parser.parse_args() |
|
|
| root = Path(args.root).expanduser().resolve() |
| if not root.exists(): |
| raise SystemExit(f"Root directory does not exist: {root}") |
|
|
| client = LLMClient() |
| paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name) |
| print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}") |
|
|
| counts = {"verified": 0, "skipped": 0, "missing_labels": 0, "no_candidates": 0} |
| for paper_dir in paper_dirs: |
| status = process_paper( |
| paper_dir, |
| client, |
| args.k, |
| args.batch_size, |
| args.overwrite, |
| args.resume, |
| ) |
| counts[status] = counts.get(status, 0) + 1 |
| print(f"[{status.upper()}] {paper_dir.name}") |
|
|
| print( |
| "[SUMMARY] verified={verified}, skipped={skipped}, missing_labels={missing_labels}, " |
| "no_candidates={no_candidates}".format(**counts) |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|