Scipaths / src /step_05_verify_uses_extends /verify_uses_extends.py
Eric Chamoun
Initial SciPaths Space release
0a55f0f
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, List
SRC_ROOT = Path(__file__).resolve().parents[1]
if str(SRC_ROOT) not in sys.path:
sys.path.insert(0, str(SRC_ROOT))
from common.llm_client import LLMClient
from prompts import build_uses_extends_verification_prompt
from schemas import USES_EXTENDS_VERIFICATION_JSON_SCHEMA
PAPER_META_FILE = "paper_metadata.json"
USAGE_LABELS_FILE = "usage_context_labels.json"
OUT_FILE = "usage_uses_extends_verified.json"
USE_LABELS = {"Uses", "Extends"}
def load_json(path: Path) -> Any | None:
if not path.exists():
return None
try:
return json.loads(path.read_text(encoding="utf-8"))
except Exception:
return None
def iter_paper_dirs(root: Path) -> List[Path]:
out: List[Path] = []
for child in root.iterdir():
if child.is_dir() and (child / PAPER_META_FILE).exists():
out.append(child)
return out
def _normalize_author_last(name: str) -> str:
parts = [p for p in (name or "").split() if p.strip()]
return parts[-1] if parts else ""
def extract_target_info(meta: Any) -> Dict[str, str]:
if isinstance(meta, list) and meta:
meta = meta[0]
if not isinstance(meta, dict):
return {"title": "", "first_author_last": "", "year": ""}
authors = meta.get("authors") or []
first_author = authors[0]["name"] if authors else ""
return {
"title": meta.get("title", ""),
"first_author_last": _normalize_author_last(first_author),
"year": str(meta.get("year", "")),
}
def verify_candidates(
client: LLMClient,
target_info: Dict[str, str],
candidates: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
prompt = build_uses_extends_verification_prompt(target_info, candidates)
try:
raw = client.call(prompt, schema=USES_EXTENDS_VERIFICATION_JSON_SCHEMA)
except Exception as exc:
print(f"[WARN] LLM call failed: {exc}. Marking all candidates NOT_CONFIRMED.")
return [
{
"id": item.get("id"),
"label": "NOT_CONFIRMED",
"cue_span": "",
"rationale": "",
"text": item.get("text", ""),
"citing_paper_id": item.get("citing_paper_id", ""),
"citing_title": item.get("citing_title", ""),
"original_label": item.get("original_label", ""),
}
for item in candidates
]
data = _parse_llm_json(raw)
if not isinstance(data, dict):
print("[WARN] Failed to parse LLM JSON response; marking all candidates NOT_CONFIRMED.")
return [
{
"id": item.get("id"),
"label": "NOT_CONFIRMED",
"cue_span": "",
"rationale": "",
"text": item.get("text", ""),
"citing_paper_id": item.get("citing_paper_id", ""),
"citing_title": item.get("citing_title", ""),
"original_label": item.get("original_label", ""),
}
for item in candidates
]
labels = data.get("labels", [])
by_id = {item.get("id"): item for item in labels if isinstance(item, dict)}
verified: List[Dict[str, Any]] = []
for candidate in candidates:
item_id = candidate["id"]
model = by_id.get(item_id, {})
label = model.get("label", "NOT_CONFIRMED")
cue_span = model.get("cue_span", "")
if not cue_span:
label = "NOT_CONFIRMED"
verified.append(
{
"id": item_id,
"label": label,
"cue_span": cue_span,
"rationale": model.get("rationale", ""),
"text": candidate.get("text", ""),
"citing_paper_id": candidate.get("citing_paper_id", ""),
"citing_title": candidate.get("citing_title", ""),
"original_label": candidate.get("original_label", ""),
}
)
return verified
def _parse_llm_json(raw: str) -> Any | None:
try:
return json.loads(raw)
except json.JSONDecodeError:
pass
cleaned = raw.strip()
if cleaned.startswith("```"):
cleaned = cleaned.strip("`")
cleaned = cleaned.replace("json", "", 1).strip()
start = cleaned.find("{")
end = cleaned.rfind("}")
if start == -1 or end == -1 or end <= start:
return None
snippet = cleaned[start : end + 1]
try:
return json.loads(snippet)
except json.JSONDecodeError:
return None
def process_paper(
paper_dir: Path,
client: LLMClient,
k: int,
batch_size: int,
overwrite: bool,
resume: bool,
) -> str:
labels_path = paper_dir / USAGE_LABELS_FILE
payload = load_json(labels_path)
if not isinstance(payload, dict):
return "missing_labels"
out_path = paper_dir / OUT_FILE
if out_path.exists() and (resume or not overwrite):
return "skipped"
labels = payload.get("labels", [])
candidates_all = []
for item in labels:
if item.get("label") in USE_LABELS:
candidates_all.append(
{
"id": item.get("id"),
"text": item.get("text", ""),
"citing_paper_id": item.get("citing_paper_id", ""),
"citing_title": item.get("citing_title", ""),
"original_label": item.get("label"),
"confidence": float(item.get("confidence", 0.0) or 0.0),
}
)
if not candidates_all:
result = {
"paper_id": payload.get("paper_id"),
"target": {},
"candidates_total": 0,
"candidates_considered": 0,
"verified": [],
"confirmed": [],
}
out_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
return "no_candidates"
# Keep top-k highest-confidence USES/EXTENDS contexts for LLM verification.
# If k <= 0, verify all candidates.
candidates_all = sorted(
candidates_all,
key=lambda x: x.get("confidence", 0.0),
reverse=True,
)
candidates = candidates_all if k <= 0 else candidates_all[:k]
target_info = extract_target_info(load_json(paper_dir / PAPER_META_FILE))
verified: List[Dict[str, Any]] = []
if batch_size <= 0:
batch_size = 25
for i in range(0, len(candidates), batch_size):
batch = candidates[i : i + batch_size]
verified.extend(verify_candidates(client, target_info, batch))
confirmed = [v for v in verified if v["label"] in {"USES", "EXTENDS"}]
if any(item["label"] == "EXTENDS" for item in confirmed):
final_label = "EXTENDS"
elif confirmed:
final_label = "USES"
else:
final_label = "NOT_CONFIRMED"
result = {
"paper_id": payload.get("paper_id"),
"target": target_info,
"candidates_total": len(candidates_all),
"candidates_considered": len(candidates),
"verification_batch_size": int(batch_size),
"verification_num_batches": (len(candidates) + batch_size - 1) // batch_size if candidates else 0,
"candidates_selected": len(confirmed),
"verified": verified,
"confirmed": confirmed,
"confirmed_extends": sum(1 for x in confirmed if x.get("label") == "EXTENDS"),
"confirmed_uses": sum(1 for x in confirmed if x.get("label") == "USES"),
"final_label": final_label,
}
out_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
return "verified"
def main() -> None:
parser = argparse.ArgumentParser(
description="Verify USES/EXTENDS candidates via LLM and select top-K."
)
parser.add_argument(
"--root",
type=str,
default="runs/processed_papers",
help="Root directory containing processed paper directories.",
)
parser.add_argument(
"--k",
type=int,
default=0,
help="Verify top-k USES/EXTENDS candidates ranked by classifier confidence (<=0 means all).",
)
parser.add_argument(
"--batch-size",
type=int,
default=25,
help="Number of candidates per LLM verification batch.",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing usage_uses_extends_verified.json files.",
)
parser.add_argument(
"--resume",
action="store_true",
help="Skip papers with existing output files (even if --overwrite is set).",
)
args = parser.parse_args()
root = Path(args.root).expanduser().resolve()
if not root.exists():
raise SystemExit(f"Root directory does not exist: {root}")
client = LLMClient()
paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
counts = {"verified": 0, "skipped": 0, "missing_labels": 0, "no_candidates": 0}
for paper_dir in paper_dirs:
status = process_paper(
paper_dir,
client,
args.k,
args.batch_size,
args.overwrite,
args.resume,
)
counts[status] = counts.get(status, 0) + 1
print(f"[{status.upper()}] {paper_dir.name}")
print(
"[SUMMARY] verified={verified}, skipped={skipped}, missing_labels={missing_labels}, "
"no_candidates={no_candidates}".format(**counts)
)
if __name__ == "__main__":
main()