| import argparse |
| import json |
| import sys |
| from pathlib import Path |
| from typing import Any, Dict, List |
|
|
| DEEP_CITATION_ROOT = Path(__file__).resolve().parents[2] / "Deep-Citation" |
| if not DEEP_CITATION_ROOT.exists(): |
| raise SystemExit(f"Deep-Citation repo not found at {DEEP_CITATION_ROOT}") |
|
|
| sys.path.insert(0, str(DEEP_CITATION_ROOT)) |
|
|
| from data import CollateFn, create_data_channels |
| from Model import MultiHeadLanguageModel |
| import torch |
| from torch.utils.data import DataLoader |
|
|
|
|
| PAPER_META_FILE = "paper_metadata.json" |
| USAGE_CONTEXTS_FILE = "usage_contexts.json" |
| OUT_FILE = "usage_context_labels.json" |
|
|
| LABEL_SET = [ |
| "Background", |
| "Uses", |
| "Extends", |
| "CompareOrContrast", |
| "Motivation", |
| "Future", |
| ] |
|
|
|
|
| def load_json(path: Path) -> Any | None: |
| if not path.exists(): |
| return None |
| try: |
| return json.loads(path.read_text(encoding="utf-8")) |
| except Exception: |
| return None |
|
|
|
|
| def iter_paper_dirs(root: Path) -> List[Path]: |
| out: List[Path] = [] |
| for child in root.iterdir(): |
| if child.is_dir() and (child / PAPER_META_FILE).exists(): |
| out.append(child) |
| return out |
|
|
|
|
| def flatten_contexts(usage: Dict[str, Any]) -> List[Dict[str, Any]]: |
| contexts: List[Dict[str, Any]] = [] |
| idx = 1 |
| for entry in usage.get("citing_papers", []) or []: |
| if not isinstance(entry, dict): |
| continue |
| citing_title = entry.get("title") or "Unknown citing paper" |
| citing_paper_id = entry.get("citing_paper_id") or "" |
| for c in entry.get("contexts", []) or []: |
| if not isinstance(c, dict): |
| continue |
| text = (c.get("text") or "").strip() |
| if not text: |
| continue |
| contexts.append( |
| { |
| "id": idx, |
| "text": text, |
| "citing_title": citing_title, |
| "citing_paper_id": citing_paper_id, |
| } |
| ) |
| idx += 1 |
| return contexts |
|
|
|
|
| def _resolve_model_name(lm: str) -> str: |
| if lm == "scibert": |
| return "allenai/scibert_scivocab_uncased" |
| if lm == "bert": |
| return "bert-base-uncased" |
| if lm == "deberta": |
| return "microsoft/deberta-v3-base" |
| if lm == "deberta-large": |
| return "microsoft/deberta-v3-large" |
| return lm |
|
|
|
|
| def _infer_head_sizes(state_dict: Dict[str, Any]) -> List[int]: |
| head_weights = [ |
| (k, v) for k, v in state_dict.items() if k.startswith("lns.") and k.endswith(".weight") |
| ] |
| head_weights.sort(key=lambda x: int(x[0].split(".")[1])) |
| return [int(weight.shape[0]) for _, weight in head_weights] |
|
|
|
|
| class _ContextDataset: |
| def __init__(self, texts: List[str]): |
| self.texts = texts |
|
|
| def __len__(self) -> int: |
| return len(self.texts) |
|
|
| def __getitem__(self, idx: int): |
| return (self.texts[idx], torch.tensor(0), torch.tensor(0)) |
|
|
|
|
| def label_with_model( |
| contexts: List[Dict[str, Any]], |
| model_path: Path, |
| data_dir: Path, |
| class_definition: Path, |
| lm: str, |
| device: str, |
| batch_size: int, |
| ) -> Dict[int, Dict[str, Any]]: |
| data_file = data_dir / "acl.tsv" |
| train_data, _, _, label_names = create_data_channels( |
| str(data_file), |
| str(class_definition), |
| lmbd=1.0, |
| ) |
| modelname = _resolve_model_name(lm) |
| state_dict = torch.load(model_path, map_location=device) |
| head_sizes = _infer_head_sizes(state_dict) |
| model = MultiHeadLanguageModel( |
| modelname=modelname, |
| device=device, |
| readout="ch", |
| num_classes=head_sizes, |
| ).to(device) |
| model.load_state_dict(state_dict) |
| model.eval() |
|
|
| collate_fn = CollateFn( |
| modelname=modelname, |
| class_definitions=train_data.class_definitions, |
| instance_weights=False, |
| ) |
| dataset = _ContextDataset([ctx["text"] for ctx in contexts]) |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) |
|
|
| outputs: Dict[int, Dict[str, Any]] = {} |
| idx_offset = 0 |
| with torch.no_grad(): |
| for batched_text, labels, ds_indices, class_tokens, class_ds_indices in loader: |
| ds_indices = ds_indices.to(device) |
| class_ds_indices = class_ds_indices.to(device) |
| logits = model(batched_text, ds_indices, class_tokens, class_ds_indices)[0] |
| probs = torch.softmax(logits, dim=1) |
| preds = logits.argmax(dim=1).cpu().tolist() |
| pred_confidences = probs.max(dim=1).values.cpu().tolist() |
| top2 = torch.topk(probs, k=2, dim=1).values.cpu() |
| margins = (top2[:, 0] - top2[:, 1]).tolist() |
| for i, pred in enumerate(preds): |
| raw_label = label_names[pred] |
| outputs[idx_offset + i + 1] = { |
| "id": idx_offset + i + 1, |
| "label": raw_label, |
| "confidence": float(pred_confidences[i]), |
| "confidence_margin": float(margins[i]), |
| "cue_span": "", |
| "rationale": "scibert_model", |
| } |
| idx_offset += len(preds) |
| return outputs |
|
|
|
|
| def aggregate_citing_labels(labels: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| by_citing: Dict[str, List[Dict[str, Any]]] = {} |
| for item in labels: |
| citing_id = item.get("citing_paper_id") or "" |
| by_citing.setdefault(citing_id, []).append(item) |
|
|
| aggregated: List[Dict[str, Any]] = [] |
| for citing_id, items in by_citing.items(): |
| title = items[0].get("citing_title", "") |
| labels_set = {it.get("label") for it in items} |
|
|
| if "Extends" in labels_set: |
| label = "Extends" |
| evidence_ids = [it["id"] for it in items if it.get("label") == "Extends"] |
| elif "Uses" in labels_set: |
| label = "Uses" |
| evidence_ids = [it["id"] for it in items if it.get("label") == "Uses"] |
| elif "CompareOrContrast" in labels_set: |
| label = "CompareOrContrast" |
| evidence_ids = [ |
| it["id"] for it in items if it.get("label") == "CompareOrContrast" |
| ] |
| else: |
| label = "Background" |
| evidence_ids = [] |
|
|
| aggregated.append( |
| { |
| "citing_paper_id": citing_id, |
| "citing_title": title, |
| "label": label, |
| "evidence_context_ids": evidence_ids, |
| } |
| ) |
|
|
| return aggregated |
|
|
|
|
| def aggregate_final_label(citing_labels: List[Dict[str, Any]]) -> str: |
| labels_set = {item.get("label") for item in citing_labels} |
| if "Extends" in labels_set: |
| return "Extends" |
| if "Uses" in labels_set: |
| return "Uses" |
| if "CompareOrContrast" in labels_set: |
| return "CompareOrContrast" |
| return "Background" |
|
|
|
|
| def score_for_paper( |
| paper_dir: Path, |
| batch_size: int, |
| overwrite: bool, |
| model_path: Path, |
| model_data_dir: Path, |
| model_class_def: Path, |
| model_lm: str, |
| device: str, |
| ) -> str: |
| usage_path = paper_dir / USAGE_CONTEXTS_FILE |
| usage = load_json(usage_path) |
| if not isinstance(usage, dict): |
| return "missing_usage" |
|
|
| contexts = flatten_contexts(usage) |
| if not contexts: |
| return "empty_contexts" |
|
|
| out_path = paper_dir / OUT_FILE |
| if out_path.exists() and not overwrite: |
| return "skipped" |
|
|
| labeled = label_with_model( |
| contexts=contexts, |
| model_path=model_path, |
| data_dir=model_data_dir, |
| class_definition=model_class_def, |
| lm=model_lm, |
| device=device, |
| batch_size=batch_size, |
| ) |
|
|
| labels_sorted = [] |
| for context in contexts: |
| context_id = context["id"] |
| item = labeled.get(context_id) |
| if not item: |
| item = { |
| "id": context_id, |
| "label": "Background", |
| "confidence": 0.0, |
| "cue_span": "", |
| "rationale": "missing label", |
| } |
| item = dict(item) |
| item["citing_paper_id"] = context.get("citing_paper_id", "") |
| item["citing_title"] = context.get("citing_title", "") |
| item["text"] = context.get("text", "") |
| labels_sorted.append(item) |
|
|
| citing_labels = aggregate_citing_labels(labels_sorted) |
| payload = { |
| "paper_id": usage.get("paper_id"), |
| "num_contexts": len(contexts), |
| "label_set": LABEL_SET, |
| "labels": labels_sorted, |
| "citing_paper_labels": citing_labels, |
| "final_label": aggregate_final_label(citing_labels), |
| } |
| out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") |
| return "labeled" |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Label citation functions using a Deep-Citation checkpoint." |
| ) |
| parser.add_argument( |
| "--root", |
| type=str, |
| default="runs/processed_papers", |
| help="Root directory containing processed paper directories.", |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=32, |
| help="Batch size for model inference.", |
| ) |
| parser.add_argument( |
| "--overwrite", |
| action="store_true", |
| help="Overwrite existing usage_context_labels.json files.", |
| ) |
| parser.add_argument( |
| "--model-path", |
| type=str, |
| required=True, |
| help="Path to Deep-Citation best_model.pt checkpoint.", |
| ) |
| parser.add_argument( |
| "--model-data-dir", |
| type=str, |
| default="Deep-Citation/Data", |
| help="Deep-Citation data directory (for label order).", |
| ) |
| parser.add_argument( |
| "--model-class-def", |
| type=str, |
| default="Deep-Citation/Data/class_def.json", |
| help="Deep-Citation class_def.json path.", |
| ) |
| parser.add_argument( |
| "--model-lm", |
| type=str, |
| default="scibert", |
| help="Model name used for the Deep-Citation checkpoint.", |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cuda", |
| help="Device for model inference (cuda/cpu).", |
| ) |
| args = parser.parse_args() |
|
|
| model_path = Path(args.model_path).expanduser().resolve() |
| if not model_path.exists(): |
| raise SystemExit(f"Model path does not exist: {model_path}") |
|
|
| root = Path(args.root).expanduser().resolve() |
| if not root.exists(): |
| raise SystemExit(f"Root directory does not exist: {root}") |
|
|
| paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name) |
| print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}") |
|
|
| counts = { |
| "labeled": 0, |
| "skipped": 0, |
| "missing_usage": 0, |
| "empty_contexts": 0, |
| } |
|
|
| for paper_dir in paper_dirs: |
| status = score_for_paper( |
| paper_dir, |
| args.batch_size, |
| args.overwrite, |
| model_path=model_path, |
| model_data_dir=Path(args.model_data_dir).expanduser().resolve(), |
| model_class_def=Path(args.model_class_def).expanduser().resolve(), |
| model_lm=args.model_lm, |
| device=args.device, |
| ) |
| counts[status] = counts.get(status, 0) + 1 |
| print(f"[{status.upper()}] {paper_dir.name}") |
|
|
| print( |
| "[SUMMARY] labeled={labeled}, skipped={skipped}, missing_usage={missing_usage}, " |
| "empty_contexts={empty_contexts}".format(**counts) |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|