Eric Chamoun commited on
Commit ·
0a55f0f
0
Parent(s):
Initial SciPaths Space release
Browse files- .dockerignore +7 -0
- .gitattributes +1 -0
- .gitignore +7 -0
- Deep-Citation/Data/acl.tsv +0 -0
- Deep-Citation/Data/class_def.json +23 -0
- Deep-Citation/Model/__init__.py +1 -0
- Deep-Citation/Model/model.py +89 -0
- Deep-Citation/Workspace/acl_scicite_wksp_trl/args.txt +21 -0
- Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt +3 -0
- Deep-Citation/data.py +211 -0
- Dockerfile +24 -0
- README.md +232 -0
- app.py +5 -0
- hf_space/requirements.txt +17 -0
- hf_space/runner.py +333 -0
- hf_space/streamlit_app.py +864 -0
- hf_space/streamlit_config.py +30 -0
- requirements.txt +1 -0
- src/common/__init__.py +0 -0
- src/common/llm_client.py +49 -0
- src/common/model_client.py +143 -0
- src/common/paper_package.py +288 -0
- src/step_01_fetch/config.py +6 -0
- src/step_01_fetch/fetch_metadata.py +440 -0
- src/step_01_fetch/process_tex_source.py +203 -0
- src/step_01_fetch/semanticscholar_client.py +158 -0
- src/step_02_mark_citations/replace_citation_markers.py +440 -0
- src/step_03_usage_contexts/build_usage_contexts.py +184 -0
- src/step_04_label_citations/label_citation_functions.py +373 -0
- src/step_05_verify_uses_extends/prompts.py +115 -0
- src/step_05_verify_uses_extends/schemas.py +22 -0
- src/step_05_verify_uses_extends/verify_uses_extends.py +296 -0
- src/step_06_extract_paragraphs/extract_arxiv_paragraphs.py +488 -0
- src/step_07_extract_and_refine/extract_contributions_from_citations.py +329 -0
- src/step_07_extract_and_refine/prompts.py +65 -0
- src/step_07_extract_and_refine/refine_and_filter_clusters_llm.py +402 -0
- src/step_07_extract_and_refine/schemas.py +12 -0
- src/step_08_annotation/__init__.py +3 -0
- src/step_08_annotation/cli.py +99 -0
- src/step_08_annotation/final_prompts.py +0 -0
- src/step_08_annotation/paper_package.py +52 -0
- src/step_08_annotation/pipeline.py +256 -0
- src/step_08_annotation/schemas.py +127 -0
.dockerignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
__pycache__
|
| 3 |
+
*.pyc
|
| 4 |
+
hf_space/runs
|
| 5 |
+
**/__pycache__
|
| 6 |
+
**/*.pyc
|
| 7 |
+
*.zip
|
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
.DS_Store
|
| 4 |
+
.streamlit/secrets.toml
|
| 5 |
+
hf_space/runs/
|
| 6 |
+
runs/
|
| 7 |
+
*.zip
|
Deep-Citation/Data/acl.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Deep-Citation/Data/class_def.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"acl":
|
| 3 |
+
{
|
| 4 |
+
"BACKGROUND": "The citation provides relevant information for the domain that the present paper discusses.",
|
| 5 |
+
"MOTIVATION": "The citation illustrates the need for data, goals, methods, etc that is proposed in the present paper.",
|
| 6 |
+
"USES": "The present paper uses data, methods, etc., from the paper associated with the citation.",
|
| 7 |
+
"EXTENDS": "The present paper extends the data, methods, etc. from the paper associated with the citation.",
|
| 8 |
+
"COMPAREORCONTRAST": "The present paper expresses similarity / differences to the citation.",
|
| 9 |
+
"FUTURE": "The citation is a potential avenue for future work of the present paper."
|
| 10 |
+
},
|
| 11 |
+
"kim":
|
| 12 |
+
{
|
| 13 |
+
"Used": "The present paper uses at least one method that is proposed in the paper associated with the citation.",
|
| 14 |
+
"Not used": "The present paper does not use or extend any methods that is proposed in the paper associated with the citation.",
|
| 15 |
+
"Extended": "The present paper uses an extended / modified version of the method proposed in the paper associated with the citation."
|
| 16 |
+
},
|
| 17 |
+
"scicite":
|
| 18 |
+
{
|
| 19 |
+
"Background": "The citation states, mentions, or points to the background information giving more context about a problem, concept, approach, topic, or importance of the problem that is discussed in the present paper.",
|
| 20 |
+
"Method": "The present paper uses a method, tool, approach or dataset that is proposed in the paper associated with the citation.",
|
| 21 |
+
"Result": "The present paper compares its results/findings with the results/findings of the paper associated with the citation."
|
| 22 |
+
}
|
| 23 |
+
}
|
Deep-Citation/Model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .model import LanguageModel, MultiHeadLanguageModel
|
Deep-Citation/Model/model.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from typing import List
|
| 6 |
+
from transformers import AutoModel
|
| 7 |
+
|
| 8 |
+
def mask_pooling(model_output, attention_mask):
|
| 9 |
+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
| 10 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 11 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 12 |
+
|
| 13 |
+
class LanguageModel(nn.Module):
|
| 14 |
+
def __init__(self,
|
| 15 |
+
modelname: str,
|
| 16 |
+
device: str,
|
| 17 |
+
readout: str
|
| 18 |
+
):
|
| 19 |
+
super(LanguageModel, self).__init__()
|
| 20 |
+
self.device = device
|
| 21 |
+
self.modelname = modelname
|
| 22 |
+
self.readout_fn = readout
|
| 23 |
+
|
| 24 |
+
self.model = AutoModel.from_pretrained(modelname)
|
| 25 |
+
self.hidden_size = self.model.config.hidden_size
|
| 26 |
+
|
| 27 |
+
def readout(self, model_inputs, model_outputs, readout_masks=None):
|
| 28 |
+
if self.readout_fn == 'cls':
|
| 29 |
+
if 'bert' in self.modelname or 'deberta' in self.modelname:
|
| 30 |
+
text_representations = model_outputs.last_hidden_state[:, 0]
|
| 31 |
+
elif 'xlnet' in self.modelname:
|
| 32 |
+
text_representations = model_outputs.last_hidden_state[:, -1]
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError('Invalid model name {} for the cls readout.'.format(self.modelname))
|
| 35 |
+
elif self.readout_fn == 'mean':
|
| 36 |
+
text_representations = mask_pooling(model_outputs, model_inputs['attention_mask'])
|
| 37 |
+
elif self.readout_fn == 'ch' and readout_masks is not None:
|
| 38 |
+
text_representations = mask_pooling(model_outputs, readout_masks)
|
| 39 |
+
else:
|
| 40 |
+
raise ValueError('Invalid readout function.')
|
| 41 |
+
return text_representations
|
| 42 |
+
|
| 43 |
+
def _lm_forward(self, tokens):
|
| 44 |
+
tokens = tokens.to(self.device)
|
| 45 |
+
if 'readout_mask' in tokens:
|
| 46 |
+
readout_mask = tokens.pop('readout_mask')
|
| 47 |
+
else:
|
| 48 |
+
readout_mask = None
|
| 49 |
+
outputs = self.model(**tokens)
|
| 50 |
+
return self.readout(tokens, outputs, readout_mask)
|
| 51 |
+
|
| 52 |
+
def forward(self):
|
| 53 |
+
raise NotImplementedError
|
| 54 |
+
|
| 55 |
+
def save_pretrained(self, modeldir):
|
| 56 |
+
model_filename = os.path.join(modeldir, 'checkpoint.pt')
|
| 57 |
+
torch.save(self.state_dict(), model_filename)
|
| 58 |
+
|
| 59 |
+
def load_pretrained(self, modeldir):
|
| 60 |
+
model_filename = os.path.join(modeldir, 'checkpoint.pt')
|
| 61 |
+
self.load_state_dict(torch.load(model_filename))
|
| 62 |
+
|
| 63 |
+
class MultiHeadLanguageModel(LanguageModel):
|
| 64 |
+
def __init__(self,
|
| 65 |
+
modelname: str,
|
| 66 |
+
device: str,
|
| 67 |
+
readout: str,
|
| 68 |
+
num_classes: List
|
| 69 |
+
):
|
| 70 |
+
super().__init__(
|
| 71 |
+
modelname,
|
| 72 |
+
device,
|
| 73 |
+
readout
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.num_classes = num_classes
|
| 77 |
+
self.lns = nn.ModuleList([nn.Linear(self.hidden_size, num_class) for num_class in num_classes])
|
| 78 |
+
|
| 79 |
+
def forward(self, input_tokens, input_head_indices, class_tokens, class_head_indices):
|
| 80 |
+
head_indices = torch.unique(input_head_indices)
|
| 81 |
+
text_representations = self._lm_forward(input_tokens)
|
| 82 |
+
|
| 83 |
+
final_preds = {}
|
| 84 |
+
for i in head_indices:
|
| 85 |
+
if torch.any(input_head_indices == i):
|
| 86 |
+
final_preds[i.item()] = self.lns[i.item()](text_representations[input_head_indices == i])
|
| 87 |
+
else:
|
| 88 |
+
final_preds[i.item()] = torch.tensor([]).to(self.device)
|
| 89 |
+
return final_preds
|
Deep-Citation/Workspace/acl_scicite_wksp_trl/args.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Namespace(dataset='acl-scicite',
|
| 2 |
+
lambdas='1-0.063',
|
| 3 |
+
data_dir='Data',
|
| 4 |
+
workspace='Workspace/acl_scicite_wksp_trl',
|
| 5 |
+
class_definition='Data/class_def.json',
|
| 6 |
+
batch_size=32,
|
| 7 |
+
lr=5e-05,
|
| 8 |
+
decay_rate=0.5,
|
| 9 |
+
decay_step=5,
|
| 10 |
+
num_epochs=10,
|
| 11 |
+
scheduler='slanted',
|
| 12 |
+
dropout_rate=0.2,
|
| 13 |
+
l2=0.0,
|
| 14 |
+
device='cuda',
|
| 15 |
+
tol=10,
|
| 16 |
+
inference_only=False,
|
| 17 |
+
seed=1,
|
| 18 |
+
lm='scibert',
|
| 19 |
+
max_length=512,
|
| 20 |
+
batch_size_factor=2,
|
| 21 |
+
readout='ch')
|
Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e45ab11942439f80a121dad5b2d9da392470e0cedf6a7335991fa0a1f616dcb2
|
| 3 |
+
size 439784777
|
Deep-Citation/data.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import copy
|
| 4 |
+
import torch
|
| 5 |
+
import scipy
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from scipy.special import softmax
|
| 10 |
+
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
|
| 13 |
+
class CollateFn(object):
|
| 14 |
+
def __init__(self, modelname, class_definitions=None, instance_weights=False):
|
| 15 |
+
self.instance_weights = instance_weights
|
| 16 |
+
use_fast = False if 'deberta' in modelname else True
|
| 17 |
+
self.tokenizer = AutoTokenizer.from_pretrained(modelname, use_fast=use_fast)
|
| 18 |
+
cited_ids = self.tokenizer.encode('<CITED HERE>', add_special_tokens=False)
|
| 19 |
+
self.cited_here_tokens = torch.tensor(cited_ids, dtype=torch.long)
|
| 20 |
+
|
| 21 |
+
if class_definitions is not None:
|
| 22 |
+
self.class_definitions = []
|
| 23 |
+
self.class_head_indices = []
|
| 24 |
+
for i, defs in enumerate(class_definitions):
|
| 25 |
+
self.class_definitions += defs
|
| 26 |
+
self.class_head_indices.append(i * torch.ones(len(defs), dtype=torch.long))
|
| 27 |
+
self.class_head_indices = torch.cat(self.class_head_indices, dim=0)
|
| 28 |
+
self.class_tokens = self.tokenizer(
|
| 29 |
+
self.class_definitions,
|
| 30 |
+
return_tensors="pt",
|
| 31 |
+
max_length=512,
|
| 32 |
+
truncation=True,
|
| 33 |
+
padding=True
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def _get_readout_mask(self, tokens):
|
| 37 |
+
# cited_here_tokens = torch.tensor([962, 8412, 1530, 1374])
|
| 38 |
+
readout_mask = torch.zeros_like(tokens['input_ids'], dtype=torch.bool)
|
| 39 |
+
|
| 40 |
+
batch_size = tokens['input_ids'].size(0)
|
| 41 |
+
l = tokens['input_ids'].size(1)
|
| 42 |
+
ctk_l = self.cited_here_tokens.size(0)
|
| 43 |
+
for b in range(batch_size):
|
| 44 |
+
for i in range(1, l - ctk_l):
|
| 45 |
+
if torch.equal(tokens['input_ids'][b, i:i+ctk_l], self.cited_here_tokens):
|
| 46 |
+
readout_mask[b, i:i+ctk_l] = True
|
| 47 |
+
if not readout_mask[b].any():
|
| 48 |
+
# Fallback to CLS if the citation marker isn't matched.
|
| 49 |
+
readout_mask[b, 0] = True
|
| 50 |
+
return readout_mask
|
| 51 |
+
|
| 52 |
+
def _tokenize_context(self, context):
|
| 53 |
+
tokens = self.tokenizer(
|
| 54 |
+
context,
|
| 55 |
+
return_tensors="pt",
|
| 56 |
+
max_length=512,
|
| 57 |
+
truncation=True,
|
| 58 |
+
padding=True
|
| 59 |
+
)
|
| 60 |
+
tokens['readout_mask'] = self._get_readout_mask(
|
| 61 |
+
tokens
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return tokens
|
| 65 |
+
|
| 66 |
+
def __call__(self, samples):
|
| 67 |
+
if self.instance_weights:
|
| 68 |
+
text, labels, ds_indices, instance_weights = list(map(list, zip(*samples)))
|
| 69 |
+
batched_text = self._tokenize_context(text)
|
| 70 |
+
labels = torch.stack(labels)
|
| 71 |
+
ds_indices = torch.stack(ds_indices)
|
| 72 |
+
instance_weights = torch.stack(instance_weights)
|
| 73 |
+
return batched_text, labels, ds_indices, instance_weights
|
| 74 |
+
else:
|
| 75 |
+
text, labels, ds_indices = list(map(list, zip(*samples)))
|
| 76 |
+
batched_text = self._tokenize_context(text)
|
| 77 |
+
labels = torch.stack(labels)
|
| 78 |
+
ds_indices = torch.stack(ds_indices)
|
| 79 |
+
|
| 80 |
+
return batched_text, labels, ds_indices, copy.deepcopy(self.class_tokens), self.class_head_indices
|
| 81 |
+
|
| 82 |
+
class Dataset(object):
|
| 83 |
+
def __init__(self, dataframe, class_definitions, lmbd=1.0):
|
| 84 |
+
self.class_definitions = class_definitions
|
| 85 |
+
self.lmbd = lmbd
|
| 86 |
+
self._load_data(dataframe)
|
| 87 |
+
|
| 88 |
+
def __len__(self):
|
| 89 |
+
return len(self.labels)
|
| 90 |
+
|
| 91 |
+
def __getitem__(self, idx):
|
| 92 |
+
'''Get datapoint with index'''
|
| 93 |
+
return (self.text[idx], self.labels[idx], self.ds_index[idx])
|
| 94 |
+
|
| 95 |
+
def _load_data(self, annotated_data):
|
| 96 |
+
self.labels = torch.LongTensor(annotated_data['label'].tolist())
|
| 97 |
+
self.original_labels = torch.LongTensor(annotated_data['label'].tolist())
|
| 98 |
+
self.ds_index = torch.zeros_like(self.original_labels)
|
| 99 |
+
self.text = annotated_data['context'].tolist()
|
| 100 |
+
|
| 101 |
+
class MultiHeadDatasets(object):
|
| 102 |
+
def __init__(self, datasets, batch_size_factor=2):
|
| 103 |
+
self.text = []
|
| 104 |
+
self.ds_index = []
|
| 105 |
+
self.labels = []
|
| 106 |
+
self.class_definitions = []
|
| 107 |
+
self.lambdas = []
|
| 108 |
+
|
| 109 |
+
self.dataset_sizes = [len(d.labels) for d in datasets]
|
| 110 |
+
if len(self.dataset_sizes) > 1:
|
| 111 |
+
if sum(self.dataset_sizes) / self.dataset_sizes[0] <= batch_size_factor:
|
| 112 |
+
self.sample_auxiliary = False
|
| 113 |
+
self.adjusted_batch_size_factor = sum(self.dataset_sizes) / self.dataset_sizes[0]
|
| 114 |
+
else:
|
| 115 |
+
self.sample_auxiliary = True
|
| 116 |
+
self.sample_distribution = np.array([d.lmbd for d in datasets[1:]]) / sum([d.lmbd for d in datasets[1:]])
|
| 117 |
+
self.adjusted_batch_size_factor = batch_size_factor
|
| 118 |
+
else:
|
| 119 |
+
self.sample_auxiliary = False
|
| 120 |
+
self.adjusted_batch_size_factor = 1
|
| 121 |
+
|
| 122 |
+
for i, d in enumerate(datasets):
|
| 123 |
+
self.text += d.text
|
| 124 |
+
self.ds_index.append(i * torch.ones(len(d.text), dtype=torch.long))
|
| 125 |
+
self.labels.append(d.labels)
|
| 126 |
+
self.class_definitions.append(d.class_definitions)
|
| 127 |
+
self.lambdas.append(d.lmbd)
|
| 128 |
+
self.labels = torch.cat(self.labels, dim=0)
|
| 129 |
+
self.ds_index = torch.cat(self.ds_index, dim=0)
|
| 130 |
+
|
| 131 |
+
def sample_auxiliary_instace(self):
|
| 132 |
+
sampled_dataset_idx = np.random.choice(
|
| 133 |
+
np.arange(1, len(self.dataset_sizes)),
|
| 134 |
+
p=self.sample_distribution
|
| 135 |
+
)
|
| 136 |
+
instance_idx = np.random.choice(
|
| 137 |
+
self.dataset_sizes[sampled_dataset_idx]
|
| 138 |
+
) + sum(self.dataset_sizes[:sampled_dataset_idx])
|
| 139 |
+
return instance_idx
|
| 140 |
+
|
| 141 |
+
def __len__(self):
|
| 142 |
+
if self.sample_auxiliary: # if the auxiliary dataset is larger than the main dataset
|
| 143 |
+
return self.dataset_sizes[0] * self.adjusted_batch_size_factor
|
| 144 |
+
return len(self.labels)
|
| 145 |
+
|
| 146 |
+
def __getitem__(self, idx):
|
| 147 |
+
'''Get datapoint with index'''
|
| 148 |
+
if idx < self.dataset_sizes[0] or not self.sample_auxiliary:
|
| 149 |
+
return (self.text[idx], self.labels[idx], self.ds_index[idx])
|
| 150 |
+
else:
|
| 151 |
+
real_idx = self.sample_auxiliary_instace()
|
| 152 |
+
return (self.text[real_idx], self.labels[real_idx], self.ds_index[real_idx])
|
| 153 |
+
|
| 154 |
+
def load_class_definitions(filename):
|
| 155 |
+
with open(filename, 'r') as f:
|
| 156 |
+
class_definitions = json.load(f)
|
| 157 |
+
|
| 158 |
+
results = {k:{} for k in class_definitions.keys()}
|
| 159 |
+
for k, v in class_definitions.items():
|
| 160 |
+
for kk, vv in v.items():
|
| 161 |
+
results[k][kk.lower()] = vv
|
| 162 |
+
return results
|
| 163 |
+
|
| 164 |
+
def create_data_channels(filename, class_definition_filename, split=None, lmbd=1.0):
|
| 165 |
+
data = pd.read_csv(filename, sep='\t')
|
| 166 |
+
data = data.fillna(' ')
|
| 167 |
+
|
| 168 |
+
print('Number of data instance: {}'.format(data.shape[0]))
|
| 169 |
+
|
| 170 |
+
# map labels to ids
|
| 171 |
+
unique_labels = data['label'].unique().tolist()
|
| 172 |
+
label2id = {lb: i for i, lb in enumerate(unique_labels)}
|
| 173 |
+
|
| 174 |
+
data['label'] = data['label'].apply(
|
| 175 |
+
lambda x: label2id[x])
|
| 176 |
+
|
| 177 |
+
data_train = data[data['split'] == 'train'].reset_index()
|
| 178 |
+
data_val = data[data['split'] == 'val'].reset_index()
|
| 179 |
+
data_test = data[data['split'] == 'test'].reset_index()
|
| 180 |
+
|
| 181 |
+
class_definitions = load_class_definitions(class_definition_filename)
|
| 182 |
+
dataname = filename.split('/')[-1].split('.')[0]
|
| 183 |
+
data_class_definitions = [class_definitions[dataname][lb.lower()] for lb in unique_labels]
|
| 184 |
+
|
| 185 |
+
train_data = Dataset(data_train, data_class_definitions, lmbd=lmbd)
|
| 186 |
+
val_data = Dataset(data_val, data_class_definitions, lmbd=lmbd)
|
| 187 |
+
test_data = Dataset(data_test, data_class_definitions, lmbd=lmbd)
|
| 188 |
+
|
| 189 |
+
return train_data, val_data, test_data, unique_labels
|
| 190 |
+
|
| 191 |
+
def create_single_data_object(filename, class_definition_filename, split=None, lmbd=1.0):
|
| 192 |
+
data = pd.read_csv(filename, sep='\t')
|
| 193 |
+
data = data.fillna(' ')
|
| 194 |
+
|
| 195 |
+
print('Number of data instance: {}'.format(data.shape[0]))
|
| 196 |
+
|
| 197 |
+
# map labels to ids
|
| 198 |
+
unique_labels = data['label'].unique()
|
| 199 |
+
label2id = {lb: i for i, lb in enumerate(unique_labels)}
|
| 200 |
+
|
| 201 |
+
data['label'] = data['label'].apply(
|
| 202 |
+
lambda x: label2id[x])
|
| 203 |
+
|
| 204 |
+
class_definitions = load_class_definitions(class_definition_filename)
|
| 205 |
+
dataname = filename.split('/')[-1].split('.')[0]
|
| 206 |
+
data_class_definitions = [class_definitions[dataname][lb.lower()] for lb in unique_labels]
|
| 207 |
+
|
| 208 |
+
if split is None:
|
| 209 |
+
return Dataset(data, data_class_definitions, lmbd=lmbd), unique_labels
|
| 210 |
+
else:
|
| 211 |
+
return Dataset(data[data['split'] == split].reset_index(), data_class_definitions, lmbd=lmbd), unique_labels
|
Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1 \
|
| 6 |
+
STREAMLIT_SERVER_HEADLESS=true
|
| 7 |
+
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
+
git \
|
| 12 |
+
build-essential \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
COPY requirements.txt /app/requirements.txt
|
| 16 |
+
COPY hf_space/requirements.txt /app/hf_space/requirements.txt
|
| 17 |
+
RUN python -m pip install --upgrade pip && \
|
| 18 |
+
pip install -r requirements.txt
|
| 19 |
+
|
| 20 |
+
COPY . /app
|
| 21 |
+
|
| 22 |
+
EXPOSE 7860
|
| 23 |
+
|
| 24 |
+
CMD ["streamlit", "run", "hf_space/streamlit_app.py", "--server.address", "0.0.0.0", "--server.port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SciPaths
|
| 3 |
+
emoji: 🔬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# SciPaths
|
| 11 |
+
|
| 12 |
+
SciPaths runs an end-to-end target-contribution pathway pipeline for arXiv papers. It collects downstream citation evidence, derives target contributions from refined citation clusters, decomposes each target contribution into enabling contributions, and grounds those enabling contributions in prior studies.
|
| 13 |
+
|
| 14 |
+
The Hugging Face Space launches the Streamlit app from `hf_space/streamlit_app.py`.
|
| 15 |
+
|
| 16 |
+
## Citation
|
| 17 |
+
|
| 18 |
+
If you find this useful, please cite our paper as:
|
| 19 |
+
|
| 20 |
+
```bibtex
|
| 21 |
+
@misc{chamoun2026scipathsforecastingpathwaysscientific,
|
| 22 |
+
title={SciPaths: Forecasting Pathways to Scientific Discovery},
|
| 23 |
+
author={Eric Chamoun and Yizhou Chi and Yulong Chen and Rui Cao and Zifeng Ding and Michalis Korakakis and Andreas Vlachos},
|
| 24 |
+
year={2026},
|
| 25 |
+
eprint={2605.14600},
|
| 26 |
+
archivePrefix={arXiv},
|
| 27 |
+
primaryClass={cs.CL},
|
| 28 |
+
url={https://arxiv.org/abs/2605.14600},
|
| 29 |
+
}
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Paper URL: https://arxiv.org/abs/2605.14600
|
| 33 |
+
|
| 34 |
+
## Required Secrets
|
| 35 |
+
|
| 36 |
+
Set this in the Space settings before publishing:
|
| 37 |
+
|
| 38 |
+
```text
|
| 39 |
+
GEMINI_API_KEY=<Google Gemini API key>
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Optional, for saving completed run artifacts to a Hugging Face Dataset:
|
| 43 |
+
|
| 44 |
+
```text
|
| 45 |
+
HF_WRITE_TOKEN=<Hugging Face write token>
|
| 46 |
+
RUNS_REPO_ID=<owner/dataset-name>
|
| 47 |
+
RUNS_REPO_TYPE=dataset
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Optional, for higher Semantic Scholar limits:
|
| 51 |
+
|
| 52 |
+
```text
|
| 53 |
+
SEMANTIC_SCHOLAR_API_KEY=<Semantic Scholar API key>
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Run The Demo Locally
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
pip install -r requirements.txt
|
| 60 |
+
streamlit run hf_space/streamlit_app.py
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
Then enter an arXiv URL or ID, for example:
|
| 64 |
+
|
| 65 |
+
```text
|
| 66 |
+
https://arxiv.org/abs/2211.08788
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
The app writes each run under:
|
| 70 |
+
|
| 71 |
+
```text
|
| 72 |
+
hf_space/runs/<job_id>/
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## Run One Example From The Command Line
|
| 76 |
+
|
| 77 |
+
This example stores all intermediate files under `runs/example/processed_papers`.
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
mkdir -p runs/example
|
| 81 |
+
printf '[{"id":"2211.08788","title":"","id_type":"ArXiv"}]\n' > runs/example/input_ids.json
|
| 82 |
+
|
| 83 |
+
python src/step_01_fetch/fetch_metadata.py \
|
| 84 |
+
--ids runs/example/input_ids.json \
|
| 85 |
+
--outdir runs/example/processed_papers
|
| 86 |
+
|
| 87 |
+
python src/step_02_mark_citations/replace_citation_markers.py \
|
| 88 |
+
--root runs/example/processed_papers
|
| 89 |
+
|
| 90 |
+
python src/step_03_usage_contexts/build_usage_contexts.py \
|
| 91 |
+
--root runs/example/processed_papers \
|
| 92 |
+
--out-name usage_contexts.json
|
| 93 |
+
|
| 94 |
+
python src/step_04_label_citations/label_citation_functions.py \
|
| 95 |
+
--root runs/example/processed_papers \
|
| 96 |
+
--model-path Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt \
|
| 97 |
+
--model-data-dir Deep-Citation/Data \
|
| 98 |
+
--model-class-def Deep-Citation/Data/class_def.json \
|
| 99 |
+
--model-lm scibert \
|
| 100 |
+
--device cpu
|
| 101 |
+
|
| 102 |
+
python src/step_05_verify_uses_extends/verify_uses_extends.py \
|
| 103 |
+
--root runs/example/processed_papers \
|
| 104 |
+
--k 0 \
|
| 105 |
+
--batch-size 25
|
| 106 |
+
|
| 107 |
+
python src/step_06_extract_paragraphs/extract_arxiv_paragraphs.py \
|
| 108 |
+
--root runs/example/processed_papers
|
| 109 |
+
|
| 110 |
+
python src/step_07_extract_and_refine/extract_contributions_from_citations.py \
|
| 111 |
+
--root runs/example/processed_papers
|
| 112 |
+
|
| 113 |
+
python src/step_07_extract_and_refine/refine_and_filter_clusters_llm.py \
|
| 114 |
+
--root runs/example/processed_papers \
|
| 115 |
+
--inplace \
|
| 116 |
+
--overwrite
|
| 117 |
+
|
| 118 |
+
PYTHONPATH=src \
|
| 119 |
+
python -m step_08_annotation.cli run \
|
| 120 |
+
--paper-dir runs/example/processed_papers/2211.08788 \
|
| 121 |
+
--provider gemini \
|
| 122 |
+
--model gemini/gemini-3.1-pro-preview \
|
| 123 |
+
--formatter-model gemini/gemini-3.1-pro-preview \
|
| 124 |
+
--judge-model gemini/gemini-3.1-pro-preview \
|
| 125 |
+
--candidate-count 3 \
|
| 126 |
+
--output-root runs/example/two_pass_outputs
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
The final UI payload is written as `pass_2_ui_payload.json` inside the annotation run directory printed by the last command.
|
| 130 |
+
|
| 131 |
+
## Run Each Step On A Set Of Papers
|
| 132 |
+
|
| 133 |
+
Create an ID file with one entry per paper:
|
| 134 |
+
|
| 135 |
+
```json
|
| 136 |
+
[
|
| 137 |
+
{"id": "2211.08788", "title": "", "id_type": "ArXiv"},
|
| 138 |
+
{"id": "2311.14919", "title": "", "id_type": "ArXiv"}
|
| 139 |
+
]
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
Save it as `runs/batch/input_ids.json`, then run:
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
mkdir -p runs/batch
|
| 146 |
+
|
| 147 |
+
# 1. Fetch metadata + LaTeX for each input paper.
|
| 148 |
+
python src/step_01_fetch/fetch_metadata.py \
|
| 149 |
+
--ids runs/batch/input_ids.json \
|
| 150 |
+
--outdir runs/batch/processed_papers
|
| 151 |
+
|
| 152 |
+
# 2. Add explicit citation markers to the target-paper text.
|
| 153 |
+
python src/step_02_mark_citations/replace_citation_markers.py \
|
| 154 |
+
--root runs/batch/processed_papers
|
| 155 |
+
|
| 156 |
+
# 3. Build downstream citation usage contexts.
|
| 157 |
+
python src/step_03_usage_contexts/build_usage_contexts.py \
|
| 158 |
+
--root runs/batch/processed_papers \
|
| 159 |
+
--out-name usage_contexts.json
|
| 160 |
+
|
| 161 |
+
# 4. Label citation functions with the bundled Deep-Citation classifier.
|
| 162 |
+
python src/step_04_label_citations/label_citation_functions.py \
|
| 163 |
+
--root runs/batch/processed_papers \
|
| 164 |
+
--model-path Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt \
|
| 165 |
+
--model-data-dir Deep-Citation/Data \
|
| 166 |
+
--model-class-def Deep-Citation/Data/class_def.json \
|
| 167 |
+
--model-lm scibert \
|
| 168 |
+
--device cpu
|
| 169 |
+
|
| 170 |
+
# 5. Verify USES/EXTENDS citations with an LLM.
|
| 171 |
+
python src/step_05_verify_uses_extends/verify_uses_extends.py \
|
| 172 |
+
--root runs/batch/processed_papers \
|
| 173 |
+
--k 0 \
|
| 174 |
+
--batch-size 25
|
| 175 |
+
|
| 176 |
+
# 6. Extract arXiv paragraphs from downstream citing papers.
|
| 177 |
+
python src/step_06_extract_paragraphs/extract_arxiv_paragraphs.py \
|
| 178 |
+
--root runs/batch/processed_papers
|
| 179 |
+
|
| 180 |
+
# 7. Extract downstream contribution clusters, then merge/filter them.
|
| 181 |
+
python src/step_07_extract_and_refine/extract_contributions_from_citations.py \
|
| 182 |
+
--root runs/batch/processed_papers
|
| 183 |
+
|
| 184 |
+
python src/step_07_extract_and_refine/refine_and_filter_clusters_llm.py \
|
| 185 |
+
--root runs/batch/processed_papers \
|
| 186 |
+
--inplace \
|
| 187 |
+
--overwrite
|
| 188 |
+
|
| 189 |
+
# 8. Annotate each ready paper: target contributions, enabling contributions, and groundings.
|
| 190 |
+
for paper_dir in runs/batch/processed_papers/*; do
|
| 191 |
+
[ -d "$paper_dir" ] || continue
|
| 192 |
+
[ -f "$paper_dir/usage_discovery_from_contributions.json" ] || continue
|
| 193 |
+
PYTHONPATH=src \
|
| 194 |
+
python -m step_08_annotation.cli run \
|
| 195 |
+
--paper-dir "$paper_dir" \
|
| 196 |
+
--provider gemini \
|
| 197 |
+
--model gemini/gemini-3.1-pro-preview \
|
| 198 |
+
--formatter-model gemini/gemini-3.1-pro-preview \
|
| 199 |
+
--judge-model gemini/gemini-3.1-pro-preview \
|
| 200 |
+
--candidate-count 3 \
|
| 201 |
+
--output-root runs/batch/two_pass_outputs
|
| 202 |
+
done
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
## Pipeline Steps
|
| 206 |
+
|
| 207 |
+
1. **Fetch metadata + LaTeX.** Downloads target-paper metadata, references, citing-paper metadata, and arXiv source where available.
|
| 208 |
+
2. **Add citation markers.** Inserts normalized citation markers into the target paper so downstream citation contexts can be aligned.
|
| 209 |
+
3. **Build usage contexts.** Collects text windows around downstream citations to the target paper.
|
| 210 |
+
4. **Label citation functions.** Uses the bundled Deep-Citation classifier to label citation contexts as background, use, extension, comparison, and related categories.
|
| 211 |
+
5. **Verify USES/EXTENDS.** Uses an LLM to check whether candidate downstream citations genuinely use or extend the target paper.
|
| 212 |
+
6. **Extract arXiv paragraphs.** Retrieves fuller paragraphs from citing papers so the system has enough context for contribution extraction.
|
| 213 |
+
7. **Extract and refine target-contribution clusters.** Extracts what downstream papers use the target paper for, clusters near-duplicates, and filters weak/non-usage evidence.
|
| 214 |
+
8. **Annotate pathways.** Derives target contributions from the refined clusters, decomposes each into enabling contributions, selects primary groundings, and records additional grounding studies.
|
| 215 |
+
|
| 216 |
+
## Important Files
|
| 217 |
+
|
| 218 |
+
```text
|
| 219 |
+
hf_space/streamlit_app.py Streamlit UI
|
| 220 |
+
hf_space/runner.py Orchestrates steps 1-7 for the UI
|
| 221 |
+
hf_space/streamlit_config.py Example papers and tab names
|
| 222 |
+
src/common/ Shared LLM and paper-package utilities
|
| 223 |
+
src/step_01_fetch/ Metadata, references, citations, and LaTeX
|
| 224 |
+
src/step_02_mark_citations/ Citation-marker insertion
|
| 225 |
+
src/step_03_usage_contexts/ Downstream usage-context construction
|
| 226 |
+
src/step_04_label_citations/ Deep-Citation citation-function labeling
|
| 227 |
+
src/step_05_verify_uses_extends/ LLM verification of USES/EXTENDS citations
|
| 228 |
+
src/step_06_extract_paragraphs/ ArXiv paragraph extraction from citing papers
|
| 229 |
+
src/step_07_extract_and_refine/ Contribution extraction and cluster refinement
|
| 230 |
+
src/step_08_annotation/ Target/enabling contribution annotation and grounding
|
| 231 |
+
Deep-Citation/ Bundled citation-function classifier assets
|
| 232 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from hf_space.streamlit_app import main
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
main()
|
hf_space/requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.36.0
|
| 2 |
+
arxiv==2.2.0
|
| 3 |
+
requests==2.32.5
|
| 4 |
+
google-generativeai
|
| 5 |
+
litellm
|
| 6 |
+
rapidfuzz
|
| 7 |
+
bibtexparser
|
| 8 |
+
sentence-transformers
|
| 9 |
+
transformers
|
| 10 |
+
torch
|
| 11 |
+
huggingface_hub
|
| 12 |
+
typer
|
| 13 |
+
tqdm
|
| 14 |
+
pydantic
|
| 15 |
+
numpy
|
| 16 |
+
pandas
|
| 17 |
+
scipy
|
hf_space/runner.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import shutil
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Generator, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class PipelineConfig:
|
| 16 |
+
repo_root: Path
|
| 17 |
+
source_root: Path
|
| 18 |
+
paper_input: str
|
| 19 |
+
llm_provider: str
|
| 20 |
+
llm_model: str
|
| 21 |
+
llm_model_step4: str
|
| 22 |
+
model_path: str
|
| 23 |
+
model_data_dir: str
|
| 24 |
+
model_class_def: str
|
| 25 |
+
model_lm: str
|
| 26 |
+
device: str
|
| 27 |
+
embedding_model: str
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class PipelineResult:
|
| 32 |
+
job_id: str
|
| 33 |
+
job_dir: Path
|
| 34 |
+
paper_dir: Path
|
| 35 |
+
zip_path: Path
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
STEP_LABELS = {
|
| 39 |
+
1: "Fetch metadata + LaTeX for input paper",
|
| 40 |
+
2: "Add citation markers",
|
| 41 |
+
3: "Build usage contexts",
|
| 42 |
+
4: "Label citation functions",
|
| 43 |
+
5: "Verify USES/EXTENDS",
|
| 44 |
+
6: "Extract arXiv paragraphs",
|
| 45 |
+
7: "Extract target contributions and refine clusters",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
FULL_STEPS = [1, 2, 3, 4, 5, 6, 7]
|
| 49 |
+
STOP_PREFIX = "Pipeline stopped:"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def parse_arxiv_id(paper_input: str) -> str:
|
| 53 |
+
s = (paper_input or "").strip()
|
| 54 |
+
if not s:
|
| 55 |
+
raise ValueError("paper_input is required")
|
| 56 |
+
if "arxiv.org" in s:
|
| 57 |
+
m = re.search(r"arxiv\.org/(abs|pdf)/([^/?#]+)", s)
|
| 58 |
+
if not m:
|
| 59 |
+
raise ValueError(f"Could not parse arXiv ID from URL: {s}")
|
| 60 |
+
s = m.group(2)
|
| 61 |
+
s = s.replace(".pdf", "")
|
| 62 |
+
s = re.sub(r"v\d+$", "", s)
|
| 63 |
+
if not re.match(r"^[0-9]{4}\.[0-9]{4,5}$", s):
|
| 64 |
+
raise ValueError(f"Invalid arXiv ID format: {s}")
|
| 65 |
+
return s
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _build_commands(
|
| 69 |
+
cfg: PipelineConfig,
|
| 70 |
+
step: int,
|
| 71 |
+
job_processed_root: Path,
|
| 72 |
+
paper_id: str,
|
| 73 |
+
ids_path: Optional[Path],
|
| 74 |
+
) -> List[List[str]]:
|
| 75 |
+
py = sys.executable
|
| 76 |
+
if step == 1:
|
| 77 |
+
assert ids_path is not None
|
| 78 |
+
return [[
|
| 79 |
+
py,
|
| 80 |
+
"src/step_01_fetch/fetch_metadata.py",
|
| 81 |
+
"--ids",
|
| 82 |
+
str(ids_path),
|
| 83 |
+
"--outdir",
|
| 84 |
+
str(job_processed_root),
|
| 85 |
+
]]
|
| 86 |
+
if step == 2:
|
| 87 |
+
return [[py, "src/step_02_mark_citations/replace_citation_markers.py", "--root", str(job_processed_root)]]
|
| 88 |
+
if step == 3:
|
| 89 |
+
return [[py, "src/step_03_usage_contexts/build_usage_contexts.py", "--root", str(job_processed_root), "--out-name", "usage_contexts.json"]]
|
| 90 |
+
if step == 4:
|
| 91 |
+
return [[
|
| 92 |
+
py,
|
| 93 |
+
"src/step_04_label_citations/label_citation_functions.py",
|
| 94 |
+
"--root",
|
| 95 |
+
str(job_processed_root),
|
| 96 |
+
"--model-path",
|
| 97 |
+
cfg.model_path,
|
| 98 |
+
"--model-data-dir",
|
| 99 |
+
cfg.model_data_dir,
|
| 100 |
+
"--model-class-def",
|
| 101 |
+
cfg.model_class_def,
|
| 102 |
+
"--model-lm",
|
| 103 |
+
cfg.model_lm,
|
| 104 |
+
"--device",
|
| 105 |
+
cfg.device,
|
| 106 |
+
]]
|
| 107 |
+
if step == 5:
|
| 108 |
+
return [[
|
| 109 |
+
py,
|
| 110 |
+
"src/step_05_verify_uses_extends/verify_uses_extends.py",
|
| 111 |
+
"--root",
|
| 112 |
+
str(job_processed_root),
|
| 113 |
+
"--k",
|
| 114 |
+
"0",
|
| 115 |
+
"--batch-size",
|
| 116 |
+
"25",
|
| 117 |
+
]]
|
| 118 |
+
if step == 6:
|
| 119 |
+
return [[py, "src/step_06_extract_paragraphs/extract_arxiv_paragraphs.py", "--root", str(job_processed_root)]]
|
| 120 |
+
if step == 7:
|
| 121 |
+
return [
|
| 122 |
+
[py, "src/step_07_extract_and_refine/extract_contributions_from_citations.py", "--root", str(job_processed_root)],
|
| 123 |
+
[py, "src/step_07_extract_and_refine/refine_and_filter_clusters_llm.py", "--root", str(job_processed_root), "--inplace", "--overwrite"],
|
| 124 |
+
]
|
| 125 |
+
raise ValueError(f"Unknown step: {step}")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _write_single_id_file(job_dir: Path, arxiv_id: str) -> Path:
|
| 129 |
+
ids_path = job_dir / "input_ids.json"
|
| 130 |
+
payload = [{"id": arxiv_id, "title": "", "id_type": "ArXiv"}]
|
| 131 |
+
ids_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 132 |
+
return ids_path
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _write_run_metadata(cfg: PipelineConfig, job_dir: Path, paper_id: str, arxiv_id: str) -> None:
|
| 136 |
+
payload = {
|
| 137 |
+
"paper_input": cfg.paper_input,
|
| 138 |
+
"paper_id": paper_id,
|
| 139 |
+
"arxiv_id": arxiv_id,
|
| 140 |
+
"source_root": str(cfg.source_root),
|
| 141 |
+
"steps": FULL_STEPS + ["annotation"],
|
| 142 |
+
"llm_provider": cfg.llm_provider,
|
| 143 |
+
"llm_model": cfg.llm_model,
|
| 144 |
+
"llm_model_step4": cfg.llm_model_step4,
|
| 145 |
+
"device": cfg.device,
|
| 146 |
+
"embedding_model": cfg.embedding_model,
|
| 147 |
+
"timestamp": int(time.time()),
|
| 148 |
+
}
|
| 149 |
+
(job_dir / "run_config.json").write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _zip_job_dir(job_dir: Path) -> Path:
|
| 153 |
+
zip_base = job_dir.parent / job_dir.name
|
| 154 |
+
archive = shutil.make_archive(str(zip_base), "zip", root_dir=str(job_dir))
|
| 155 |
+
return Path(archive)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _tail_log(path: Path, max_lines: int = 60) -> str:
|
| 159 |
+
try:
|
| 160 |
+
lines = path.read_text(encoding="utf-8", errors="ignore").splitlines()
|
| 161 |
+
except Exception:
|
| 162 |
+
return ""
|
| 163 |
+
if not lines:
|
| 164 |
+
return ""
|
| 165 |
+
return "\n".join(lines[-max_lines:])
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _load_json(path: Path, default=None):
|
| 169 |
+
try:
|
| 170 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 171 |
+
except Exception:
|
| 172 |
+
return default
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _write_summary_and_zip(job_dir: Path, summary_lines: List[str]) -> Path:
|
| 176 |
+
(job_dir / "summary.txt").write_text("\n".join(summary_lines), encoding="utf-8")
|
| 177 |
+
return _zip_job_dir(job_dir)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _count_verified_uses_extends(payload: dict) -> int:
|
| 181 |
+
records = payload.get("confirmed") or payload.get("verified_contexts") or payload.get("contexts") or payload.get("items") or []
|
| 182 |
+
if not isinstance(records, list):
|
| 183 |
+
return 0
|
| 184 |
+
accepted = {"USES", "EXTENDS", "Uses", "Extends"}
|
| 185 |
+
return sum(1 for item in records if isinstance(item, dict) and item.get("label") in accepted)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _stop_reason_after_step(step: int, paper_dir: Path) -> str | None:
|
| 189 |
+
if step == 1:
|
| 190 |
+
if not paper_dir.exists():
|
| 191 |
+
return "metadata could not be fetched for this paper"
|
| 192 |
+
if not (paper_dir / "processed_main.tex").exists():
|
| 193 |
+
return "arXiv source could not be retrieved or converted for this paper"
|
| 194 |
+
citations = _load_json(paper_dir / "citations_metadata.json", [])
|
| 195 |
+
if not isinstance(citations, list) or not citations:
|
| 196 |
+
return "Semantic Scholar returned no citing papers for this target paper"
|
| 197 |
+
|
| 198 |
+
if step == 3:
|
| 199 |
+
usage = _load_json(paper_dir / "usage_contexts.json", {})
|
| 200 |
+
if not isinstance(usage, dict):
|
| 201 |
+
return "citation usage contexts could not be built"
|
| 202 |
+
if int(usage.get("num_contexts") or 0) == 0:
|
| 203 |
+
return "no citation usage contexts were found"
|
| 204 |
+
|
| 205 |
+
if step == 4:
|
| 206 |
+
labels = _load_json(paper_dir / "usage_context_labels.json", {})
|
| 207 |
+
contexts = labels.get("labels") if isinstance(labels, dict) else None
|
| 208 |
+
if not isinstance(contexts, list) or not contexts:
|
| 209 |
+
return "citation-function labeling produced no labeled contexts"
|
| 210 |
+
|
| 211 |
+
if step == 5:
|
| 212 |
+
verified = _load_json(paper_dir / "usage_uses_extends_verified.json", {})
|
| 213 |
+
if not isinstance(verified, dict):
|
| 214 |
+
return "USES/EXTENDS verification did not produce an output file"
|
| 215 |
+
if _count_verified_uses_extends(verified) == 0:
|
| 216 |
+
return "no downstream citations were verified as USES or EXTENDS"
|
| 217 |
+
|
| 218 |
+
if step == 6:
|
| 219 |
+
paragraphs = _load_json(paper_dir / "usage_citing_paragraphs.json", {})
|
| 220 |
+
citing = paragraphs.get("citing_papers") if isinstance(paragraphs, dict) else None
|
| 221 |
+
if not isinstance(citing, list) or not citing:
|
| 222 |
+
return "no citing-paper paragraphs could be extracted from arXiv"
|
| 223 |
+
usable = [
|
| 224 |
+
item for item in citing
|
| 225 |
+
if isinstance(item, dict)
|
| 226 |
+
and not item.get("error")
|
| 227 |
+
and (item.get("matched_paragraphs") or item.get("target_citing_paragraphs"))
|
| 228 |
+
]
|
| 229 |
+
if not usable:
|
| 230 |
+
return "arXiv paragraph extraction returned no usable citing-paper text"
|
| 231 |
+
|
| 232 |
+
if step == 7:
|
| 233 |
+
contributions = _load_json(paper_dir / "usage_contributions.json", {})
|
| 234 |
+
items = contributions.get("contributions") if isinstance(contributions, dict) else None
|
| 235 |
+
if not isinstance(items, list) or not items:
|
| 236 |
+
return "no downstream target-contribution evidence could be extracted"
|
| 237 |
+
refined = _load_json(paper_dir / "usage_discovery_from_contributions.json", {})
|
| 238 |
+
clusters = refined.get("clusters") if isinstance(refined, dict) else None
|
| 239 |
+
if not isinstance(clusters, list) or not clusters:
|
| 240 |
+
return "no valid downstream usage clusters survived refinement"
|
| 241 |
+
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def run_pipeline(cfg: PipelineConfig, output_root: Path) -> Generator[Tuple[str, Optional[str]], None, PipelineResult]:
|
| 246 |
+
output_root.mkdir(parents=True, exist_ok=True)
|
| 247 |
+
job_id = f"job_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
| 248 |
+
job_dir = output_root / job_id
|
| 249 |
+
job_processed_root = job_dir / "processed_papers"
|
| 250 |
+
job_logs = job_dir / "logs"
|
| 251 |
+
|
| 252 |
+
job_processed_root.mkdir(parents=True, exist_ok=True)
|
| 253 |
+
job_logs.mkdir(parents=True, exist_ok=True)
|
| 254 |
+
|
| 255 |
+
arxiv_id = parse_arxiv_id(cfg.paper_input)
|
| 256 |
+
paper_id = arxiv_id
|
| 257 |
+
ids_path = _write_single_id_file(job_dir, arxiv_id)
|
| 258 |
+
_write_run_metadata(cfg, job_dir, paper_id, arxiv_id)
|
| 259 |
+
|
| 260 |
+
base_env = os.environ.copy()
|
| 261 |
+
base_env["LLM_PROVIDER"] = cfg.llm_provider
|
| 262 |
+
base_env["LLM_MODEL"] = cfg.llm_model
|
| 263 |
+
|
| 264 |
+
summary_lines: List[str] = []
|
| 265 |
+
paper_dir = job_processed_root / paper_id
|
| 266 |
+
|
| 267 |
+
max_step = 8
|
| 268 |
+
for step in FULL_STEPS:
|
| 269 |
+
label = STEP_LABELS[step]
|
| 270 |
+
log_file = job_logs / f"step_{step:02d}.log"
|
| 271 |
+
summary_lines.append(f"[{step}] {label}")
|
| 272 |
+
yield (f"Step {step}/{max_step}: {label}", None)
|
| 273 |
+
|
| 274 |
+
env = base_env.copy()
|
| 275 |
+
if step == 5 and cfg.llm_model_step4:
|
| 276 |
+
env["LLM_MODEL"] = cfg.llm_model_step4
|
| 277 |
+
|
| 278 |
+
with log_file.open("w", encoding="utf-8") as lf:
|
| 279 |
+
return_code = 0
|
| 280 |
+
failed_cmd: List[str] | None = None
|
| 281 |
+
for cmd in _build_commands(cfg, step, job_processed_root, paper_id, ids_path):
|
| 282 |
+
lf.write(f"$ {' '.join(cmd)}\n\n")
|
| 283 |
+
proc = subprocess.Popen(
|
| 284 |
+
cmd,
|
| 285 |
+
cwd=str(cfg.repo_root),
|
| 286 |
+
stdout=subprocess.PIPE,
|
| 287 |
+
stderr=subprocess.STDOUT,
|
| 288 |
+
text=True,
|
| 289 |
+
encoding="utf-8",
|
| 290 |
+
errors="ignore",
|
| 291 |
+
env=env,
|
| 292 |
+
)
|
| 293 |
+
assert proc.stdout is not None
|
| 294 |
+
for line in proc.stdout:
|
| 295 |
+
lf.write(line)
|
| 296 |
+
return_code = proc.wait()
|
| 297 |
+
if return_code != 0:
|
| 298 |
+
failed_cmd = cmd
|
| 299 |
+
break
|
| 300 |
+
|
| 301 |
+
if return_code != 0:
|
| 302 |
+
summary_lines.append(f"FAILED at step {step}")
|
| 303 |
+
zip_path = _write_summary_and_zip(job_dir, summary_lines)
|
| 304 |
+
tail = _tail_log(log_file)
|
| 305 |
+
if tail:
|
| 306 |
+
yield (
|
| 307 |
+
f"Step {step} failed.\n\nCommand: {' '.join(failed_cmd or [])}\n\nLast log lines:\n{tail}",
|
| 308 |
+
str(zip_path),
|
| 309 |
+
)
|
| 310 |
+
else:
|
| 311 |
+
yield (f"Step {step} failed. Command: {' '.join(failed_cmd or [])}", str(zip_path))
|
| 312 |
+
return PipelineResult(job_id=job_id, job_dir=job_dir, paper_dir=paper_dir, zip_path=zip_path)
|
| 313 |
+
else:
|
| 314 |
+
yield (f"Step {step} complete", None)
|
| 315 |
+
|
| 316 |
+
if step == 1 and not paper_dir.exists():
|
| 317 |
+
summary_lines.append("FAILED: fetch_metadata did not create paper directory")
|
| 318 |
+
zip_path = _write_summary_and_zip(job_dir, summary_lines)
|
| 319 |
+
yield (f"Step 1 finished but paper dir missing: {paper_dir}", str(zip_path))
|
| 320 |
+
return PipelineResult(job_id=job_id, job_dir=job_dir, paper_dir=paper_dir, zip_path=zip_path)
|
| 321 |
+
|
| 322 |
+
stop_reason = _stop_reason_after_step(step, paper_dir)
|
| 323 |
+
if stop_reason:
|
| 324 |
+
message = f"{STOP_PREFIX} {stop_reason}."
|
| 325 |
+
summary_lines.append(message)
|
| 326 |
+
zip_path = _write_summary_and_zip(job_dir, summary_lines)
|
| 327 |
+
yield (message, str(zip_path))
|
| 328 |
+
return PipelineResult(job_id=job_id, job_dir=job_dir, paper_dir=paper_dir, zip_path=zip_path)
|
| 329 |
+
|
| 330 |
+
summary_lines.append("SUCCESS")
|
| 331 |
+
zip_path = _write_summary_and_zip(job_dir, summary_lines)
|
| 332 |
+
yield ("Pipeline completed successfully.", str(zip_path))
|
| 333 |
+
return PipelineResult(job_id=job_id, job_dir=job_dir, paper_dir=paper_dir, zip_path=zip_path)
|
hf_space/streamlit_app.py
ADDED
|
@@ -0,0 +1,864 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
import html
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Optional
|
| 8 |
+
|
| 9 |
+
import streamlit as st
|
| 10 |
+
try:
|
| 11 |
+
from huggingface_hub import HfApi
|
| 12 |
+
except Exception:
|
| 13 |
+
HfApi = None
|
| 14 |
+
|
| 15 |
+
SRC = Path(__file__).resolve().parent
|
| 16 |
+
REPO_ROOT = SRC.parent
|
| 17 |
+
for extra in (SRC, REPO_ROOT / "src"):
|
| 18 |
+
extra_str = str(extra)
|
| 19 |
+
if extra_str not in sys.path:
|
| 20 |
+
sys.path.insert(0, extra_str)
|
| 21 |
+
|
| 22 |
+
import runner as runner_module
|
| 23 |
+
from runner import PipelineConfig
|
| 24 |
+
from common.paper_package import load_paper_package
|
| 25 |
+
from step_08_annotation.pipeline import TwoPassAnnotationPipeline
|
| 26 |
+
from streamlit_config import EXAMPLES, TAB_NAMES
|
| 27 |
+
|
| 28 |
+
DEFAULT_SOURCE_ROOT = str(REPO_ROOT / "src" / "processed_papers")
|
| 29 |
+
DEFAULT_OUTPUT_ROOT = str(REPO_ROOT / "hf_space" / "runs")
|
| 30 |
+
|
| 31 |
+
CUSTOM_CSS = """
|
| 32 |
+
<style>
|
| 33 |
+
.block-container {max-width: 1450px; padding-top: 2rem; padding-bottom: 2rem;}
|
| 34 |
+
[data-testid="stSidebar"] {background: #f5f7fb; border-right: 1px solid #e2e8f0;}
|
| 35 |
+
.hero-title {font-size: 3rem; font-weight: 800; letter-spacing: -0.03em; color: #1f2937; margin-bottom: 0.35rem;}
|
| 36 |
+
.hero-sub {font-size: 1rem; color: #6b7280; max-width: 920px; margin-bottom: 1.25rem;}
|
| 37 |
+
.metric-card {background: #ffffff; border: 1px solid #e5e7eb; border-radius: 16px; padding: 1rem 1.1rem; min-height: 96px;}
|
| 38 |
+
.metric-label {font-size: 0.78rem; font-weight: 700; color: #6b7280; text-transform: uppercase; letter-spacing: 0.04em;}
|
| 39 |
+
.metric-value {font-size: 1.7rem; font-weight: 800; color: #111827; margin-top: 0.35rem;}
|
| 40 |
+
.soft-card {background: #ffffff; border: 1px solid #e5e7eb; border-radius: 16px; padding: 1rem 1.1rem;}
|
| 41 |
+
.claim-card {background: #ffffff; border: 1px solid #e5e7eb; border-radius: 18px; overflow: hidden; margin-bottom: 1rem;}
|
| 42 |
+
.claim-head {padding: 1rem 1.1rem; border-bottom: 1px solid #eef2f7; background: #fcfdff;}
|
| 43 |
+
.claim-kicker {font-size: 0.78rem; font-weight: 800; color: #2563eb; text-transform: uppercase; letter-spacing: 0.04em; margin-bottom: 0.45rem;}
|
| 44 |
+
.claim-text {font-size: 1.05rem; line-height: 1.55; font-weight: 700; color: #111827;}
|
| 45 |
+
.claim-grid {display: grid; grid-template-columns: 1.7fr 1fr;}
|
| 46 |
+
.claim-main, .claim-side {padding: 1rem 1.1rem;}
|
| 47 |
+
.claim-side {border-left: 1px solid #eef2f7; background: #fbfdff;}
|
| 48 |
+
.section-label {font-size: 0.78rem; font-weight: 800; color: #6b7280; text-transform: uppercase; letter-spacing: 0.04em; margin-bottom: 0.7rem;}
|
| 49 |
+
.pill-row {display: flex; flex-wrap: wrap; gap: 0.45rem; margin-top: 0.8rem;}
|
| 50 |
+
.pill {display: inline-block; padding: 0.28rem 0.7rem; border-radius: 999px; border: 1px solid #dbe4f0; background: #f8fbff; color: #1d4ed8; font-size: 0.78rem; font-weight: 700;}
|
| 51 |
+
.ingredient-card {border: 1px solid #e6edf7; border-left: 4px solid #2563eb; border-radius: 12px; background: #ffffff; padding: 0.9rem; margin-bottom: 0.8rem;}
|
| 52 |
+
.ingredient-top {display: flex; justify-content: space-between; gap: 0.7rem; align-items: flex-start; margin-bottom: 0.45rem;}
|
| 53 |
+
.ingredient-name {font-size: 0.98rem; font-weight: 800; color: #111827; line-height: 1.4;}
|
| 54 |
+
.role-pill {display: inline-block; padding: 0.2rem 0.55rem; border-radius: 999px; border: 1px solid #ddd6fe; background: #f5f3ff; color: #6d28d9; font-size: 0.72rem; font-weight: 800; white-space: nowrap;}
|
| 55 |
+
.field {font-size: 0.88rem; line-height: 1.5; color: #374151; margin-top: 0.4rem;}
|
| 56 |
+
.field b {color: #111827;}
|
| 57 |
+
.grounding-block {margin-top: 0.75rem; display: grid; gap: 0.55rem;}
|
| 58 |
+
.grounding-card {border-radius: 10px; padding: 0.65rem 0.75rem; border: 1px solid #bfdbfe; background: #eff6ff;}
|
| 59 |
+
.grounding-card.additional {border-color: #fed7aa; background: #fff7ed;}
|
| 60 |
+
.grounding-label {font-size: 0.7rem; font-weight: 900; text-transform: uppercase; letter-spacing: 0.05em; margin-bottom: 0.25rem;}
|
| 61 |
+
.grounding-label.primary {color: #1d4ed8;}
|
| 62 |
+
.grounding-label.additional {color: #c2410c;}
|
| 63 |
+
.grounding-title {font-size: 0.9rem; font-weight: 800; color: #111827; line-height: 1.35;}
|
| 64 |
+
.grounding-meta {font-size: 0.78rem; color: #64748b; margin-top: 0.2rem;}
|
| 65 |
+
.cluster-card {border: 1px solid #e5e7eb; border-radius: 16px; background: #ffffff; padding: 1rem 1.1rem; margin-bottom: 0.9rem;}
|
| 66 |
+
.cluster-card.additional-study {border-color: #fed7aa; background: #fff7ed;}
|
| 67 |
+
.cluster-title {font-size: 1rem; font-weight: 800; color: #111827; line-height: 1.45; margin-bottom: 0.4rem;}
|
| 68 |
+
.cluster-meta {font-size: 0.86rem; color: #6b7280; margin-bottom: 0.65rem;}
|
| 69 |
+
.empty-card {border: 1px dashed #cbd5e1; border-radius: 14px; padding: 1rem; background: #ffffff; color: #64748b;}
|
| 70 |
+
.example-btn button {border-radius: 999px !important; border: 1px solid #fecaca !important; color: #991b1b !important; background: #fff !important;}
|
| 71 |
+
@media (max-width: 1050px) {.claim-grid {grid-template-columns: 1fr;} .claim-side {border-left: none; border-top: 1px solid #eef2f7;}}
|
| 72 |
+
</style>
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_secret(name: str, default: str = "") -> str:
|
| 77 |
+
value = os.getenv(name)
|
| 78 |
+
if value:
|
| 79 |
+
return value
|
| 80 |
+
try:
|
| 81 |
+
return st.secrets[name]
|
| 82 |
+
except Exception:
|
| 83 |
+
return default
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def run_repo_config() -> tuple[str | None, str, str | None]:
|
| 87 |
+
repo_id = get_secret("RUNS_REPO_ID", "")
|
| 88 |
+
repo_type = get_secret("RUNS_REPO_TYPE", "dataset")
|
| 89 |
+
token = get_secret("HF_WRITE_TOKEN", "") or get_secret("HF_TOKEN", "")
|
| 90 |
+
return repo_id or None, repo_type, token or None
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def remote_run_prefix(job_id: str) -> str:
|
| 94 |
+
return f"runs/{job_id}"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def upload_run_artifact(job_dir: Path) -> str:
|
| 98 |
+
repo_id, repo_type, token = run_repo_config()
|
| 99 |
+
if not repo_id or not token:
|
| 100 |
+
return ""
|
| 101 |
+
if HfApi is None:
|
| 102 |
+
return "upload_failed: huggingface_hub is not installed"
|
| 103 |
+
|
| 104 |
+
job_id = job_dir.name
|
| 105 |
+
remote_prefix = remote_run_prefix(job_id)
|
| 106 |
+
uploaded: list[str] = []
|
| 107 |
+
try:
|
| 108 |
+
api = HfApi(token=token)
|
| 109 |
+
for name in ["input_ids.json", "run_config.json", "summary.txt"]:
|
| 110 |
+
path = job_dir / name
|
| 111 |
+
if path.exists():
|
| 112 |
+
api.upload_file(
|
| 113 |
+
path_or_fileobj=str(path),
|
| 114 |
+
path_in_repo=f"{remote_prefix}/{name}",
|
| 115 |
+
repo_id=repo_id,
|
| 116 |
+
repo_type=repo_type,
|
| 117 |
+
commit_message=f"Upload {name} for {job_id}",
|
| 118 |
+
)
|
| 119 |
+
uploaded.append(name)
|
| 120 |
+
|
| 121 |
+
for folder_name in ["logs", "processed_papers", "two_pass_outputs"]:
|
| 122 |
+
folder = job_dir / folder_name
|
| 123 |
+
if not folder.exists():
|
| 124 |
+
continue
|
| 125 |
+
files = [path for path in folder.rglob("*") if path.is_file()]
|
| 126 |
+
if not files:
|
| 127 |
+
continue
|
| 128 |
+
api.upload_folder(
|
| 129 |
+
folder_path=str(folder),
|
| 130 |
+
path_in_repo=f"{remote_prefix}/{folder_name}",
|
| 131 |
+
repo_id=repo_id,
|
| 132 |
+
repo_type=repo_type,
|
| 133 |
+
commit_message=f"Upload {folder_name} for {job_id}",
|
| 134 |
+
ignore_patterns=["__pycache__/*", "*.pyc", "*.zip"],
|
| 135 |
+
)
|
| 136 |
+
uploaded.append(f"{folder_name}[{len(files)} files]")
|
| 137 |
+
|
| 138 |
+
return f"{repo_type}:{repo_id}/{remote_prefix}/ (uploaded: {', '.join(uploaded) or 'nothing'})"
|
| 139 |
+
except Exception as exc:
|
| 140 |
+
return f"upload_failed: {exc}"
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _load_json(path: Path) -> Optional[dict]:
|
| 144 |
+
if not path.exists():
|
| 145 |
+
return None
|
| 146 |
+
try:
|
| 147 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 148 |
+
except Exception:
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _status_from_line(line: str, current: str) -> str:
|
| 153 |
+
text = (line or "").strip()
|
| 154 |
+
text = _display_log_line(text)
|
| 155 |
+
if text.startswith("Pipeline stopped:"):
|
| 156 |
+
return "Stopped"
|
| 157 |
+
if text.startswith("Step "):
|
| 158 |
+
return text
|
| 159 |
+
if "failed" in text.lower():
|
| 160 |
+
return f"Failed: {text}"
|
| 161 |
+
if "completed successfully" in text.lower():
|
| 162 |
+
return "Completed"
|
| 163 |
+
return current
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _display_log_line(line: str) -> str:
|
| 167 |
+
text = (line or "").strip()
|
| 168 |
+
if text.startswith("Step ") and " failed." in text:
|
| 169 |
+
return text.splitlines()[0]
|
| 170 |
+
if text == "[annotation] starting cluster-first two-pass annotation":
|
| 171 |
+
return "Step 8/8: Annotate target contributions and enabling contributions"
|
| 172 |
+
if text.startswith("[annotation] complete:"):
|
| 173 |
+
return "Step 8 complete"
|
| 174 |
+
if text == "Pipeline completed successfully.":
|
| 175 |
+
return text
|
| 176 |
+
return text
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _format_step_event(line: str) -> str:
|
| 180 |
+
text = _display_log_line(line)
|
| 181 |
+
if not text:
|
| 182 |
+
return ""
|
| 183 |
+
if text.startswith("Step ") and "/" in text and ":" in text:
|
| 184 |
+
return f"🛠️ {text}"
|
| 185 |
+
if text.startswith("Step ") and text.endswith(" complete"):
|
| 186 |
+
return f"✅ {text}"
|
| 187 |
+
if text.lower().startswith("stopped after step"):
|
| 188 |
+
return f"⏹️ {text}"
|
| 189 |
+
if text.startswith("Pipeline stopped:"):
|
| 190 |
+
return f"⏹️ {text}"
|
| 191 |
+
if "failed" in text.lower():
|
| 192 |
+
return f"❌ {text}"
|
| 193 |
+
if "completed successfully" in text.lower():
|
| 194 |
+
return f"✅ {text}"
|
| 195 |
+
return f"• {text}"
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _ensure_state():
|
| 199 |
+
defaults = {
|
| 200 |
+
"paper_input": "",
|
| 201 |
+
"run_status": "Idle",
|
| 202 |
+
"run_logs": [],
|
| 203 |
+
"run_events": [],
|
| 204 |
+
"artifact_path": None,
|
| 205 |
+
"run_dir_path": None,
|
| 206 |
+
"paper_dir_path": None,
|
| 207 |
+
"annotation_payload_path": None,
|
| 208 |
+
"run_summary": None,
|
| 209 |
+
"annotation_skipped_reason": None,
|
| 210 |
+
"pipeline_failed_reason": None,
|
| 211 |
+
"remote_artifact_ref": "",
|
| 212 |
+
}
|
| 213 |
+
for key, value in defaults.items():
|
| 214 |
+
st.session_state.setdefault(key, value)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _metric_card(label: str, value: Any):
|
| 218 |
+
st.markdown(
|
| 219 |
+
f"<div class='metric-card'><div class='metric-label'>{label}</div><div class='metric-value'>{value}</div></div>",
|
| 220 |
+
unsafe_allow_html=True,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _esc(value: Any) -> str:
|
| 225 |
+
return html.escape("" if value is None else str(value))
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _safe_int(value: Any, default: int = 0) -> int:
|
| 229 |
+
try:
|
| 230 |
+
return int(value)
|
| 231 |
+
except (TypeError, ValueError):
|
| 232 |
+
return default
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _grounding_html(grounding: Optional[dict], label: str, kind: str) -> str:
|
| 236 |
+
if not grounding:
|
| 237 |
+
return ""
|
| 238 |
+
title = (
|
| 239 |
+
grounding.get("ref_title")
|
| 240 |
+
or grounding.get("title")
|
| 241 |
+
or grounding.get("paper_id")
|
| 242 |
+
or grounding.get("ref_id")
|
| 243 |
+
or "__NONE__"
|
| 244 |
+
)
|
| 245 |
+
meta = []
|
| 246 |
+
if grounding.get("paper_id"):
|
| 247 |
+
meta.append(f"paper_id: {grounding.get('paper_id')}")
|
| 248 |
+
elif grounding.get("ref_id"):
|
| 249 |
+
meta.append(f"ref_id: {grounding.get('ref_id')}")
|
| 250 |
+
if grounding.get("ref_year"):
|
| 251 |
+
meta.append(str(grounding.get("ref_year")))
|
| 252 |
+
authors = grounding.get("ref_authors")
|
| 253 |
+
if isinstance(authors, list) and authors:
|
| 254 |
+
meta.append(", ".join(str(author) for author in authors[:3]))
|
| 255 |
+
meta_html = f"<div class='grounding-meta'>{_esc(' · '.join(meta))}</div>" if meta else ""
|
| 256 |
+
extra_class = " additional" if kind == "additional" else ""
|
| 257 |
+
return (
|
| 258 |
+
f"<div class='grounding-card{extra_class}'>"
|
| 259 |
+
f"<div class='grounding-label {kind}'>{_esc(label)}</div>"
|
| 260 |
+
f"<div class='grounding-title'>{_esc(title)}</div>"
|
| 261 |
+
f"{meta_html}"
|
| 262 |
+
"</div>"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _study_key(item: dict) -> str:
|
| 267 |
+
for key in ["paper_id", "ref_id", "ref_title", "title"]:
|
| 268 |
+
value = item.get(key)
|
| 269 |
+
if value:
|
| 270 |
+
return str(value).lower()
|
| 271 |
+
return ""
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _collect_grounded_studies(discoveries: list[dict], ingredients: list[dict]) -> list[dict]:
|
| 275 |
+
studies: list[dict] = []
|
| 276 |
+
seen: set[str] = set()
|
| 277 |
+
for item in discoveries:
|
| 278 |
+
if not isinstance(item, dict):
|
| 279 |
+
continue
|
| 280 |
+
copied = dict(item)
|
| 281 |
+
copied["_grounding_kind"] = "primary"
|
| 282 |
+
copied["_grounding_label"] = "Primary study"
|
| 283 |
+
key = _study_key(copied)
|
| 284 |
+
if key:
|
| 285 |
+
seen.add(key)
|
| 286 |
+
studies.append(copied)
|
| 287 |
+
|
| 288 |
+
for idx, ingredient in enumerate(ingredients, start=1):
|
| 289 |
+
if not isinstance(ingredient, dict):
|
| 290 |
+
continue
|
| 291 |
+
canonical = ingredient.get("canonical_grounding") or {}
|
| 292 |
+
canonical_key = _study_key(canonical) if isinstance(canonical, dict) else ""
|
| 293 |
+
annotation = ingredient.get("canonical_annotation") or {}
|
| 294 |
+
for ref in ingredient.get("additional_groundings") or []:
|
| 295 |
+
if not isinstance(ref, dict):
|
| 296 |
+
continue
|
| 297 |
+
key = _study_key(ref)
|
| 298 |
+
if key and (key == canonical_key or key in seen):
|
| 299 |
+
continue
|
| 300 |
+
copied = dict(ref)
|
| 301 |
+
copied["_grounding_kind"] = "additional"
|
| 302 |
+
copied["_grounding_label"] = f"Additional study for enabling contribution {idx}"
|
| 303 |
+
copied.setdefault("role", annotation.get("role") or ", ".join(annotation.get("roles") or []))
|
| 304 |
+
copied.setdefault("contribution", annotation.get("contribution"))
|
| 305 |
+
copied.setdefault("rationale", annotation.get("rationale"))
|
| 306 |
+
if key:
|
| 307 |
+
seen.add(key)
|
| 308 |
+
studies.append(copied)
|
| 309 |
+
return studies
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def _render_reference_list(discoveries: list[dict], ingredients: Optional[list[dict]] = None):
|
| 313 |
+
studies = _collect_grounded_studies(discoveries, ingredients or [])
|
| 314 |
+
if not studies:
|
| 315 |
+
st.markdown("<div class='empty-card'>No grounded studies listed for this target contribution.</div>", unsafe_allow_html=True)
|
| 316 |
+
return
|
| 317 |
+
for item in studies:
|
| 318 |
+
title = item.get("ref_title") or item.get("title") or item.get("ref_id") or item.get("paper_id") or "Untitled reference"
|
| 319 |
+
is_additional = item.get("_grounding_kind") == "additional"
|
| 320 |
+
meta = []
|
| 321 |
+
if item.get("_grounding_label"):
|
| 322 |
+
meta.append(str(item.get("_grounding_label")))
|
| 323 |
+
if item.get("role"):
|
| 324 |
+
meta.append(str(item.get("role")))
|
| 325 |
+
if item.get("ref_year"):
|
| 326 |
+
meta.append(str(item.get("ref_year")))
|
| 327 |
+
class_name = "cluster-card additional-study" if is_additional else "cluster-card"
|
| 328 |
+
body = [f"<div class='{class_name}'><div class='cluster-title'>{_esc(title)}</div>"]
|
| 329 |
+
if meta:
|
| 330 |
+
body.append(f"<div class='cluster-meta'>{_esc(' · '.join(meta))}</div>")
|
| 331 |
+
if item.get("contribution"):
|
| 332 |
+
body.append(f"<div class='field'><b>Contribution.</b> {_esc(item.get('contribution'))}</div>")
|
| 333 |
+
if item.get("rationale"):
|
| 334 |
+
body.append(f"<div class='field'><b>Rationale.</b> {_esc(item.get('rationale'))}</div>")
|
| 335 |
+
body.append("</div>")
|
| 336 |
+
st.markdown("".join(body), unsafe_allow_html=True)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def _render_claims_tab(payload: Optional[dict]):
|
| 340 |
+
if not payload:
|
| 341 |
+
st.markdown("<div class='empty-card'>No annotation payload is available yet.</div>", unsafe_allow_html=True)
|
| 342 |
+
return
|
| 343 |
+
claims = payload.get("claims") or []
|
| 344 |
+
if not claims:
|
| 345 |
+
st.markdown("<div class='empty-card'>The run completed, but no target contributions were produced.</div>", unsafe_allow_html=True)
|
| 346 |
+
return
|
| 347 |
+
|
| 348 |
+
for idx, claim in enumerate(claims, start=1):
|
| 349 |
+
claim_id = claim.get("claim_id") or f"C{idx}"
|
| 350 |
+
claim_text = claim.get("rewritten_claim") or claim.get("text") or "(missing target contribution text)"
|
| 351 |
+
ingredients = claim.get("ingredients") or []
|
| 352 |
+
discoveries = claim.get("enabling_discoveries") or []
|
| 353 |
+
grounded_studies = _collect_grounded_studies(discoveries, ingredients)
|
| 354 |
+
meta_pills = []
|
| 355 |
+
if claim.get("decision"):
|
| 356 |
+
meta_pills.append(str(claim.get("decision")))
|
| 357 |
+
if claim.get("cluster_id"):
|
| 358 |
+
meta_pills.append(f"cluster {claim.get('cluster_id')}")
|
| 359 |
+
meta_pills.append(f"{len(ingredients)} enabling contribution{'s' if len(ingredients) != 1 else ''}")
|
| 360 |
+
meta_pills.append(f"{len(grounded_studies)} grounded stud{'ies' if len(grounded_studies) != 1 else 'y'}")
|
| 361 |
+
|
| 362 |
+
pills_html = "".join(f"<span class='pill'>{_esc(p)}</span>" for p in meta_pills)
|
| 363 |
+
st.markdown(
|
| 364 |
+
f"""
|
| 365 |
+
<div class='claim-card'>
|
| 366 |
+
<div class='claim-head'>
|
| 367 |
+
<div class='claim-kicker'>Target contribution {idx} · {_esc(claim_id)}</div>
|
| 368 |
+
<div class='claim-text'>{_esc(claim_text)}</div>
|
| 369 |
+
<div class='pill-row'>{pills_html}</div>
|
| 370 |
+
</div>
|
| 371 |
+
</div>
|
| 372 |
+
""",
|
| 373 |
+
unsafe_allow_html=True,
|
| 374 |
+
)
|
| 375 |
+
left, right = st.columns([1.7, 1.0], gap="large")
|
| 376 |
+
with left:
|
| 377 |
+
st.markdown("<div class='section-label'>Decomposition</div>", unsafe_allow_html=True)
|
| 378 |
+
if not ingredients:
|
| 379 |
+
st.markdown("<div class='empty-card'>No enabling contributions for this target contribution.</div>", unsafe_allow_html=True)
|
| 380 |
+
for ingredient_idx, ingredient in enumerate(ingredients, start=1):
|
| 381 |
+
annotation = ingredient.get("canonical_annotation") or {}
|
| 382 |
+
role = annotation.get("role") or ", ".join(annotation.get("roles") or []) or "UNSPECIFIED"
|
| 383 |
+
canonical_grounding = ingredient.get("canonical_grounding") or {}
|
| 384 |
+
extras = ingredient.get("additional_groundings") or []
|
| 385 |
+
grounding_parts = []
|
| 386 |
+
if canonical_grounding:
|
| 387 |
+
grounding_parts.append(
|
| 388 |
+
_grounding_html(canonical_grounding, "Primary grounding", "primary")
|
| 389 |
+
)
|
| 390 |
+
for ref in extras:
|
| 391 |
+
if not isinstance(ref, dict):
|
| 392 |
+
continue
|
| 393 |
+
if canonical_grounding and (
|
| 394 |
+
ref.get("paper_id") == canonical_grounding.get("paper_id")
|
| 395 |
+
or ref.get("ref_id") == canonical_grounding.get("ref_id")
|
| 396 |
+
):
|
| 397 |
+
continue
|
| 398 |
+
grounding_parts.append(
|
| 399 |
+
_grounding_html(ref, "Additional grounding", "additional")
|
| 400 |
+
)
|
| 401 |
+
if not grounding_parts:
|
| 402 |
+
canonical_ref_id = ingredient.get("canonical_ref_id") or "__NONE__"
|
| 403 |
+
grounding_parts.append(
|
| 404 |
+
"<div class='grounding-card'>"
|
| 405 |
+
"<div class='grounding-label primary'>Grounding</div>"
|
| 406 |
+
f"<div class='grounding-title'>{_esc(canonical_ref_id)}</div>"
|
| 407 |
+
"</div>"
|
| 408 |
+
)
|
| 409 |
+
grounding_block = (
|
| 410 |
+
"<div class='grounding-block'>"
|
| 411 |
+
f"<div class='section-label'>Groundings for enabling contribution {ingredient_idx}</div>"
|
| 412 |
+
+ "".join(grounding_parts)
|
| 413 |
+
+ "</div>"
|
| 414 |
+
)
|
| 415 |
+
st.markdown(
|
| 416 |
+
f"""
|
| 417 |
+
<div class='ingredient-card'>
|
| 418 |
+
<div class='ingredient-top'>
|
| 419 |
+
<div class='ingredient-name'>{ingredient_idx}. {_esc(ingredient.get('ingredient') or '(missing enabling contribution)')}</div>
|
| 420 |
+
<div class='role-pill'>{_esc(role)}</div>
|
| 421 |
+
</div>
|
| 422 |
+
<div class='field'><b>Contribution.</b> {_esc(annotation.get('contribution') or '')}</div>
|
| 423 |
+
<div class='field'><b>Rationale.</b> {_esc(annotation.get('rationale') or '')}</div>
|
| 424 |
+
<div class='field'><b>Evidence.</b> {_esc(annotation.get('evidence_span') or '')}</div>
|
| 425 |
+
{grounding_block}
|
| 426 |
+
</div>
|
| 427 |
+
""",
|
| 428 |
+
unsafe_allow_html=True,
|
| 429 |
+
)
|
| 430 |
+
with right:
|
| 431 |
+
st.markdown("<div class='section-label'>Grounded and additional studies</div>", unsafe_allow_html=True)
|
| 432 |
+
_render_reference_list(discoveries, ingredients)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def _render_clusters_tab(discovery: Optional[dict], contributions: list[dict]):
|
| 436 |
+
if not discovery:
|
| 437 |
+
st.markdown("<div class='empty-card'>No refined cluster file is available yet.</div>", unsafe_allow_html=True)
|
| 438 |
+
return
|
| 439 |
+
clusters = discovery.get("clusters") or []
|
| 440 |
+
dropped = discovery.get("dropped_clusters") or []
|
| 441 |
+
if not clusters:
|
| 442 |
+
st.markdown("<div class='empty-card'>No valid downstream usage clusters survived refinement and filtering.</div>", unsafe_allow_html=True)
|
| 443 |
+
if dropped:
|
| 444 |
+
with st.expander(f"Dropped clusters ({len(dropped)})", expanded=False):
|
| 445 |
+
st.json(dropped)
|
| 446 |
+
return
|
| 447 |
+
|
| 448 |
+
for cluster in clusters:
|
| 449 |
+
cluster_id = cluster.get("cluster_id", "")
|
| 450 |
+
rep = cluster.get("representative_claim") or cluster.get("cluster_title") or "(missing representative claim)"
|
| 451 |
+
count = _safe_int(cluster.get("count"), len(cluster.get("claim_indices") or []))
|
| 452 |
+
source_ids = cluster.get("source_cluster_ids") or []
|
| 453 |
+
merge_rationale = cluster.get("merge_rationale") or ""
|
| 454 |
+
st.markdown(
|
| 455 |
+
f"""
|
| 456 |
+
<div class='cluster-card'>
|
| 457 |
+
<div class='cluster-title'>{_esc(rep)}</div>
|
| 458 |
+
<div class='cluster-meta'>Cluster {_esc(cluster_id)} · {count} contribution instance{'s' if count != 1 else ''}</div>
|
| 459 |
+
</div>
|
| 460 |
+
""",
|
| 461 |
+
unsafe_allow_html=True,
|
| 462 |
+
)
|
| 463 |
+
meta_cols = st.columns([1.3, 1.3, 1.4])
|
| 464 |
+
with meta_cols[0]:
|
| 465 |
+
st.caption("Cluster ID")
|
| 466 |
+
st.code(str(cluster_id), language="text")
|
| 467 |
+
with meta_cols[1]:
|
| 468 |
+
st.caption("Source clusters")
|
| 469 |
+
st.code(", ".join(str(x) for x in source_ids) if source_ids else "singleton", language="text")
|
| 470 |
+
with meta_cols[2]:
|
| 471 |
+
st.caption("Merge rationale")
|
| 472 |
+
st.write(merge_rationale or "—")
|
| 473 |
+
|
| 474 |
+
claim_indices = cluster.get("claim_indices") or []
|
| 475 |
+
if claim_indices:
|
| 476 |
+
with st.expander(f"Linked contribution instances ({len(claim_indices)})", expanded=False):
|
| 477 |
+
for idx in claim_indices:
|
| 478 |
+
try:
|
| 479 |
+
j = int(idx)
|
| 480 |
+
except Exception:
|
| 481 |
+
continue
|
| 482 |
+
if 0 <= j < len(contributions):
|
| 483 |
+
item = contributions[j] or {}
|
| 484 |
+
title = item.get("citing_title") or item.get("citing_paper_id") or "Unknown citing paper"
|
| 485 |
+
claim = item.get("paper_claim") or item.get("claim") or "(missing claim)"
|
| 486 |
+
rationale = item.get("rationale") or ""
|
| 487 |
+
evidence = item.get("evidence_span") or ""
|
| 488 |
+
st.markdown(f"**{title}**")
|
| 489 |
+
st.write(claim)
|
| 490 |
+
if rationale:
|
| 491 |
+
st.caption(f"Rationale: {rationale}")
|
| 492 |
+
if evidence:
|
| 493 |
+
st.caption(f"Evidence: {evidence}")
|
| 494 |
+
st.divider()
|
| 495 |
+
|
| 496 |
+
if dropped:
|
| 497 |
+
with st.expander(f"Dropped clusters ({len(dropped)})", expanded=False):
|
| 498 |
+
st.json(dropped)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def run_two_pass_annotation(
|
| 502 |
+
paper_dir: Path,
|
| 503 |
+
annotation_output_root: Path,
|
| 504 |
+
llm_provider: str,
|
| 505 |
+
llm_model: str,
|
| 506 |
+
formatter_model: str,
|
| 507 |
+
judge_model: str,
|
| 508 |
+
candidate_count: int,
|
| 509 |
+
):
|
| 510 |
+
paper = load_paper_package(paper_dir)
|
| 511 |
+
pipeline = TwoPassAnnotationPipeline(
|
| 512 |
+
provider=llm_provider,
|
| 513 |
+
model=llm_model,
|
| 514 |
+
formatter_model=formatter_model or None,
|
| 515 |
+
judge_model=judge_model or None,
|
| 516 |
+
output_root=annotation_output_root,
|
| 517 |
+
annotator_id="streamlit_hf_space",
|
| 518 |
+
candidate_count=max(1, int(candidate_count)),
|
| 519 |
+
formatter_max_attempts=3,
|
| 520 |
+
include_reference_examples=True,
|
| 521 |
+
prompt_profile="full",
|
| 522 |
+
)
|
| 523 |
+
result = pipeline.run(paper)
|
| 524 |
+
return result.result, result.run_dir
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def run_pipeline_stream(
|
| 528 |
+
paper_input: str,
|
| 529 |
+
source_root: str,
|
| 530 |
+
output_root: str,
|
| 531 |
+
llm_provider: str,
|
| 532 |
+
llm_model: str,
|
| 533 |
+
llm_model_step4: str,
|
| 534 |
+
formatter_model: str,
|
| 535 |
+
judge_model: str,
|
| 536 |
+
candidate_count: int,
|
| 537 |
+
):
|
| 538 |
+
gemini_key = get_secret("GEMINI_API_KEY")
|
| 539 |
+
if gemini_key:
|
| 540 |
+
os.environ["GEMINI_API_KEY"] = gemini_key
|
| 541 |
+
|
| 542 |
+
cfg = PipelineConfig(
|
| 543 |
+
repo_root=REPO_ROOT,
|
| 544 |
+
source_root=Path(source_root).expanduser().resolve(),
|
| 545 |
+
paper_input=paper_input.strip(),
|
| 546 |
+
llm_provider=llm_provider.strip() or "gemini",
|
| 547 |
+
llm_model=llm_model.strip() or "gemini-3.1-pro-preview",
|
| 548 |
+
llm_model_step4=llm_model_step4.strip() or "gemini-3-flash-preview",
|
| 549 |
+
model_path="Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt",
|
| 550 |
+
model_data_dir="Deep-Citation/Data",
|
| 551 |
+
model_class_def="Deep-Citation/Data/class_def.json",
|
| 552 |
+
model_lm="scibert",
|
| 553 |
+
device="cpu",
|
| 554 |
+
embedding_model="sentence-transformers/all-mpnet-base-v2",
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
status_placeholder = st.empty()
|
| 558 |
+
activity_placeholder = st.empty()
|
| 559 |
+
status = "Starting"
|
| 560 |
+
logs: list[str] = []
|
| 561 |
+
events: list[str] = []
|
| 562 |
+
seen_events: set[str] = set()
|
| 563 |
+
artifact_path = None
|
| 564 |
+
annotation_payload_path = None
|
| 565 |
+
annotation_skipped_reason = None
|
| 566 |
+
run_summary = None
|
| 567 |
+
pipeline_stopped_reason = None
|
| 568 |
+
pipeline_failed_reason = None
|
| 569 |
+
|
| 570 |
+
def render_activity(items: list[str]):
|
| 571 |
+
if not items:
|
| 572 |
+
activity_placeholder.info("Waiting for first step...")
|
| 573 |
+
return
|
| 574 |
+
activity_placeholder.markdown("### Activity\n" + "\n".join(f"- {item}" for item in items[-20:]))
|
| 575 |
+
|
| 576 |
+
def append_display_line(line: str):
|
| 577 |
+
display_line = _display_log_line(line)
|
| 578 |
+
if not display_line:
|
| 579 |
+
return
|
| 580 |
+
logs.append(display_line)
|
| 581 |
+
event = _format_step_event(display_line)
|
| 582 |
+
if event and event not in seen_events:
|
| 583 |
+
seen_events.add(event)
|
| 584 |
+
events.append(event)
|
| 585 |
+
render_activity(events)
|
| 586 |
+
|
| 587 |
+
for line, maybe_artifact in runner_module.run_pipeline(cfg, Path(output_root).expanduser().resolve()):
|
| 588 |
+
if line:
|
| 589 |
+
if line.strip() == "Pipeline completed successfully.":
|
| 590 |
+
if maybe_artifact:
|
| 591 |
+
artifact_path = maybe_artifact
|
| 592 |
+
continue
|
| 593 |
+
display_line = _display_log_line(line)
|
| 594 |
+
if display_line:
|
| 595 |
+
logs.append(display_line)
|
| 596 |
+
status = _status_from_line(display_line, status)
|
| 597 |
+
if display_line.startswith("Pipeline stopped:"):
|
| 598 |
+
pipeline_stopped_reason = display_line
|
| 599 |
+
if "failed" in display_line.lower():
|
| 600 |
+
pipeline_failed_reason = display_line
|
| 601 |
+
event = _format_step_event(display_line)
|
| 602 |
+
if event and event not in seen_events:
|
| 603 |
+
seen_events.add(event)
|
| 604 |
+
events.append(event)
|
| 605 |
+
if maybe_artifact:
|
| 606 |
+
artifact_path = maybe_artifact
|
| 607 |
+
status_placeholder.info(f"Current status: {status}")
|
| 608 |
+
render_activity(events)
|
| 609 |
+
|
| 610 |
+
run_dir_path = None
|
| 611 |
+
paper_dir_path = None
|
| 612 |
+
remote_artifact_ref = ""
|
| 613 |
+
if artifact_path:
|
| 614 |
+
job_dir = Path(str(artifact_path)).with_suffix("")
|
| 615 |
+
run_dir_path = str(job_dir)
|
| 616 |
+
paper_id = runner_module.parse_arxiv_id(paper_input.strip())
|
| 617 |
+
paper_dir = job_dir / "processed_papers" / paper_id
|
| 618 |
+
paper_dir_path = str(paper_dir)
|
| 619 |
+
if pipeline_failed_reason:
|
| 620 |
+
annotation_skipped_reason = f"{pipeline_failed_reason} Annotation was not run."
|
| 621 |
+
elif pipeline_stopped_reason:
|
| 622 |
+
annotation_skipped_reason = f"{pipeline_stopped_reason} Annotation was not run."
|
| 623 |
+
else:
|
| 624 |
+
discovery = _load_json(paper_dir / "usage_discovery_from_contributions.json") or {}
|
| 625 |
+
refined_clusters = discovery.get("clusters") or []
|
| 626 |
+
if not refined_clusters:
|
| 627 |
+
annotation_skipped_reason = "No valid downstream usage clusters remained after refinement and filtering. Annotation was skipped."
|
| 628 |
+
logs.append("[annotation] skipped: no refined downstream usage clusters")
|
| 629 |
+
else:
|
| 630 |
+
append_display_line("[annotation] starting cluster-first two-pass annotation")
|
| 631 |
+
status_placeholder.info("Current status: Running annotation")
|
| 632 |
+
try:
|
| 633 |
+
run_output, annotation_run_dir = run_two_pass_annotation(
|
| 634 |
+
paper_dir=paper_dir,
|
| 635 |
+
annotation_output_root=job_dir / "two_pass_outputs",
|
| 636 |
+
llm_provider=llm_provider,
|
| 637 |
+
llm_model=llm_model,
|
| 638 |
+
formatter_model=formatter_model,
|
| 639 |
+
judge_model=judge_model,
|
| 640 |
+
candidate_count=candidate_count,
|
| 641 |
+
)
|
| 642 |
+
payload_path = run_output.get("ui_payload_path") if isinstance(run_output, dict) else None
|
| 643 |
+
if payload_path and Path(payload_path).exists():
|
| 644 |
+
annotation_payload_path = str(Path(payload_path))
|
| 645 |
+
append_display_line(f"[annotation] complete: {annotation_run_dir}")
|
| 646 |
+
except Exception as exc:
|
| 647 |
+
pipeline_failed_reason = f"Annotation failed: {exc}"
|
| 648 |
+
annotation_skipped_reason = pipeline_failed_reason
|
| 649 |
+
logs.append(f"[annotation] failed: {exc}")
|
| 650 |
+
logs.append("[upload] uploading run artifact to Hugging Face dataset")
|
| 651 |
+
status_placeholder.info("Current status: Finalizing run")
|
| 652 |
+
remote_artifact_ref = upload_run_artifact(job_dir)
|
| 653 |
+
if remote_artifact_ref:
|
| 654 |
+
logs.append(f"[upload] {remote_artifact_ref}")
|
| 655 |
+
else:
|
| 656 |
+
logs.append("[upload] skipped: RUNS_REPO_ID/HF_WRITE_TOKEN not configured")
|
| 657 |
+
if not pipeline_stopped_reason and not pipeline_failed_reason:
|
| 658 |
+
append_display_line("Pipeline completed successfully.")
|
| 659 |
+
|
| 660 |
+
if pipeline_failed_reason:
|
| 661 |
+
status = "Failed"
|
| 662 |
+
elif artifact_path and pipeline_stopped_reason:
|
| 663 |
+
status = "Stopped"
|
| 664 |
+
else:
|
| 665 |
+
status = "Completed" if artifact_path else "Failed"
|
| 666 |
+
if status == "Completed":
|
| 667 |
+
status_placeholder.success(f"Final status: {status}")
|
| 668 |
+
elif status == "Stopped":
|
| 669 |
+
status_placeholder.warning(f"Final status: {status}")
|
| 670 |
+
else:
|
| 671 |
+
status_placeholder.error("Final status: Failed")
|
| 672 |
+
|
| 673 |
+
st.session_state["run_status"] = status
|
| 674 |
+
st.session_state["run_logs"] = logs
|
| 675 |
+
st.session_state["run_events"] = events
|
| 676 |
+
st.session_state["artifact_path"] = artifact_path
|
| 677 |
+
st.session_state["run_dir_path"] = run_dir_path
|
| 678 |
+
st.session_state["paper_dir_path"] = paper_dir_path
|
| 679 |
+
st.session_state["annotation_payload_path"] = annotation_payload_path
|
| 680 |
+
st.session_state["annotation_skipped_reason"] = annotation_skipped_reason
|
| 681 |
+
st.session_state["pipeline_stopped_reason"] = pipeline_stopped_reason
|
| 682 |
+
st.session_state["pipeline_failed_reason"] = pipeline_failed_reason
|
| 683 |
+
st.session_state["run_summary"] = run_summary
|
| 684 |
+
st.session_state["remote_artifact_ref"] = remote_artifact_ref
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def _load_result_bundle():
|
| 688 |
+
paper_dir_path = st.session_state.get("paper_dir_path")
|
| 689 |
+
annotation_payload_path = st.session_state.get("annotation_payload_path")
|
| 690 |
+
paper_dir = Path(paper_dir_path) if paper_dir_path else None
|
| 691 |
+
payload = _load_json(Path(annotation_payload_path)) if annotation_payload_path else None
|
| 692 |
+
discovery = _load_json(paper_dir / "usage_discovery_from_contributions.json") if paper_dir and paper_dir.exists() else None
|
| 693 |
+
contributions_data = _load_json(paper_dir / "usage_contributions.json") if paper_dir and paper_dir.exists() else None
|
| 694 |
+
contributions = (contributions_data or {}).get("contributions") or []
|
| 695 |
+
return paper_dir, discovery, contributions, payload
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def _render_overview(payload: Optional[dict], discovery: Optional[dict]):
|
| 699 |
+
claims = (payload or {}).get("claims") or []
|
| 700 |
+
ingredients = sum(len(claim.get("ingredients") or []) for claim in claims)
|
| 701 |
+
studies = sum(
|
| 702 |
+
len(_collect_grounded_studies(claim.get("enabling_discoveries") or [], claim.get("ingredients") or []))
|
| 703 |
+
for claim in claims
|
| 704 |
+
)
|
| 705 |
+
clusters = len((discovery or {}).get("clusters") or [])
|
| 706 |
+
|
| 707 |
+
c1, c2, c3, c4 = st.columns(4)
|
| 708 |
+
with c1:
|
| 709 |
+
_metric_card("Refined clusters", clusters)
|
| 710 |
+
with c2:
|
| 711 |
+
_metric_card("Target contributions", len(claims))
|
| 712 |
+
with c3:
|
| 713 |
+
_metric_card("Enabling contributions", ingredients)
|
| 714 |
+
with c4:
|
| 715 |
+
_metric_card("Grounded studies", studies)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def _build_public_export(discovery: Optional[dict], payload: Optional[dict]) -> dict:
|
| 719 |
+
claims = []
|
| 720 |
+
for claim in (payload or {}).get("claims") or []:
|
| 721 |
+
if not isinstance(claim, dict):
|
| 722 |
+
continue
|
| 723 |
+
ingredients = []
|
| 724 |
+
for ingredient in claim.get("ingredients") or []:
|
| 725 |
+
if not isinstance(ingredient, dict):
|
| 726 |
+
continue
|
| 727 |
+
ingredients.append({
|
| 728 |
+
"ingredient_id": ingredient.get("ingredient_id"),
|
| 729 |
+
"enabling_contribution": ingredient.get("ingredient"),
|
| 730 |
+
"canonical_annotation": ingredient.get("canonical_annotation") or {},
|
| 731 |
+
"primary_grounding": ingredient.get("canonical_grounding") or {},
|
| 732 |
+
"additional_groundings": ingredient.get("additional_groundings") or [],
|
| 733 |
+
})
|
| 734 |
+
claims.append({
|
| 735 |
+
"claim_id": claim.get("claim_id"),
|
| 736 |
+
"target_contribution": claim.get("rewritten_claim") or claim.get("text"),
|
| 737 |
+
"cluster_id": claim.get("cluster_id"),
|
| 738 |
+
"decision": claim.get("decision"),
|
| 739 |
+
"enabling_contributions": ingredients,
|
| 740 |
+
"grounded_studies": _collect_grounded_studies(claim.get("enabling_discoveries") or [], claim.get("ingredients") or []),
|
| 741 |
+
})
|
| 742 |
+
|
| 743 |
+
return {
|
| 744 |
+
"citation_clusters": (discovery or {}).get("clusters") or [],
|
| 745 |
+
"target_contribution_decompositions": claims,
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
def main():
|
| 750 |
+
llm_provider = os.getenv("LLM_PROVIDER", "gemini")
|
| 751 |
+
llm_model = os.getenv("LLM_MODEL", "gemini-3.1-pro-preview")
|
| 752 |
+
llm_model_step4 = os.getenv("LLM_MODEL_STEP4", "gemini-3-flash-preview")
|
| 753 |
+
formatter_model = os.getenv("ANNOTATION_FORMATTER_MODEL", "gemini/gemini-3.1-pro-preview")
|
| 754 |
+
judge_model = os.getenv("ANNOTATION_JUDGE_MODEL", "gemini/gemini-3.1-pro-preview")
|
| 755 |
+
candidate_count = int(os.getenv("ANNOTATION_CANDIDATE_COUNT", "3"))
|
| 756 |
+
source_root = DEFAULT_SOURCE_ROOT
|
| 757 |
+
output_root = DEFAULT_OUTPUT_ROOT
|
| 758 |
+
|
| 759 |
+
st.set_page_config(page_title="Forecasting Scientific Contribution Pathways", page_icon="📚", layout="wide")
|
| 760 |
+
st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
|
| 761 |
+
_ensure_state()
|
| 762 |
+
|
| 763 |
+
with st.sidebar:
|
| 764 |
+
st.markdown("## SciPaths")
|
| 765 |
+
st.caption("Enter an arXiv paper and run the target-contribution pathway annotation pipeline.")
|
| 766 |
+
st.divider()
|
| 767 |
+
st.markdown("### Citation")
|
| 768 |
+
st.caption("If you find this useful, please cite our paper as:")
|
| 769 |
+
st.code(
|
| 770 |
+
"@misc{chamoun2026scipathsforecastingpathwaysscientific,\n"
|
| 771 |
+
" title={SciPaths: Forecasting Pathways to Scientific Discovery}, \n"
|
| 772 |
+
" author={Eric Chamoun and Yizhou Chi and Yulong Chen and Rui Cao and Zifeng Ding and Michalis Korakakis and Andreas Vlachos},\n"
|
| 773 |
+
" year={2026},\n"
|
| 774 |
+
" eprint={2605.14600},\n"
|
| 775 |
+
" archivePrefix={arXiv},\n"
|
| 776 |
+
" primaryClass={cs.CL},\n"
|
| 777 |
+
" url={https://arxiv.org/abs/2605.14600}, \n"
|
| 778 |
+
"}",
|
| 779 |
+
language="bibtex",
|
| 780 |
+
)
|
| 781 |
+
st.caption("Paper URL: https://arxiv.org/abs/2605.14600")
|
| 782 |
+
st.caption("Questions or feedback: ec806@cam.ac.uk")
|
| 783 |
+
st.divider()
|
| 784 |
+
if st.button("Clear chat / restart", use_container_width=True):
|
| 785 |
+
for key in [
|
| 786 |
+
"paper_input", "run_status", "run_logs", "run_events", "artifact_path",
|
| 787 |
+
"run_dir_path", "paper_dir_path", "annotation_payload_path",
|
| 788 |
+
"run_summary", "annotation_skipped_reason", "pipeline_stopped_reason",
|
| 789 |
+
"pipeline_failed_reason", "remote_artifact_ref",
|
| 790 |
+
]:
|
| 791 |
+
if key in st.session_state:
|
| 792 |
+
del st.session_state[key]
|
| 793 |
+
st.rerun()
|
| 794 |
+
if not get_secret("GEMINI_API_KEY"):
|
| 795 |
+
st.warning("No GEMINI_API_KEY found in environment or secrets.", icon="🔑")
|
| 796 |
+
|
| 797 |
+
st.markdown("<div class='hero-title'>Forecasting Scientific Contribution Pathways</div>", unsafe_allow_html=True)
|
| 798 |
+
st.markdown(
|
| 799 |
+
"<div class='hero-sub'>Run the SciPaths pipeline through refined downstream citation clusters, then derive target contributions from those clusters and decompose each target contribution into enabling contributions and grounded studies.</div>",
|
| 800 |
+
unsafe_allow_html=True,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
tabs = st.tabs(TAB_NAMES)
|
| 804 |
+
|
| 805 |
+
with tabs[0]:
|
| 806 |
+
with st.expander("Try an example", expanded=True):
|
| 807 |
+
cols = st.columns(len(EXAMPLES))
|
| 808 |
+
for i, (label, value) in enumerate(EXAMPLES.items()):
|
| 809 |
+
with cols[i]:
|
| 810 |
+
if st.button(label, key=f"example::{label}", use_container_width=True):
|
| 811 |
+
st.session_state["paper_input"] = value
|
| 812 |
+
st.rerun()
|
| 813 |
+
|
| 814 |
+
paper_input = st.text_input(
|
| 815 |
+
"Paper input (arXiv URL or ID)",
|
| 816 |
+
key="paper_input",
|
| 817 |
+
placeholder="https://arxiv.org/abs/2311.14919",
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
if st.button("Run pipeline + annotation", type="primary", use_container_width=True):
|
| 821 |
+
if not paper_input.strip():
|
| 822 |
+
st.error("Paper input is required.")
|
| 823 |
+
else:
|
| 824 |
+
run_pipeline_stream(
|
| 825 |
+
paper_input=paper_input,
|
| 826 |
+
source_root=source_root,
|
| 827 |
+
output_root=output_root,
|
| 828 |
+
llm_provider=llm_provider,
|
| 829 |
+
llm_model=llm_model,
|
| 830 |
+
llm_model_step4=llm_model_step4,
|
| 831 |
+
formatter_model=formatter_model,
|
| 832 |
+
judge_model=judge_model,
|
| 833 |
+
candidate_count=candidate_count,
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
st.markdown("### Latest run")
|
| 837 |
+
st.info(f"Status: {st.session_state.get('run_status', 'Idle')}")
|
| 838 |
+
if st.session_state.get("pipeline_failed_reason"):
|
| 839 |
+
st.error(st.session_state["pipeline_failed_reason"])
|
| 840 |
+
if st.session_state.get("annotation_skipped_reason"):
|
| 841 |
+
st.warning(st.session_state["annotation_skipped_reason"])
|
| 842 |
+
|
| 843 |
+
paper_dir, discovery, contributions, payload = _load_result_bundle()
|
| 844 |
+
public_export = _build_public_export(discovery, payload)
|
| 845 |
+
if public_export["citation_clusters"] or public_export["target_contribution_decompositions"]:
|
| 846 |
+
st.download_button(
|
| 847 |
+
"Download citation clusters and contribution groundings",
|
| 848 |
+
data=json.dumps(public_export, indent=2, ensure_ascii=False),
|
| 849 |
+
file_name="scipaths_run_results.json",
|
| 850 |
+
mime="application/json",
|
| 851 |
+
use_container_width=False,
|
| 852 |
+
)
|
| 853 |
+
_render_overview(payload, discovery)
|
| 854 |
+
|
| 855 |
+
with tabs[1]:
|
| 856 |
+
paper_dir, discovery, contributions, payload = _load_result_bundle()
|
| 857 |
+
_render_clusters_tab(discovery, contributions)
|
| 858 |
+
|
| 859 |
+
with tabs[2]:
|
| 860 |
+
paper_dir, discovery, contributions, payload = _load_result_bundle()
|
| 861 |
+
_render_claims_tab(payload)
|
| 862 |
+
|
| 863 |
+
if __name__ == "__main__":
|
| 864 |
+
main()
|
hf_space/streamlit_config.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from runner import STEP_LABELS
|
| 2 |
+
|
| 3 |
+
EXAMPLES = {
|
| 4 |
+
"Confidence-based MBR Decoding": "https://arxiv.org/abs/2311.14919",
|
| 5 |
+
"AVerImaTeC": "https://arxiv.org/abs/2505.17978",
|
| 6 |
+
"CSCD-NS (2022)": "https://arxiv.org/abs/2211.08788",
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
TAB_NAMES = [
|
| 10 |
+
"Pipeline Run",
|
| 11 |
+
"Citation Clusters",
|
| 12 |
+
"Target Contribution Decomposition",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
METHOD_NOTES = {
|
| 16 |
+
"Pipeline scope": "Runs steps 0, 1, 2, 3, 4, 5, 6, and 8, then launches cluster-first two-pass annotation.",
|
| 17 |
+
"Input": "Accepts a single arXiv URL or arXiv ID.",
|
| 18 |
+
"Cluster-first annotation": "Uses all refined downstream USES/EXTENDS clusters to derive target contributions, then decomposes each target contribution separately.",
|
| 19 |
+
"Stopping rule": "If no valid downstream usage clusters remain after refinement and filtering, annotation is skipped.",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
DISPLAY_STEPS = [0, 1, 2, 3, 4, 5, 6, 8]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def pipeline_steps_markdown() -> str:
|
| 26 |
+
lines = []
|
| 27 |
+
for idx in DISPLAY_STEPS:
|
| 28 |
+
lines.append(f"{idx}. {STEP_LABELS[idx]}")
|
| 29 |
+
lines.append("9. Cluster-first target contribution annotation and enabling contribution decomposition")
|
| 30 |
+
return "\n".join(lines)
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
-r hf_space/requirements.txt
|
src/common/__init__.py
ADDED
|
File without changes
|
src/common/llm_client.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import google.generativeai as genai
|
| 5 |
+
from google.generativeai.types import GenerationConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LLMClient:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.provider = os.getenv("LLM_PROVIDER", "gemini").lower()
|
| 11 |
+
self.model_name = os.getenv("LLM_MODEL", "gemini-3.1-pro-preview")
|
| 12 |
+
|
| 13 |
+
if self.provider == "gemini":
|
| 14 |
+
if genai is None:
|
| 15 |
+
raise ImportError("google-generativeai not installed.")
|
| 16 |
+
key = os.getenv("GEMINI_API_KEY")
|
| 17 |
+
if not key:
|
| 18 |
+
raise ValueError("GEMINI_API_KEY not set.")
|
| 19 |
+
genai.configure(api_key=key)
|
| 20 |
+
self.model = genai.GenerativeModel(self.model_name)
|
| 21 |
+
else:
|
| 22 |
+
raise NotImplementedError("Only Gemini provider is wired for now.")
|
| 23 |
+
|
| 24 |
+
def call(self, prompt: str, schema: Optional[dict] = None) -> str:
|
| 25 |
+
"""
|
| 26 |
+
Call the underlying LLM.
|
| 27 |
+
|
| 28 |
+
If `schema` is provided (as a plain JSON schema dict), and provider is Gemini,
|
| 29 |
+
use it as response_schema with JSON mime type.
|
| 30 |
+
"""
|
| 31 |
+
if self.provider == "gemini":
|
| 32 |
+
if schema and GenerationConfig is not None:
|
| 33 |
+
config = GenerationConfig(
|
| 34 |
+
response_schema=schema,
|
| 35 |
+
response_mime_type="application/json",
|
| 36 |
+
)
|
| 37 |
+
response = self.model.generate_content(
|
| 38 |
+
prompt,
|
| 39 |
+
generation_config=config,
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
response = self.model.generate_content(prompt)
|
| 43 |
+
|
| 44 |
+
text = getattr(response, "text", "")
|
| 45 |
+
if not text:
|
| 46 |
+
raise RuntimeError("LLM response did not contain text.")
|
| 47 |
+
return text
|
| 48 |
+
|
| 49 |
+
raise NotImplementedError("Schema-based calls only wired for Gemini right now.")
|
src/common/model_client.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Type
|
| 8 |
+
|
| 9 |
+
import litellm
|
| 10 |
+
from litellm import completion
|
| 11 |
+
from pydantic import BaseModel, ValidationError
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class ModelConfig:
|
| 16 |
+
provider: str
|
| 17 |
+
model: str
|
| 18 |
+
temperature: float = 0.2
|
| 19 |
+
max_tokens: int = 12000
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def model_name(self) -> str:
|
| 23 |
+
if "/" in self.model:
|
| 24 |
+
return self.model
|
| 25 |
+
if self.provider.lower() == "openai":
|
| 26 |
+
return f"openai/{self.model}"
|
| 27 |
+
if self.provider.lower() == "gemini":
|
| 28 |
+
return f"gemini/{self.model}"
|
| 29 |
+
return self.model
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class MultiProviderLLMClient:
|
| 33 |
+
def __init__(self, default_config: ModelConfig, stage_models: dict[str, str] | None = None):
|
| 34 |
+
self.default_config = default_config
|
| 35 |
+
self.stage_models = stage_models or {}
|
| 36 |
+
litellm.drop_params = True
|
| 37 |
+
self._validate_env(default_config.provider)
|
| 38 |
+
|
| 39 |
+
def _validate_env(self, provider: str) -> None:
|
| 40 |
+
provider = provider.lower()
|
| 41 |
+
if provider == "openai" and not os.getenv("OPENAI_API_KEY"):
|
| 42 |
+
raise ValueError("OPENAI_API_KEY is required for provider=openai")
|
| 43 |
+
if provider == "gemini" and not os.getenv("GEMINI_API_KEY"):
|
| 44 |
+
raise ValueError("GEMINI_API_KEY is required for provider=gemini")
|
| 45 |
+
|
| 46 |
+
def config_for_stage(self, stage_name: str) -> ModelConfig:
|
| 47 |
+
model_override = self.stage_models.get(stage_name)
|
| 48 |
+
if not model_override:
|
| 49 |
+
return self.default_config
|
| 50 |
+
provider = self.default_config.provider
|
| 51 |
+
model = model_override
|
| 52 |
+
if "/" in model_override:
|
| 53 |
+
provider, model = model_override.split("/", 1)
|
| 54 |
+
self._validate_env(provider)
|
| 55 |
+
return ModelConfig(
|
| 56 |
+
provider=provider,
|
| 57 |
+
model=model,
|
| 58 |
+
temperature=self.default_config.temperature,
|
| 59 |
+
max_tokens=self.default_config.max_tokens,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def generate_structured(
|
| 63 |
+
self,
|
| 64 |
+
*,
|
| 65 |
+
stage_name: str,
|
| 66 |
+
system_prompt: str,
|
| 67 |
+
user_prompt: str,
|
| 68 |
+
response_model: Type[BaseModel],
|
| 69 |
+
) -> BaseModel:
|
| 70 |
+
config = self.config_for_stage(stage_name)
|
| 71 |
+
completion_kwargs = {
|
| 72 |
+
"model": config.model_name,
|
| 73 |
+
"messages": [
|
| 74 |
+
{"role": "system", "content": system_prompt},
|
| 75 |
+
{"role": "user", "content": user_prompt},
|
| 76 |
+
],
|
| 77 |
+
"max_tokens": config.max_tokens,
|
| 78 |
+
"response_format": {"type": "json_object"},
|
| 79 |
+
}
|
| 80 |
+
temperature = self._temperature_for_model(config)
|
| 81 |
+
if temperature is not None:
|
| 82 |
+
completion_kwargs["temperature"] = temperature
|
| 83 |
+
|
| 84 |
+
response = completion(
|
| 85 |
+
**completion_kwargs,
|
| 86 |
+
)
|
| 87 |
+
content = response.choices[0].message.content or ""
|
| 88 |
+
payload = self._parse_json(content)
|
| 89 |
+
try:
|
| 90 |
+
return response_model.model_validate(payload)
|
| 91 |
+
except ValidationError as exc:
|
| 92 |
+
if isinstance(payload, list) and len(payload) == 1 and isinstance(payload[0], dict):
|
| 93 |
+
try:
|
| 94 |
+
return response_model.model_validate(payload[0])
|
| 95 |
+
except ValidationError:
|
| 96 |
+
pass
|
| 97 |
+
raise ValueError(
|
| 98 |
+
f"Stage {stage_name} returned invalid JSON for {response_model.__name__}: {exc}\nRaw content:\n{content}"
|
| 99 |
+
) from exc
|
| 100 |
+
|
| 101 |
+
def generate_text(
|
| 102 |
+
self,
|
| 103 |
+
*,
|
| 104 |
+
stage_name: str,
|
| 105 |
+
system_prompt: str,
|
| 106 |
+
user_prompt: str,
|
| 107 |
+
) -> str:
|
| 108 |
+
config = self.config_for_stage(stage_name)
|
| 109 |
+
completion_kwargs = {
|
| 110 |
+
"model": config.model_name,
|
| 111 |
+
"messages": [
|
| 112 |
+
{"role": "system", "content": system_prompt},
|
| 113 |
+
{"role": "user", "content": user_prompt},
|
| 114 |
+
],
|
| 115 |
+
"max_tokens": config.max_tokens,
|
| 116 |
+
}
|
| 117 |
+
temperature = self._temperature_for_model(config)
|
| 118 |
+
if temperature is not None:
|
| 119 |
+
completion_kwargs["temperature"] = temperature
|
| 120 |
+
response = completion(**completion_kwargs)
|
| 121 |
+
return (response.choices[0].message.content or "").strip()
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def _parse_json(text: str) -> Any:
|
| 125 |
+
text = text.strip()
|
| 126 |
+
if text.startswith("```"):
|
| 127 |
+
match = re.search(r"```(?:json)?\s*(.*?)```", text, flags=re.S)
|
| 128 |
+
if match:
|
| 129 |
+
text = match.group(1).strip()
|
| 130 |
+
try:
|
| 131 |
+
return json.loads(text)
|
| 132 |
+
except json.JSONDecodeError:
|
| 133 |
+
match = re.search(r"(\{.*\}|\[.*\])", text, flags=re.S)
|
| 134 |
+
if match:
|
| 135 |
+
return json.loads(match.group(1))
|
| 136 |
+
raise
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def _temperature_for_model(config: ModelConfig) -> float | None:
|
| 140 |
+
model_name = config.model_name.lower()
|
| 141 |
+
if "gpt-5" in model_name:
|
| 142 |
+
return None
|
| 143 |
+
return config.temperature
|
src/common/paper_package.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, List
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
SECTION_FILES = [
|
| 12 |
+
"abstract.txt",
|
| 13 |
+
"introduction.tex",
|
| 14 |
+
"related_work.tex",
|
| 15 |
+
"tldr.txt",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class PaperPackage(BaseModel):
|
| 20 |
+
paper_dir: Path
|
| 21 |
+
paper_metadata: Dict[str, Any]
|
| 22 |
+
extracted_discovery_claim: str
|
| 23 |
+
downstream_cluster_evidence: List[Dict[str, Any]]
|
| 24 |
+
paper_text: Dict[str, str]
|
| 25 |
+
full_processed_text: str
|
| 26 |
+
bibliography: List[Dict[str, Any]]
|
| 27 |
+
citation_contexts: List[Dict[str, Any]]
|
| 28 |
+
|
| 29 |
+
def to_prompt_payload(self) -> Dict[str, Any]:
|
| 30 |
+
return {
|
| 31 |
+
"paper_metadata": self.paper_metadata,
|
| 32 |
+
"extracted_discovery_claim": self.extracted_discovery_claim,
|
| 33 |
+
"downstream_cluster_evidence": self.downstream_cluster_evidence,
|
| 34 |
+
"paper_text": self.paper_text,
|
| 35 |
+
"full_processed_text": self.full_processed_text,
|
| 36 |
+
"bibliography": self.bibliography,
|
| 37 |
+
"citation_contexts": self.citation_contexts,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _load_json(path: Path, default: Any) -> Any:
|
| 42 |
+
try:
|
| 43 |
+
return json.loads(path.read_text())
|
| 44 |
+
except Exception:
|
| 45 |
+
return default
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _read_text(path: Path) -> str:
|
| 49 |
+
try:
|
| 50 |
+
return path.read_text()
|
| 51 |
+
except Exception:
|
| 52 |
+
return ""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _normalize_dict_payload(value: Any) -> Dict[str, Any]:
|
| 56 |
+
if isinstance(value, dict):
|
| 57 |
+
return value
|
| 58 |
+
if isinstance(value, list):
|
| 59 |
+
for item in value:
|
| 60 |
+
if isinstance(item, dict):
|
| 61 |
+
return item
|
| 62 |
+
return {}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _collect_sections(paper_dir: Path) -> Dict[str, str]:
|
| 66 |
+
sections_dir = paper_dir / "sections"
|
| 67 |
+
out: Dict[str, str] = {}
|
| 68 |
+
for name in SECTION_FILES:
|
| 69 |
+
text = _read_text(sections_dir / name).strip()
|
| 70 |
+
if text:
|
| 71 |
+
out[name] = text[:12000]
|
| 72 |
+
if not out:
|
| 73 |
+
processed = _read_text(paper_dir / "processed_main.tex").strip()
|
| 74 |
+
if processed:
|
| 75 |
+
out["processed_main.tex"] = processed[:24000]
|
| 76 |
+
return out
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _collect_full_processed_text(paper_dir: Path) -> str:
|
| 80 |
+
processed = _read_text(paper_dir / "processed_main.tex").strip()
|
| 81 |
+
if processed:
|
| 82 |
+
return processed
|
| 83 |
+
|
| 84 |
+
sections_dir = paper_dir / "sections"
|
| 85 |
+
parts: List[str] = []
|
| 86 |
+
if sections_dir.exists():
|
| 87 |
+
for path in sorted(sections_dir.iterdir()):
|
| 88 |
+
if not path.is_file():
|
| 89 |
+
continue
|
| 90 |
+
text = _read_text(path).strip()
|
| 91 |
+
if text:
|
| 92 |
+
parts.append(f"[{path.name}]\n{text}")
|
| 93 |
+
return "\n\n".join(parts)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _extract_year(value: Any) -> Any:
|
| 97 |
+
if value:
|
| 98 |
+
return value
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _normalise_reference_record(ref: Dict[str, Any]) -> Dict[str, Any]:
|
| 103 |
+
cited = ref.get("citedPaper")
|
| 104 |
+
source = cited if isinstance(cited, dict) else ref
|
| 105 |
+
external_ids = source.get("external_ids") or source.get("externalIds") or {}
|
| 106 |
+
return {
|
| 107 |
+
"ref_id": (
|
| 108 |
+
ref.get("ref_id")
|
| 109 |
+
or ref.get("bib_key")
|
| 110 |
+
or source.get("ref_id")
|
| 111 |
+
or source.get("bib_key")
|
| 112 |
+
or source.get("paperId")
|
| 113 |
+
or source.get("paper_id")
|
| 114 |
+
or external_ids.get("ACL")
|
| 115 |
+
or external_ids.get("ArXiv")
|
| 116 |
+
or external_ids.get("DOI")
|
| 117 |
+
),
|
| 118 |
+
"title": source.get("title") or source.get("ref_title"),
|
| 119 |
+
"authors": source.get("authors") or source.get("ref_authors"),
|
| 120 |
+
"year": _extract_year(source.get("year") or source.get("ref_year")),
|
| 121 |
+
"external_ids": external_ids,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _parse_bibtex_entries(text: str, limit: int) -> List[Dict[str, Any]]:
|
| 126 |
+
entries: List[Dict[str, Any]] = []
|
| 127 |
+
for match in re.finditer(r"@\w+\s*\{\s*([^,]+),(.*?)(?=\n@\w+\s*\{|\Z)", text, re.S):
|
| 128 |
+
key = match.group(1).strip()
|
| 129 |
+
body = match.group(2)
|
| 130 |
+
fields: Dict[str, str] = {}
|
| 131 |
+
for field in ("title", "author", "year", "doi", "url", "eprint"):
|
| 132 |
+
field_match = re.search(
|
| 133 |
+
rf"\b{field}\s*=\s*(\{{(?:[^{{}}]|\{{[^{{}}]*\}})*\}}|\"[^\"]*\"|[^,\n]+)",
|
| 134 |
+
body,
|
| 135 |
+
re.I | re.S,
|
| 136 |
+
)
|
| 137 |
+
if field_match:
|
| 138 |
+
value = field_match.group(1).strip().strip(",")
|
| 139 |
+
if (value.startswith("{") and value.endswith("}")) or (
|
| 140 |
+
value.startswith('"') and value.endswith('"')
|
| 141 |
+
):
|
| 142 |
+
value = value[1:-1]
|
| 143 |
+
fields[field] = re.sub(r"\s+", " ", value).strip()
|
| 144 |
+
if fields:
|
| 145 |
+
external_ids: Dict[str, Any] = {}
|
| 146 |
+
if fields.get("doi"):
|
| 147 |
+
external_ids["DOI"] = fields["doi"]
|
| 148 |
+
if fields.get("eprint"):
|
| 149 |
+
external_ids["ArXiv"] = fields["eprint"]
|
| 150 |
+
entries.append(
|
| 151 |
+
{
|
| 152 |
+
"ref_id": key,
|
| 153 |
+
"title": fields.get("title"),
|
| 154 |
+
"authors": fields.get("author"),
|
| 155 |
+
"year": fields.get("year"),
|
| 156 |
+
"external_ids": external_ids,
|
| 157 |
+
}
|
| 158 |
+
)
|
| 159 |
+
if len(entries) >= limit:
|
| 160 |
+
break
|
| 161 |
+
return entries
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _collect_bibtex_citation_contexts(paper_dir: Path, limit: int = 60) -> List[Dict[str, Any]]:
|
| 165 |
+
bibtex = _read_text(paper_dir / "references.bib")
|
| 166 |
+
processed = _read_text(paper_dir / "processed_main.tex")
|
| 167 |
+
if not bibtex or not processed:
|
| 168 |
+
return []
|
| 169 |
+
|
| 170 |
+
refs = _parse_bibtex_entries(bibtex, limit=500)
|
| 171 |
+
out: List[Dict[str, Any]] = []
|
| 172 |
+
seen: set[tuple[str, int]] = set()
|
| 173 |
+
for ref in refs:
|
| 174 |
+
ref_id = ref.get("ref_id")
|
| 175 |
+
if not ref_id:
|
| 176 |
+
continue
|
| 177 |
+
for match in re.finditer(rf"\\cite\w*\s*(?:\[[^\]]*\]\s*)*\{{[^}}]*\b{re.escape(str(ref_id))}\b[^}}]*\}}", processed):
|
| 178 |
+
key = (str(ref_id), match.start())
|
| 179 |
+
if key in seen:
|
| 180 |
+
continue
|
| 181 |
+
seen.add(key)
|
| 182 |
+
start = max(0, match.start() - 350)
|
| 183 |
+
end = min(len(processed), match.end() + 350)
|
| 184 |
+
snippet = re.sub(r"\s+", " ", processed[start:end]).strip()
|
| 185 |
+
out.append(
|
| 186 |
+
{
|
| 187 |
+
"ref_id": ref_id,
|
| 188 |
+
"citation_marker": ref.get("title") or ref_id,
|
| 189 |
+
"text": snippet,
|
| 190 |
+
"section": None,
|
| 191 |
+
"intents": [],
|
| 192 |
+
}
|
| 193 |
+
)
|
| 194 |
+
if len(out) >= limit:
|
| 195 |
+
return out
|
| 196 |
+
return out
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _collect_bibliography(paper_dir: Path, limit: int = 80) -> List[Dict[str, Any]]:
|
| 200 |
+
refs = _load_json(paper_dir / "references_metadata.json", [])
|
| 201 |
+
if isinstance(refs, list) and refs:
|
| 202 |
+
return [_normalise_reference_record(ref) for ref in refs[:limit] if isinstance(ref, dict)]
|
| 203 |
+
|
| 204 |
+
bibtex = _read_text(paper_dir / "references.bib")
|
| 205 |
+
if bibtex:
|
| 206 |
+
return _parse_bibtex_entries(bibtex, limit)
|
| 207 |
+
return []
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _collect_citation_contexts(paper_dir: Path, limit: int = 60) -> List[Dict[str, Any]]:
|
| 211 |
+
refs = _load_json(paper_dir / "references_metadata.json", [])
|
| 212 |
+
out = []
|
| 213 |
+
if isinstance(refs, list):
|
| 214 |
+
for ref in refs:
|
| 215 |
+
if not isinstance(ref, dict):
|
| 216 |
+
continue
|
| 217 |
+
ref_record = _normalise_reference_record(ref)
|
| 218 |
+
for context in ref.get("contextsWithIntent") or []:
|
| 219 |
+
if not isinstance(context, dict):
|
| 220 |
+
continue
|
| 221 |
+
text = context.get("context") or context.get("text") or ""
|
| 222 |
+
if not text:
|
| 223 |
+
continue
|
| 224 |
+
out.append(
|
| 225 |
+
{
|
| 226 |
+
"ref_id": ref_record.get("ref_id"),
|
| 227 |
+
"citation_marker": ref_record.get("title"),
|
| 228 |
+
"text": text,
|
| 229 |
+
"section": context.get("section"),
|
| 230 |
+
"intents": context.get("intents", []),
|
| 231 |
+
}
|
| 232 |
+
)
|
| 233 |
+
if len(out) >= limit:
|
| 234 |
+
return out
|
| 235 |
+
contexts = _load_json(paper_dir / "usage_contexts.json", [])
|
| 236 |
+
if isinstance(contexts, list):
|
| 237 |
+
for item in contexts:
|
| 238 |
+
entry = {
|
| 239 |
+
"ref_id": item.get("ref_id") or item.get("bib_key"),
|
| 240 |
+
"citation_marker": item.get("citation_marker"),
|
| 241 |
+
"text": item.get("text") or item.get("text_raw") or "",
|
| 242 |
+
"section": item.get("section"),
|
| 243 |
+
}
|
| 244 |
+
if entry["text"]:
|
| 245 |
+
out.append(entry)
|
| 246 |
+
if len(out) >= limit:
|
| 247 |
+
break
|
| 248 |
+
if not out:
|
| 249 |
+
out = _collect_bibtex_citation_contexts(paper_dir, limit=limit)
|
| 250 |
+
return out
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _collect_downstream_cluster_evidence(paper_dir: Path) -> List[Dict[str, Any]]:
|
| 254 |
+
discovery = _normalize_dict_payload(_load_json(paper_dir / "usage_discovery_from_contributions.json", {}))
|
| 255 |
+
clusters = discovery.get("clusters", [])
|
| 256 |
+
out = []
|
| 257 |
+
for cluster in clusters:
|
| 258 |
+
out.append(
|
| 259 |
+
{
|
| 260 |
+
"cluster_id": cluster.get("cluster_id"),
|
| 261 |
+
"representative_claim": cluster.get("representative_claim") or cluster.get("cluster_title"),
|
| 262 |
+
"cluster_title": cluster.get("cluster_title"),
|
| 263 |
+
"count": cluster.get("count"),
|
| 264 |
+
"merge_rationale": cluster.get("merge_rationale"),
|
| 265 |
+
}
|
| 266 |
+
)
|
| 267 |
+
return out
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def load_paper_package(paper_dir: str | Path, extracted_claim_override: str | None = None) -> PaperPackage:
|
| 271 |
+
paper_dir = Path(paper_dir)
|
| 272 |
+
discovery = _normalize_dict_payload(_load_json(paper_dir / "usage_discovery_from_contributions.json", {}))
|
| 273 |
+
paper_metadata = _normalize_dict_payload(_load_json(paper_dir / "paper_metadata.json", {}))
|
| 274 |
+
claim = extracted_claim_override or (
|
| 275 |
+
discovery.get("most_impactful_contribution_self_contained")
|
| 276 |
+
or discovery.get("most_impactful_contribution")
|
| 277 |
+
or ""
|
| 278 |
+
)
|
| 279 |
+
return PaperPackage(
|
| 280 |
+
paper_dir=paper_dir,
|
| 281 |
+
paper_metadata=paper_metadata,
|
| 282 |
+
extracted_discovery_claim=claim,
|
| 283 |
+
downstream_cluster_evidence=_collect_downstream_cluster_evidence(paper_dir),
|
| 284 |
+
paper_text=_collect_sections(paper_dir),
|
| 285 |
+
full_processed_text=_collect_full_processed_text(paper_dir),
|
| 286 |
+
bibliography=_collect_bibliography(paper_dir),
|
| 287 |
+
citation_contexts=_collect_citation_contexts(paper_dir),
|
| 288 |
+
)
|
src/step_01_fetch/config.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ACL_IDS_PATH = Path("input_ids.json")
|
| 5 |
+
PAPERS_DIR = Path("papers")
|
| 6 |
+
SEMANTIC_SCHOLAR_API_KEY = os.getenv("SEMANTIC_SCHOLAR_API_KEY", "")
|
src/step_01_fetch/fetch_metadata.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import tarfile
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
import arxiv
|
| 11 |
+
import requests
|
| 12 |
+
|
| 13 |
+
from config import ACL_IDS_PATH
|
| 14 |
+
from process_tex_source import preprocess_tex, extract_introduction_and_related
|
| 15 |
+
from semanticscholar_client import get_paper, get_paper_links, search_by_title
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_ids(path: Path):
|
| 19 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def ensure_dir(path: Path):
|
| 23 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_ARXIV_LAST_TS = 0.0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _cleanup_partial_source_dir(source_dir: Path) -> None:
|
| 30 |
+
for pattern in ("*.tar.gz", "*.tgz", "*.tar"):
|
| 31 |
+
for path in source_dir.glob(pattern):
|
| 32 |
+
try:
|
| 33 |
+
path.unlink()
|
| 34 |
+
except Exception:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _download_arxiv_source_with_retries(paper, source_dir: Path, arxiv_id: str) -> Path | None:
|
| 39 |
+
max_retries = int(os.getenv("ARXIV_SOURCE_MAX_RETRIES", "4"))
|
| 40 |
+
base_sleep = float(os.getenv("ARXIV_SOURCE_BASE_SLEEP", "2.0"))
|
| 41 |
+
max_sleep = float(os.getenv("ARXIV_MAX_BACKOFF", "60"))
|
| 42 |
+
last_exc = None
|
| 43 |
+
|
| 44 |
+
for attempt in range(max_retries):
|
| 45 |
+
_cleanup_partial_source_dir(source_dir)
|
| 46 |
+
try:
|
| 47 |
+
_arxiv_min_interval_sleep()
|
| 48 |
+
tar_path = Path(paper.download_source(dirpath=str(source_dir)))
|
| 49 |
+
if not tar_path.exists():
|
| 50 |
+
raise FileNotFoundError(f"download_source returned {tar_path}, but the file does not exist")
|
| 51 |
+
if tar_path.stat().st_size < 1024:
|
| 52 |
+
raise IOError(f"downloaded source archive is unexpectedly small ({tar_path.stat().st_size} bytes)")
|
| 53 |
+
return tar_path
|
| 54 |
+
except Exception as exc:
|
| 55 |
+
last_exc = exc
|
| 56 |
+
sleep = min(base_sleep * (2**attempt), max_sleep) + random.uniform(0.0, 0.5)
|
| 57 |
+
print(f"[WARN] Failed to download source for {arxiv_id} on attempt {attempt + 1}/{max_retries}: {exc}")
|
| 58 |
+
if attempt + 1 < max_retries:
|
| 59 |
+
print(f"[INFO] Retrying source download in {sleep:.2f}s")
|
| 60 |
+
time.sleep(sleep)
|
| 61 |
+
|
| 62 |
+
print(f"[WARN] Source download failed for {arxiv_id} after {max_retries} attempts: {last_exc}")
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _arxiv_min_interval_sleep() -> None:
|
| 67 |
+
"""Global throttle to avoid arXiv API rate limits."""
|
| 68 |
+
global _ARXIV_LAST_TS
|
| 69 |
+
min_interval = float(os.getenv("ARXIV_MIN_INTERVAL", "1.0"))
|
| 70 |
+
now = time.monotonic()
|
| 71 |
+
elapsed = now - _ARXIV_LAST_TS
|
| 72 |
+
if elapsed < min_interval:
|
| 73 |
+
time.sleep(min_interval - elapsed)
|
| 74 |
+
_ARXIV_LAST_TS = time.monotonic()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def download_arxiv_tex(arxiv_id: str, base_dir: Path) -> Path | None:
|
| 78 |
+
"""
|
| 79 |
+
Download LaTeX source from arXiv and return the path to a merged TeX file.
|
| 80 |
+
|
| 81 |
+
- arxiv_id: e.g. "2410.22815"
|
| 82 |
+
- base_dir: paper directory where source should be unpacked
|
| 83 |
+
"""
|
| 84 |
+
source_dir = base_dir / f"tex_{arxiv_id}"
|
| 85 |
+
source_dir.mkdir(parents=True, exist_ok=True)
|
| 86 |
+
search = arxiv.Search(id_list=[arxiv_id])
|
| 87 |
+
max_retries = int(os.getenv("ARXIV_MAX_RETRIES", "6"))
|
| 88 |
+
base_sleep = float(os.getenv("ARXIV_BASE_SLEEP", "2.0"))
|
| 89 |
+
max_sleep = float(os.getenv("ARXIV_MAX_BACKOFF", "60"))
|
| 90 |
+
paper = None
|
| 91 |
+
|
| 92 |
+
for attempt in range(max_retries):
|
| 93 |
+
try:
|
| 94 |
+
_arxiv_min_interval_sleep()
|
| 95 |
+
paper = next(search.results())
|
| 96 |
+
break
|
| 97 |
+
except StopIteration:
|
| 98 |
+
print(f"[WARN] No arXiv paper found for ID {arxiv_id}")
|
| 99 |
+
return None
|
| 100 |
+
except arxiv.HTTPError as exc:
|
| 101 |
+
if getattr(exc, "status", None) == 429 or "429" in str(exc):
|
| 102 |
+
sleep = min(base_sleep * (2**attempt), max_sleep) + random.uniform(0.0, 0.5)
|
| 103 |
+
print(f"[WARN] arXiv 429 → retrying in {sleep:.2f}s")
|
| 104 |
+
time.sleep(sleep)
|
| 105 |
+
continue
|
| 106 |
+
print(f"[WARN] arXiv HTTP error for {arxiv_id}: {exc}")
|
| 107 |
+
return None
|
| 108 |
+
except Exception as exc:
|
| 109 |
+
sleep = min(base_sleep * (2**attempt), max_sleep) + random.uniform(0.0, 0.5)
|
| 110 |
+
print(f"[WARN] arXiv error {exc} → retrying in {sleep:.2f}s")
|
| 111 |
+
time.sleep(sleep)
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
if paper is None:
|
| 115 |
+
print(f"[ERROR] Giving up after {max_retries} attempts for arXiv ID {arxiv_id}")
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
tar_path = _download_arxiv_source_with_retries(paper, source_dir, arxiv_id)
|
| 119 |
+
if tar_path is None:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
with tarfile.open(tar_path) as tar:
|
| 124 |
+
tar.extractall(path=source_dir)
|
| 125 |
+
os.remove(tar_path)
|
| 126 |
+
except Exception as exc:
|
| 127 |
+
print(f"[WARN] Failed to extract source for {arxiv_id}: {exc}")
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
processed_tex = preprocess_tex(source_dir)
|
| 131 |
+
if processed_tex:
|
| 132 |
+
extract_introduction_and_related(processed_tex)
|
| 133 |
+
|
| 134 |
+
if not processed_tex or not processed_tex.exists():
|
| 135 |
+
print(f"[WARN] Could not produce merged TeX for {arxiv_id}")
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
print(f"[INFO] Processed LaTeX for {arxiv_id} at {processed_tex}")
|
| 139 |
+
return processed_tex
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _extract_arxiv_id_from_text(text: str) -> str | None:
|
| 143 |
+
if not text:
|
| 144 |
+
return None
|
| 145 |
+
match = re.search(r"\b(\d{4}\.\d{4,5}(?:v\d+)?)\b", text)
|
| 146 |
+
if match:
|
| 147 |
+
return match.group(1)
|
| 148 |
+
match = re.search(r"arxiv[:\s/]*(\d{4}\.\d{4,5}(?:v\d+)?)", text, re.IGNORECASE)
|
| 149 |
+
if match:
|
| 150 |
+
return match.group(1)
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _safe_write_json(path: Path, payload) -> None:
|
| 155 |
+
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _safe_write_text(path: Path, text: str) -> None:
|
| 159 |
+
path.write_text(text, encoding="utf-8")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _query_openreview_for_paper(openreview_id: str) -> dict | None:
|
| 163 |
+
"""Query OpenReview using a real OpenReview note/forum id."""
|
| 164 |
+
if not openreview_id:
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
try_urls = [
|
| 168 |
+
f"https://api.openreview.net/notes?forum={openreview_id}",
|
| 169 |
+
f"https://api2.openreview.net/notes?forum={openreview_id}",
|
| 170 |
+
f"https://api.openreview.net/notes?id={openreview_id}",
|
| 171 |
+
f"https://api2.openreview.net/notes?id={openreview_id}",
|
| 172 |
+
]
|
| 173 |
+
|
| 174 |
+
for url in try_urls:
|
| 175 |
+
try:
|
| 176 |
+
response = requests.get(url, timeout=20)
|
| 177 |
+
if response.status_code != 200:
|
| 178 |
+
continue
|
| 179 |
+
payload = response.json()
|
| 180 |
+
except Exception:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
notes = None
|
| 184 |
+
if isinstance(payload, dict) and isinstance(payload.get("notes"), list):
|
| 185 |
+
notes = payload["notes"]
|
| 186 |
+
elif isinstance(payload, dict) and payload.get("content"):
|
| 187 |
+
notes = [payload]
|
| 188 |
+
elif isinstance(payload, list):
|
| 189 |
+
notes = payload
|
| 190 |
+
|
| 191 |
+
if not notes:
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
note = notes[0]
|
| 195 |
+
content = note.get("content") if isinstance(note, dict) else None
|
| 196 |
+
title = None
|
| 197 |
+
arxiv_id = None
|
| 198 |
+
pdf_url = None
|
| 199 |
+
|
| 200 |
+
if isinstance(content, dict):
|
| 201 |
+
raw_title = content.get("title") or content.get("paperTitle")
|
| 202 |
+
title = raw_title.get("value") if isinstance(raw_title, dict) else raw_title
|
| 203 |
+
|
| 204 |
+
raw_pdf = content.get("pdf")
|
| 205 |
+
pdf_url = raw_pdf.get("value") if isinstance(raw_pdf, dict) else raw_pdf
|
| 206 |
+
|
| 207 |
+
for value in content.values():
|
| 208 |
+
if isinstance(value, dict):
|
| 209 |
+
value = value.get("value")
|
| 210 |
+
if isinstance(value, list):
|
| 211 |
+
value = " ".join(str(item) for item in value)
|
| 212 |
+
if isinstance(value, str):
|
| 213 |
+
arxiv_id = _extract_arxiv_id_from_text(value)
|
| 214 |
+
if arxiv_id:
|
| 215 |
+
break
|
| 216 |
+
|
| 217 |
+
if not title and isinstance(note, dict):
|
| 218 |
+
title = note.get("title") or note.get("forumTitle")
|
| 219 |
+
|
| 220 |
+
if not arxiv_id and isinstance(note, dict):
|
| 221 |
+
for value in note.values():
|
| 222 |
+
if isinstance(value, str):
|
| 223 |
+
arxiv_id = _extract_arxiv_id_from_text(value)
|
| 224 |
+
if arxiv_id:
|
| 225 |
+
break
|
| 226 |
+
|
| 227 |
+
return {
|
| 228 |
+
"title": title,
|
| 229 |
+
"arxiv_id": arxiv_id,
|
| 230 |
+
"pdf_url": pdf_url,
|
| 231 |
+
"openreview_id": openreview_id,
|
| 232 |
+
"source_url": url,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def _treat_as_openreview(paper: dict) -> bool:
|
| 239 |
+
acl_id = str(paper.get("id", "")).lower()
|
| 240 |
+
id_type = str(paper.get("id_type", "")).lower()
|
| 241 |
+
return (
|
| 242 |
+
id_type == "openreview"
|
| 243 |
+
or bool(paper.get("openreview_id"))
|
| 244 |
+
or acl_id.startswith("neurips-")
|
| 245 |
+
or acl_id.startswith("icml-")
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _fetch_s2_by_title(title: str, acl_id: str) -> tuple[int, dict | None]:
|
| 250 |
+
if not title:
|
| 251 |
+
print(f"[WARN] no title available for {acl_id} → skipping.")
|
| 252 |
+
return 0, None
|
| 253 |
+
hit = search_by_title(title)
|
| 254 |
+
if not hit:
|
| 255 |
+
print(f"[WARN] no S2 match for {acl_id} ({title}) → skipping.")
|
| 256 |
+
return 0, None
|
| 257 |
+
s2_id = hit["paperId"]
|
| 258 |
+
print(f"[DEBUG] title search matched semantic scholar paperId={s2_id}")
|
| 259 |
+
return get_paper(s2_id, id_type="SemanticScholar")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _best_arxiv_id(*values: str) -> str | None:
|
| 263 |
+
for value in values:
|
| 264 |
+
arxiv_id = _extract_arxiv_id_from_text(value or "")
|
| 265 |
+
if arxiv_id:
|
| 266 |
+
return arxiv_id
|
| 267 |
+
return None
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _write_openreview_snapshot(paper_dir: Path, payload: dict) -> None:
|
| 271 |
+
if payload:
|
| 272 |
+
_safe_write_json(paper_dir / "openreview_metadata.json", payload)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _write_metadata_outputs(paper_dir: Path, acl_id: str, data: dict) -> None:
|
| 276 |
+
meta_path = paper_dir / "paper_metadata.json"
|
| 277 |
+
_safe_write_json(meta_path, [data])
|
| 278 |
+
print(f"[DEBUG] wrote metadata to {meta_path}")
|
| 279 |
+
|
| 280 |
+
external_ids = data.get("externalIds", {}) or {}
|
| 281 |
+
arxiv_id = external_ids.get("ArXiv")
|
| 282 |
+
if arxiv_id:
|
| 283 |
+
download_arxiv_tex(arxiv_id=arxiv_id, base_dir=paper_dir)
|
| 284 |
+
|
| 285 |
+
sections_dir = paper_dir / "sections"
|
| 286 |
+
sections_dir.mkdir(exist_ok=True)
|
| 287 |
+
|
| 288 |
+
abstract = data.get("abstract")
|
| 289 |
+
if abstract:
|
| 290 |
+
_safe_write_text(sections_dir / "abstract.txt", abstract)
|
| 291 |
+
|
| 292 |
+
tldr_obj = data.get("tldr")
|
| 293 |
+
if isinstance(tldr_obj, dict) and tldr_obj.get("text"):
|
| 294 |
+
_safe_write_text(sections_dir / "tldr.txt", tldr_obj["text"])
|
| 295 |
+
|
| 296 |
+
semantic_id = data.get("paperId")
|
| 297 |
+
if not semantic_id:
|
| 298 |
+
print(f"[WARN] no semantic_id for {acl_id} → skip refs/cites.")
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
citation_count = data.get("citationCount", 0)
|
| 302 |
+
reference_count = data.get("referenceCount", 0)
|
| 303 |
+
|
| 304 |
+
ref_status, refs = get_paper_links(semantic_id, "references", reference_count)
|
| 305 |
+
if ref_status == 200:
|
| 306 |
+
_safe_write_json(paper_dir / "references_metadata.json", refs)
|
| 307 |
+
|
| 308 |
+
cit_status, cits = get_paper_links(semantic_id, "citations", citation_count)
|
| 309 |
+
if cit_status == 200:
|
| 310 |
+
_safe_write_json(paper_dir / "citations_metadata.json", cits)
|
| 311 |
+
|
| 312 |
+
if "ArXiv" not in external_ids:
|
| 313 |
+
_safe_write_text(paper_dir / "no_arxiv.txt", "no arxiv for this paper")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def fetch_one_acl_id(paper: dict, base_dir: Path):
|
| 317 |
+
acl_id = paper["id"]
|
| 318 |
+
title = (paper.get("title") or "").strip()
|
| 319 |
+
id_type = paper.get("id_type", "ACL")
|
| 320 |
+
openreview_id = paper.get("openreview_id", "")
|
| 321 |
+
input_pdf_url = paper.get("pdf_url", "")
|
| 322 |
+
s2_key = os.getenv("SEMANTIC_SCHOLAR_API_KEY", "")
|
| 323 |
+
print(
|
| 324 |
+
f"[DEBUG] fetch_one_acl_id: id={acl_id} id_type={id_type} "
|
| 325 |
+
f"title_len={len(title)} s2_key_present={'yes' if bool(s2_key) else 'no'} "
|
| 326 |
+
f"s2_key_len={len(s2_key)}"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
paper_dir = base_dir / acl_id
|
| 330 |
+
ensure_dir(paper_dir)
|
| 331 |
+
meta_path = paper_dir / "paper_metadata.json"
|
| 332 |
+
|
| 333 |
+
if meta_path.exists():
|
| 334 |
+
return
|
| 335 |
+
|
| 336 |
+
status, data = 0, None
|
| 337 |
+
fetch_label = f"{id_type}:{acl_id}"
|
| 338 |
+
is_openreview = _treat_as_openreview(paper)
|
| 339 |
+
openreview_meta = None
|
| 340 |
+
attempted_title_search = False
|
| 341 |
+
|
| 342 |
+
if is_openreview:
|
| 343 |
+
try:
|
| 344 |
+
openreview_meta = _query_openreview_for_paper(openreview_id or acl_id)
|
| 345 |
+
except Exception as exc:
|
| 346 |
+
print(f"[WARN] OpenReview lookup failed for {acl_id}: {exc}")
|
| 347 |
+
openreview_meta = None
|
| 348 |
+
|
| 349 |
+
if openreview_meta:
|
| 350 |
+
_write_openreview_snapshot(paper_dir, openreview_meta)
|
| 351 |
+
or_title = (openreview_meta.get("title") or title or "").strip()
|
| 352 |
+
arxiv_id = (
|
| 353 |
+
_best_arxiv_id(
|
| 354 |
+
openreview_meta.get("arxiv_id", ""),
|
| 355 |
+
openreview_meta.get("pdf_url", ""),
|
| 356 |
+
input_pdf_url,
|
| 357 |
+
)
|
| 358 |
+
or ""
|
| 359 |
+
)
|
| 360 |
+
if arxiv_id:
|
| 361 |
+
print(f"[DEBUG] OpenReview -> found ArXiv {arxiv_id} for {acl_id}")
|
| 362 |
+
status, data = get_paper(arxiv_id, id_type="ArXiv")
|
| 363 |
+
fetch_label = f"ArXiv:{arxiv_id}"
|
| 364 |
+
title = or_title or title
|
| 365 |
+
elif or_title:
|
| 366 |
+
print(f"[DEBUG] OpenReview -> no arXiv for {acl_id}, title-searching")
|
| 367 |
+
status, data = _fetch_s2_by_title(or_title, acl_id)
|
| 368 |
+
fetch_label = f"title:{or_title[:80]}"
|
| 369 |
+
title = or_title
|
| 370 |
+
attempted_title_search = True
|
| 371 |
+
else:
|
| 372 |
+
print(f"[WARN] OpenReview metadata for {acl_id} had neither title nor arXiv")
|
| 373 |
+
else:
|
| 374 |
+
print(f"[WARN] no OpenReview metadata for {acl_id} (openreview_id={openreview_id or acl_id})")
|
| 375 |
+
|
| 376 |
+
if data is None and title and not attempted_title_search:
|
| 377 |
+
print(f"[DEBUG] OpenReview fallback -> title-searching extracted title for {acl_id}")
|
| 378 |
+
status, data = _fetch_s2_by_title(title, acl_id)
|
| 379 |
+
fetch_label = f"title:{title[:80]}"
|
| 380 |
+
attempted_title_search = True
|
| 381 |
+
|
| 382 |
+
if data is None and not is_openreview:
|
| 383 |
+
status, data = get_paper(acl_id, id_type=id_type)
|
| 384 |
+
fetch_label = f"{id_type}:{acl_id}"
|
| 385 |
+
|
| 386 |
+
if data is None and not attempted_title_search:
|
| 387 |
+
print(
|
| 388 |
+
f"[WARN] direct fetch failed for {fetch_label} "
|
| 389 |
+
f"(status={status}) → trying title search with title_len={len(title)}"
|
| 390 |
+
)
|
| 391 |
+
status, data = _fetch_s2_by_title(title, acl_id)
|
| 392 |
+
|
| 393 |
+
if status != 200 or data is None:
|
| 394 |
+
print(f"[WARN] still no data for {acl_id} → skipping.")
|
| 395 |
+
return
|
| 396 |
+
|
| 397 |
+
_write_metadata_outputs(paper_dir, acl_id, data)
|
| 398 |
+
print("[SUCCESS]")
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def fetch_all_metadata(acl_ids_path: Path, out_dir: Path, start_from: str | None = None, resume: bool = False):
|
| 402 |
+
raw = json.loads(acl_ids_path.read_text(encoding="utf-8"))
|
| 403 |
+
papers = raw if isinstance(raw[0], dict) else [{"id": x, "title": ""} for x in raw]
|
| 404 |
+
|
| 405 |
+
start_seen = start_from is None
|
| 406 |
+
for paper in papers:
|
| 407 |
+
pid = str(paper.get("id", ""))
|
| 408 |
+
if not start_seen:
|
| 409 |
+
if pid == start_from:
|
| 410 |
+
start_seen = True
|
| 411 |
+
else:
|
| 412 |
+
continue
|
| 413 |
+
if resume:
|
| 414 |
+
paper_dir = out_dir / pid
|
| 415 |
+
if (paper_dir / "paper_metadata.json").exists():
|
| 416 |
+
continue
|
| 417 |
+
fetch_one_acl_id(paper, out_dir)
|
| 418 |
+
return "Meta Data Completed"
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
if __name__ == "__main__":
|
| 422 |
+
parser = argparse.ArgumentParser()
|
| 423 |
+
parser.add_argument("--ids", type=str, required=True, help="Path to JSON file with paper IDs.")
|
| 424 |
+
parser.add_argument("--outdir", type=str, default="papers", help="Output directory for metadata.")
|
| 425 |
+
parser.add_argument("--start-from", type=str, default=None, help="Start from this paper ID.")
|
| 426 |
+
parser.add_argument("--resume", action="store_true", help="Skip papers that already have paper_metadata.json.")
|
| 427 |
+
args = parser.parse_args()
|
| 428 |
+
|
| 429 |
+
ACL_IDS_PATH = Path(args.ids).expanduser().resolve()
|
| 430 |
+
OUTDIR = Path(args.outdir).expanduser().resolve()
|
| 431 |
+
|
| 432 |
+
if not ACL_IDS_PATH.exists():
|
| 433 |
+
raise FileNotFoundError(f"Could not find {ACL_IDS_PATH}")
|
| 434 |
+
|
| 435 |
+
print(f"[INFO] Using ID list from {ACL_IDS_PATH}")
|
| 436 |
+
print(f"[INFO] Output will be saved to {OUTDIR}")
|
| 437 |
+
|
| 438 |
+
start = time.time()
|
| 439 |
+
fetch_all_metadata(acl_ids_path=ACL_IDS_PATH, out_dir=OUTDIR, start_from=args.start_from, resume=args.resume)
|
| 440 |
+
print("done in", time.time() - start, "s")
|
src/step_01_fetch/process_tex_source.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import shutil
|
| 5 |
+
|
| 6 |
+
def read_tex(path: Path) -> str:
|
| 7 |
+
try:
|
| 8 |
+
return path.read_text(encoding="utf-8", errors="ignore")
|
| 9 |
+
except Exception:
|
| 10 |
+
return ""
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def resolve_inputs(tex: str, base_dir: Path, seen=None) -> str:
|
| 14 |
+
"""
|
| 15 |
+
Recursively replace \\input{...} and \\include{...} with file contents.
|
| 16 |
+
"""
|
| 17 |
+
if seen is None:
|
| 18 |
+
seen = set()
|
| 19 |
+
|
| 20 |
+
pattern = r'\\(?:input|include)\{([^}]+)\}'
|
| 21 |
+
|
| 22 |
+
def repl(match):
|
| 23 |
+
name = match.group(1)
|
| 24 |
+
if not name.endswith(".tex"):
|
| 25 |
+
name += ".tex"
|
| 26 |
+
|
| 27 |
+
full = base_dir / name
|
| 28 |
+
|
| 29 |
+
if full in seen:
|
| 30 |
+
return f"% WARNING: skipped circular input {full}\n"
|
| 31 |
+
|
| 32 |
+
if not full.exists():
|
| 33 |
+
return f"% WARNING: missing file {full}\n"
|
| 34 |
+
|
| 35 |
+
seen.add(full)
|
| 36 |
+
content = read_tex(full)
|
| 37 |
+
return resolve_inputs(content, full.parent, seen)
|
| 38 |
+
|
| 39 |
+
return re.sub(pattern, repl, tex)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def find_main_tex(source_dir: Path) -> Path | None:
|
| 43 |
+
"""
|
| 44 |
+
Heuristic to find the main .tex file:
|
| 45 |
+
1. match .bbl → .tex
|
| 46 |
+
2. else top-level .tex that contains \\begin{document}
|
| 47 |
+
3. else first .tex in directory
|
| 48 |
+
"""
|
| 49 |
+
bbls = list(source_dir.glob("*.bbl"))
|
| 50 |
+
if bbls:
|
| 51 |
+
main_candidate = source_dir / (bbls[0].stem + ".tex")
|
| 52 |
+
if main_candidate.exists():
|
| 53 |
+
return main_candidate
|
| 54 |
+
|
| 55 |
+
for tex in source_dir.glob("*.tex"):
|
| 56 |
+
if "\\begin{document}" in read_tex(tex):
|
| 57 |
+
return tex
|
| 58 |
+
|
| 59 |
+
tex_files = list(source_dir.glob("*.tex"))
|
| 60 |
+
return tex_files[0] if tex_files else None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def preprocess_tex(source_dir: Path) -> Path | None:
|
| 64 |
+
"""
|
| 65 |
+
Given an extracted arXiv source directory, produce:
|
| 66 |
+
- a merged TeX file named 'processed_main.tex'
|
| 67 |
+
- a concatenated BibTeX file named 'references.bib'
|
| 68 |
+
Both are written in the parent directory of source_dir (the paper dir).
|
| 69 |
+
|
| 70 |
+
Then delete the extracted source_dir.
|
| 71 |
+
"""
|
| 72 |
+
main_tex = find_main_tex(source_dir)
|
| 73 |
+
if not main_tex:
|
| 74 |
+
print(f"[WARN] No main .tex found in {source_dir}")
|
| 75 |
+
shutil.rmtree(source_dir, ignore_errors=True)
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
raw = read_tex(main_tex)
|
| 79 |
+
merged = resolve_inputs(raw, main_tex.parent)
|
| 80 |
+
|
| 81 |
+
paper_dir = source_dir.parent
|
| 82 |
+
out_tex_path = paper_dir / "processed_main.tex"
|
| 83 |
+
out_tex_path.write_text(merged, encoding="utf-8")
|
| 84 |
+
|
| 85 |
+
bib_files = list(source_dir.rglob("*.bib"))
|
| 86 |
+
if bib_files:
|
| 87 |
+
bib_texts = []
|
| 88 |
+
for bib in bib_files:
|
| 89 |
+
try:
|
| 90 |
+
bib_texts.append(bib.read_text(encoding="utf-8", errors="ignore"))
|
| 91 |
+
except Exception:
|
| 92 |
+
print(f"[WARN] Could not read bib file {bib}")
|
| 93 |
+
if bib_texts:
|
| 94 |
+
bib_out = paper_dir / "references.bib"
|
| 95 |
+
bib_out.write_text("\n\n".join(bib_texts), encoding="utf-8")
|
| 96 |
+
print(f"[INFO] Wrote combined BibTeX to {bib_out}")
|
| 97 |
+
|
| 98 |
+
shutil.rmtree(source_dir, ignore_errors=True)
|
| 99 |
+
|
| 100 |
+
return out_tex_path
|
| 101 |
+
|
| 102 |
+
def _load_tex(path: Path) -> str:
|
| 103 |
+
return path.read_text(encoding="utf-8", errors="ignore")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
SECTION_PATTERN = re.compile(
|
| 107 |
+
r'\\section\*?\{([^}]*)\}',
|
| 108 |
+
flags=re.IGNORECASE
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _split_into_sections(tex: str):
|
| 113 |
+
"""
|
| 114 |
+
Returns a list of (section_title, content) in order.
|
| 115 |
+
Title is the raw LaTeX title text (without braces).
|
| 116 |
+
Content is the text from this \\section line up to (but not including)
|
| 117 |
+
the next \\section or end of document.
|
| 118 |
+
"""
|
| 119 |
+
sections = []
|
| 120 |
+
matches = list(SECTION_PATTERN.finditer(tex))
|
| 121 |
+
|
| 122 |
+
if not matches:
|
| 123 |
+
return sections
|
| 124 |
+
|
| 125 |
+
for i, m in enumerate(matches):
|
| 126 |
+
title = m.group(1).strip()
|
| 127 |
+
start = m.start()
|
| 128 |
+
end = matches[i + 1].start() if i + 1 < len(matches) else len(tex)
|
| 129 |
+
content = tex[start:end]
|
| 130 |
+
sections.append((title, content))
|
| 131 |
+
|
| 132 |
+
return sections
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _normalize_title(title: str) -> str:
|
| 136 |
+
"""Lowercase and strip punctuation-ish stuff for robust matching."""
|
| 137 |
+
t = title.lower()
|
| 138 |
+
t = re.sub(r'[^a-z0-9\s]', ' ', t)
|
| 139 |
+
t = re.sub(r'\s+', ' ', t).strip()
|
| 140 |
+
return t
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _find_best_section(sections, candidates):
|
| 144 |
+
"""
|
| 145 |
+
sections: list of (raw_title, content)
|
| 146 |
+
candidates: list of strings to match against normalized title
|
| 147 |
+
Returns the content of the best-matching section or None.
|
| 148 |
+
"""
|
| 149 |
+
norm_candidates = [c.lower() for c in candidates]
|
| 150 |
+
|
| 151 |
+
for raw_title, content in sections:
|
| 152 |
+
nt = _normalize_title(raw_title)
|
| 153 |
+
for cand in norm_candidates:
|
| 154 |
+
if nt == cand or cand in nt:
|
| 155 |
+
return content
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def extract_introduction_and_related(
|
| 160 |
+
processed_tex_path: Path,
|
| 161 |
+
out_dir: Path | None = None,
|
| 162 |
+
) -> dict:
|
| 163 |
+
"""
|
| 164 |
+
Given path to processed_main.tex, extract Introduction and Related Work sections
|
| 165 |
+
into separate .tex files.
|
| 166 |
+
|
| 167 |
+
Returns a dict with keys:
|
| 168 |
+
{
|
| 169 |
+
"introduction": Path | None,
|
| 170 |
+
"related_work": Path | None
|
| 171 |
+
}
|
| 172 |
+
"""
|
| 173 |
+
if out_dir is None:
|
| 174 |
+
out_dir = processed_tex_path.parent / "sections"
|
| 175 |
+
|
| 176 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 177 |
+
|
| 178 |
+
tex = _load_tex(processed_tex_path)
|
| 179 |
+
sections = _split_into_sections(tex)
|
| 180 |
+
|
| 181 |
+
intro_candidates = ["introduction"]
|
| 182 |
+
related_candidates = ["related work"]
|
| 183 |
+
|
| 184 |
+
intro_content = _find_best_section(sections, intro_candidates)
|
| 185 |
+
related_content = _find_best_section(sections, related_candidates)
|
| 186 |
+
|
| 187 |
+
results = {"introduction": None, "related_work": None}
|
| 188 |
+
|
| 189 |
+
if intro_content:
|
| 190 |
+
intro_path = out_dir / "introduction.tex"
|
| 191 |
+
intro_path.write_text(intro_content, encoding="utf-8")
|
| 192 |
+
results["introduction"] = intro_path
|
| 193 |
+
else:
|
| 194 |
+
print(f"[WARN] No Introduction section found in {processed_tex_path}")
|
| 195 |
+
|
| 196 |
+
if related_content:
|
| 197 |
+
rw_path = out_dir / "related_work.tex"
|
| 198 |
+
rw_path.write_text(related_content, encoding="utf-8")
|
| 199 |
+
results["related_work"] = rw_path
|
| 200 |
+
else:
|
| 201 |
+
print(f"[WARN] No Related Work section found in {processed_tex_path}")
|
| 202 |
+
|
| 203 |
+
return results
|
src/step_01_fetch/semanticscholar_client.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import random
|
| 3 |
+
import requests
|
| 4 |
+
from typing import Optional, Tuple, Any
|
| 5 |
+
from config import SEMANTIC_SCHOLAR_API_KEY
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
BASE_URL = "https://api.semanticscholar.org/graph/v1/paper"
|
| 9 |
+
|
| 10 |
+
_LAST_REQUEST_TS = 0.0
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _min_interval_sleep() -> None:
|
| 14 |
+
"""Global throttle to avoid hammering Semantic Scholar."""
|
| 15 |
+
global _LAST_REQUEST_TS
|
| 16 |
+
min_interval = float(os.getenv("S2_MIN_INTERVAL", "1.0"))
|
| 17 |
+
now = time.monotonic()
|
| 18 |
+
elapsed = now - _LAST_REQUEST_TS
|
| 19 |
+
if elapsed < min_interval:
|
| 20 |
+
time.sleep(min_interval - elapsed)
|
| 21 |
+
_LAST_REQUEST_TS = time.monotonic()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def robust_request(url, params=None, headers=None, max_retries=8, base_sleep=2.0):
|
| 25 |
+
"""
|
| 26 |
+
Make a GET request with exponential backoff.
|
| 27 |
+
Retries on:
|
| 28 |
+
- connection errors
|
| 29 |
+
- 429 (Too Many Requests)
|
| 30 |
+
- 500–599 server errors
|
| 31 |
+
- invalid JSON
|
| 32 |
+
Returns (status_code, json_or_None).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
for attempt in range(max_retries):
|
| 36 |
+
try:
|
| 37 |
+
_min_interval_sleep()
|
| 38 |
+
resp = requests.get(url, params=params, headers=headers, timeout=30)
|
| 39 |
+
status = resp.status_code
|
| 40 |
+
|
| 41 |
+
if status == 200:
|
| 42 |
+
try:
|
| 43 |
+
return 200, resp.json()
|
| 44 |
+
except Exception:
|
| 45 |
+
print(f"[WARN] JSON decode failed on attempt {attempt+1}/{max_retries}")
|
| 46 |
+
|
| 47 |
+
if status == 429:
|
| 48 |
+
retry_after = resp.headers.get("Retry-After")
|
| 49 |
+
if retry_after:
|
| 50 |
+
try:
|
| 51 |
+
sleep = float(retry_after)
|
| 52 |
+
except Exception:
|
| 53 |
+
sleep = base_sleep * (2 ** attempt)
|
| 54 |
+
else:
|
| 55 |
+
sleep = base_sleep * (2 ** attempt)
|
| 56 |
+
max_sleep = float(os.getenv("S2_MAX_BACKOFF", "60"))
|
| 57 |
+
sleep = min(sleep, max_sleep)
|
| 58 |
+
sleep += random.uniform(0.0, 0.5)
|
| 59 |
+
print(f"[WARN] 429 Too Many Requests → retrying in {sleep:.2f}s")
|
| 60 |
+
time.sleep(sleep)
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
if 500 <= status < 600:
|
| 64 |
+
sleep = base_sleep * (2 ** attempt)
|
| 65 |
+
max_sleep = float(os.getenv("S2_MAX_BACKOFF", "60"))
|
| 66 |
+
sleep = min(sleep, max_sleep)
|
| 67 |
+
sleep += random.uniform(0.0, 0.5)
|
| 68 |
+
print(f"[WARN] Server error {status} → retrying in {sleep:.2f}s")
|
| 69 |
+
time.sleep(sleep)
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
return status, None
|
| 73 |
+
|
| 74 |
+
except requests.exceptions.RequestException as e:
|
| 75 |
+
sleep = base_sleep * (2 ** attempt)
|
| 76 |
+
max_sleep = float(os.getenv("S2_MAX_BACKOFF", "60"))
|
| 77 |
+
sleep = min(sleep, max_sleep)
|
| 78 |
+
sleep += random.uniform(0.0, 0.5)
|
| 79 |
+
print(f"[WARN] Network error {e} → retrying in {sleep:.2f}s")
|
| 80 |
+
time.sleep(sleep)
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
print(f"[ERROR] Giving up after {max_retries} attempts for URL: {url}")
|
| 84 |
+
return None, None
|
| 85 |
+
def get_paper(paper_id: str, id_type: str = "ACL") -> Tuple[int, Optional[dict]]:
|
| 86 |
+
"""
|
| 87 |
+
id_type can be "ACL" or "SemanticScholar" or "ArXiv" etc.
|
| 88 |
+
"""
|
| 89 |
+
if id_type == "SemanticScholar":
|
| 90 |
+
full_id = paper_id
|
| 91 |
+
else:
|
| 92 |
+
full_id = f"{id_type}:{paper_id}"
|
| 93 |
+
|
| 94 |
+
url = f"{BASE_URL}/{full_id}"
|
| 95 |
+
params = {
|
| 96 |
+
"fields": (
|
| 97 |
+
"title,year,publicationDate,authors,url,venue,externalIds,"
|
| 98 |
+
"tldr,abstract,citationCount,referenceCount,openAccessPdf"
|
| 99 |
+
)
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
headers = {"x-api-key": SEMANTIC_SCHOLAR_API_KEY} if SEMANTIC_SCHOLAR_API_KEY else {}
|
| 103 |
+
|
| 104 |
+
status, data = robust_request(url, params=params, headers=headers, max_retries=5, base_sleep=1.0)
|
| 105 |
+
if status == 200 and data is not None:
|
| 106 |
+
return status, data
|
| 107 |
+
else:
|
| 108 |
+
print(f"[WARN] {status} on {full_id}")
|
| 109 |
+
return status or 0, None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_paper_links(semantic_id: str, target_type: str, total: int, limit: int = 1000):
|
| 113 |
+
headers = {"x-api-key": SEMANTIC_SCHOLAR_API_KEY} if SEMANTIC_SCHOLAR_API_KEY else {}
|
| 114 |
+
loops = total // limit + 1 if total else 0
|
| 115 |
+
collected = []
|
| 116 |
+
|
| 117 |
+
for i in range(loops):
|
| 118 |
+
offset = i * limit
|
| 119 |
+
url = f"{BASE_URL}/{semantic_id}/{target_type}"
|
| 120 |
+
params = {
|
| 121 |
+
"offset": offset,
|
| 122 |
+
"limit": limit,
|
| 123 |
+
"fields": "paperId,title,isInfluential,externalIds,contextsWithIntent,openAccessPdf",
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
status, data = robust_request(url, params=params, headers=headers, max_retries=5, base_sleep=1.0)
|
| 127 |
+
|
| 128 |
+
if status != 200 or data is None:
|
| 129 |
+
print(f"[WARN] {target_type} fetch failed for {semantic_id} (status {status})")
|
| 130 |
+
return status or 0, []
|
| 131 |
+
|
| 132 |
+
items = data.get("data")
|
| 133 |
+
if not isinstance(items, list):
|
| 134 |
+
print(f"[WARN] malformed {target_type} response for {semantic_id}")
|
| 135 |
+
return status, []
|
| 136 |
+
|
| 137 |
+
collected.extend(items)
|
| 138 |
+
|
| 139 |
+
return 200, collected
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def search_by_title(title: str, limit: int = 1):
|
| 143 |
+
"""Search Semantic Scholar by paper title."""
|
| 144 |
+
url = "https://api.semanticscholar.org/graph/v1/paper/search"
|
| 145 |
+
params = {
|
| 146 |
+
"query": title,
|
| 147 |
+
"limit": limit,
|
| 148 |
+
"fields": "paperId,title,year,venue,externalIds",
|
| 149 |
+
}
|
| 150 |
+
headers = {"x-api-key": SEMANTIC_SCHOLAR_API_KEY} if SEMANTIC_SCHOLAR_API_KEY else {}
|
| 151 |
+
|
| 152 |
+
status, data = robust_request(url, params=params, headers=headers, max_retries=5, base_sleep=1.0)
|
| 153 |
+
if status == 200 and data is not None:
|
| 154 |
+
items = data.get("data", [])
|
| 155 |
+
return items[0] if items else None
|
| 156 |
+
else:
|
| 157 |
+
print(f"[WARN] title search failed for '{title[:60]}...' (status {status})")
|
| 158 |
+
return None
|
src/step_02_mark_citations/replace_citation_markers.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List, Tuple
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
PAPER_META_FILE = "paper_metadata.json"
|
| 9 |
+
USAGE_CLAIMS_FILE = "usage_claims.json"
|
| 10 |
+
USAGE_CONTEXTS_FILE = "usage_contexts.json"
|
| 11 |
+
CITATIONS_FILE = "citations_metadata.json"
|
| 12 |
+
PROCESSED_MAIN_FILE = "processed_main.tex"
|
| 13 |
+
REFERENCES_META_FILE = "references_metadata.json"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_json(path: Path) -> Any | None:
|
| 17 |
+
if not path.exists():
|
| 18 |
+
return None
|
| 19 |
+
try:
|
| 20 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 21 |
+
except Exception:
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def save_json(path: Path, data: Any) -> None:
|
| 26 |
+
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def iter_paper_dirs(root: Path) -> List[Path]:
|
| 30 |
+
out: List[Path] = []
|
| 31 |
+
for child in root.iterdir():
|
| 32 |
+
if child.is_dir() and (child / PAPER_META_FILE).exists():
|
| 33 |
+
out.append(child)
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def load_paper_metadata(paper_dir: Path) -> Dict[str, Any]:
|
| 38 |
+
meta = load_json(paper_dir / PAPER_META_FILE)
|
| 39 |
+
if isinstance(meta, list) and meta:
|
| 40 |
+
return meta[0]
|
| 41 |
+
if isinstance(meta, dict):
|
| 42 |
+
return meta
|
| 43 |
+
return {}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _is_structurally_complete(paper_dir: Path) -> bool:
|
| 47 |
+
return (
|
| 48 |
+
(paper_dir / PAPER_META_FILE).exists()
|
| 49 |
+
and (paper_dir / PROCESSED_MAIN_FILE).exists()
|
| 50 |
+
and (paper_dir / REFERENCES_META_FILE).exists()
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _author_last_names(authors: List[Any]) -> List[str]:
|
| 55 |
+
last_names: List[str] = []
|
| 56 |
+
for author in authors:
|
| 57 |
+
if isinstance(author, dict):
|
| 58 |
+
name = author.get("name")
|
| 59 |
+
else:
|
| 60 |
+
name = author
|
| 61 |
+
if not isinstance(name, str):
|
| 62 |
+
continue
|
| 63 |
+
parts = [p for p in re.split(r"\s+", name.strip()) if p]
|
| 64 |
+
if not parts:
|
| 65 |
+
continue
|
| 66 |
+
last_names.append(parts[-1])
|
| 67 |
+
return list(dict.fromkeys(last_names))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _title_aliases(title: str) -> List[str]:
|
| 71 |
+
aliases = [title]
|
| 72 |
+
if ":" in title:
|
| 73 |
+
aliases.append(title.split(":", 1)[0])
|
| 74 |
+
acronym = "".join([c for c in title if c.isupper()])
|
| 75 |
+
if 3 <= len(acronym) <= 10:
|
| 76 |
+
aliases.append(acronym)
|
| 77 |
+
return list(dict.fromkeys([a for a in aliases if a]))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _artifact_aliases(paper_dir: Path) -> List[str]:
|
| 81 |
+
aliases: List[str] = []
|
| 82 |
+
usage_claims = load_json(paper_dir / USAGE_CLAIMS_FILE)
|
| 83 |
+
if isinstance(usage_claims, dict):
|
| 84 |
+
caps = usage_claims.get("capabilities") or []
|
| 85 |
+
if isinstance(caps, list):
|
| 86 |
+
for cap in caps:
|
| 87 |
+
if not isinstance(cap, dict):
|
| 88 |
+
continue
|
| 89 |
+
name = cap.get("artifact_name")
|
| 90 |
+
if isinstance(name, str) and name.strip():
|
| 91 |
+
aliases.append(name.strip())
|
| 92 |
+
return list(dict.fromkeys(aliases))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _loose_alias_pattern(alias: str) -> str:
|
| 96 |
+
parts = re.split(r"[^A-Za-z0-9]+", alias)
|
| 97 |
+
parts = [p for p in parts if p]
|
| 98 |
+
if not parts:
|
| 99 |
+
return ""
|
| 100 |
+
return r"\b" + r"[-\s]*".join(map(re.escape, parts)) + r"\b"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def build_patterns(
|
| 104 |
+
meta: Dict[str, Any],
|
| 105 |
+
paper_dir: Path,
|
| 106 |
+
) -> Tuple[List[re.Pattern], List[re.Pattern], str | None]:
|
| 107 |
+
year = meta.get("year")
|
| 108 |
+
year_str = str(year) if isinstance(year, int) else None
|
| 109 |
+
authors = meta.get("authors") if isinstance(meta.get("authors"), list) else []
|
| 110 |
+
last_names = _author_last_names(authors)
|
| 111 |
+
title = meta.get("title") if isinstance(meta.get("title"), str) else ""
|
| 112 |
+
aliases = _title_aliases(title) + _artifact_aliases(paper_dir)
|
| 113 |
+
aliases = [a for a in aliases if a]
|
| 114 |
+
|
| 115 |
+
author_patterns: List[re.Pattern] = []
|
| 116 |
+
alias_patterns: List[re.Pattern] = []
|
| 117 |
+
|
| 118 |
+
if year_str and last_names:
|
| 119 |
+
year_pat = rf"{re.escape(year_str)}[a-z]?"
|
| 120 |
+
first_last = re.escape(last_names[0])
|
| 121 |
+
author_patterns.append(
|
| 122 |
+
re.compile(
|
| 123 |
+
rf"\b{first_last}\s+et\s+al\.?\s*(?:,\s*|\s*){year_pat}",
|
| 124 |
+
re.IGNORECASE,
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
for alias in aliases:
|
| 129 |
+
pat = _loose_alias_pattern(alias)
|
| 130 |
+
if pat:
|
| 131 |
+
alias_patterns.append(re.compile(pat, re.IGNORECASE))
|
| 132 |
+
|
| 133 |
+
first_last = last_names[0] if last_names else None
|
| 134 |
+
return author_patterns, alias_patterns, first_last
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _replace_author_span(text: str, first_last: str) -> Tuple[str, bool]:
|
| 138 |
+
occurrences = list(re.finditer(rf"\b{re.escape(first_last)}\b", text, re.IGNORECASE))
|
| 139 |
+
if len(occurrences) != 1:
|
| 140 |
+
return text, False
|
| 141 |
+
author_pat = re.compile(
|
| 142 |
+
rf"\(?\b{re.escape(first_last)}\b"
|
| 143 |
+
rf"\s+(?:et\s+al\.?|and|&)\s*"
|
| 144 |
+
rf"(?:,?\s*\(?\d{{4}}[a-z]?\)?)?"
|
| 145 |
+
rf"\)?",
|
| 146 |
+
re.IGNORECASE,
|
| 147 |
+
)
|
| 148 |
+
new_text, count = author_pat.subn("<CITED HERE>", text, count=1)
|
| 149 |
+
return new_text, count > 0
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
_BRACKET_NUM_RE = re.compile(r"\[[0-9,;\s]+\]")
|
| 153 |
+
_BRACKET_GROUP_RE = re.compile(r"\[([0-9,;\s]+)\]")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _extract_bracket_numbers(text: str) -> List[str]:
|
| 157 |
+
numbers: List[str] = []
|
| 158 |
+
for match in _BRACKET_GROUP_RE.finditer(text):
|
| 159 |
+
parts = re.split(r"[,\s;]+", match.group(1).strip())
|
| 160 |
+
for part in parts:
|
| 161 |
+
if part.isdigit():
|
| 162 |
+
numbers.append(part)
|
| 163 |
+
return numbers
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _dominant_bracket(contexts: List[Dict[str, Any]]) -> str | None:
|
| 167 |
+
counts: Dict[str, int] = {}
|
| 168 |
+
for ctx in contexts:
|
| 169 |
+
if not isinstance(ctx, dict):
|
| 170 |
+
continue
|
| 171 |
+
text = ctx.get("context") or ctx.get("text")
|
| 172 |
+
if not isinstance(text, str):
|
| 173 |
+
continue
|
| 174 |
+
for num in _extract_bracket_numbers(text):
|
| 175 |
+
counts[num] = counts.get(num, 0) + 1
|
| 176 |
+
if not counts:
|
| 177 |
+
return None
|
| 178 |
+
best = max(counts.values())
|
| 179 |
+
winners = [num for num, count in counts.items() if count == best]
|
| 180 |
+
if len(winners) == 1:
|
| 181 |
+
return winners[0]
|
| 182 |
+
return None
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _single_bracket_candidate(contexts: List[Dict[str, Any]]) -> str | None:
|
| 186 |
+
counts: Dict[str, int] = {}
|
| 187 |
+
for ctx in contexts:
|
| 188 |
+
if not isinstance(ctx, dict):
|
| 189 |
+
continue
|
| 190 |
+
text = ctx.get("context") or ctx.get("text")
|
| 191 |
+
if not isinstance(text, str):
|
| 192 |
+
continue
|
| 193 |
+
matches = list(_BRACKET_GROUP_RE.finditer(text))
|
| 194 |
+
if len(matches) == 1:
|
| 195 |
+
nums = _extract_bracket_numbers(text)
|
| 196 |
+
if len(nums) != 1:
|
| 197 |
+
continue
|
| 198 |
+
num = nums[0]
|
| 199 |
+
counts[num] = counts.get(num, 0) + 1
|
| 200 |
+
if not counts:
|
| 201 |
+
return None
|
| 202 |
+
best = max(counts.values())
|
| 203 |
+
winners = [num for num, count in counts.items() if count == best]
|
| 204 |
+
if len(winners) == 1:
|
| 205 |
+
return winners[0]
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _replace_single_bracket(text: str, dominant: str | None) -> Tuple[str, bool]:
|
| 210 |
+
matches = list(_BRACKET_GROUP_RE.finditer(text))
|
| 211 |
+
if len(matches) != 1:
|
| 212 |
+
return text, False
|
| 213 |
+
nums = _extract_bracket_numbers(text)
|
| 214 |
+
if len(nums) != 1:
|
| 215 |
+
return text, False
|
| 216 |
+
num = nums[0]
|
| 217 |
+
if dominant is not None and num != dominant:
|
| 218 |
+
return text, False
|
| 219 |
+
start, end = matches[0].span()
|
| 220 |
+
return text[:start] + "<CITED HERE>" + text[end:], True
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def replace_with_marker(
|
| 224 |
+
text: str,
|
| 225 |
+
author_patterns: List[re.Pattern],
|
| 226 |
+
alias_patterns: List[re.Pattern],
|
| 227 |
+
dominant_bracket: str | None = None,
|
| 228 |
+
first_author_last: str | None = None,
|
| 229 |
+
) -> Tuple[str, bool]:
|
| 230 |
+
def _collapse_markers(value: str) -> str:
|
| 231 |
+
value = re.sub(r"(?:<CITED HERE>[\s()\[\],;:]*){2,}", "<CITED HERE> ", value)
|
| 232 |
+
value = re.sub(r"<CITED HERE>(?:\s+<CITED HERE>)+", "<CITED HERE>", value)
|
| 233 |
+
return value.strip()
|
| 234 |
+
|
| 235 |
+
updated = text
|
| 236 |
+
changed = False
|
| 237 |
+
|
| 238 |
+
author_changed = False
|
| 239 |
+
if first_author_last:
|
| 240 |
+
new, author_changed = _replace_author_span(updated, first_author_last)
|
| 241 |
+
if author_changed:
|
| 242 |
+
changed = True
|
| 243 |
+
updated = _collapse_markers(new)
|
| 244 |
+
|
| 245 |
+
if dominant_bracket:
|
| 246 |
+
def _replace_if_contains(match: re.Match) -> str:
|
| 247 |
+
nums = re.split(r"[,\s;]+", match.group(1).strip())
|
| 248 |
+
if any(n == dominant_bracket for n in nums if n.isdigit()):
|
| 249 |
+
return "<CITED HERE>"
|
| 250 |
+
return match.group(0)
|
| 251 |
+
|
| 252 |
+
new = _BRACKET_GROUP_RE.sub(_replace_if_contains, updated)
|
| 253 |
+
if new != updated:
|
| 254 |
+
changed = True
|
| 255 |
+
updated = _collapse_markers(new)
|
| 256 |
+
|
| 257 |
+
for pat in author_patterns:
|
| 258 |
+
new = pat.sub("<CITED HERE>", updated)
|
| 259 |
+
if new != updated:
|
| 260 |
+
changed = True
|
| 261 |
+
updated = _collapse_markers(new)
|
| 262 |
+
|
| 263 |
+
if not author_changed:
|
| 264 |
+
for pat in alias_patterns:
|
| 265 |
+
new = pat.sub("<CITED HERE>", updated)
|
| 266 |
+
if new != updated:
|
| 267 |
+
changed = True
|
| 268 |
+
updated = _collapse_markers(new)
|
| 269 |
+
|
| 270 |
+
new, bracket_changed = _replace_single_bracket(updated, dominant_bracket)
|
| 271 |
+
if bracket_changed:
|
| 272 |
+
changed = True
|
| 273 |
+
updated = _collapse_markers(new)
|
| 274 |
+
|
| 275 |
+
updated = _collapse_markers(updated)
|
| 276 |
+
return updated, changed
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _process_contexts(
|
| 280 |
+
contexts: List[Dict[str, Any]],
|
| 281 |
+
author_patterns: List[re.Pattern],
|
| 282 |
+
alias_patterns: List[re.Pattern],
|
| 283 |
+
dominant_bracket: str | None,
|
| 284 |
+
first_author_last: str | None,
|
| 285 |
+
) -> Tuple[int, int]:
|
| 286 |
+
updated_count = 0
|
| 287 |
+
total = 0
|
| 288 |
+
for ctx in contexts:
|
| 289 |
+
if not isinstance(ctx, dict):
|
| 290 |
+
continue
|
| 291 |
+
text = ctx.get("context") or ctx.get("text")
|
| 292 |
+
if not isinstance(text, str):
|
| 293 |
+
continue
|
| 294 |
+
total += 1
|
| 295 |
+
new_text, changed = replace_with_marker(
|
| 296 |
+
text,
|
| 297 |
+
author_patterns=author_patterns,
|
| 298 |
+
alias_patterns=alias_patterns,
|
| 299 |
+
dominant_bracket=dominant_bracket,
|
| 300 |
+
first_author_last=first_author_last,
|
| 301 |
+
)
|
| 302 |
+
if changed:
|
| 303 |
+
updated_count += 1
|
| 304 |
+
ctx["context_with_marker"] = new_text
|
| 305 |
+
return updated_count, total
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def update_citations_file(
|
| 309 |
+
paper_dir: Path,
|
| 310 |
+
author_patterns: List[re.Pattern],
|
| 311 |
+
alias_patterns: List[re.Pattern],
|
| 312 |
+
first_author_last: str | None,
|
| 313 |
+
) -> Tuple[int, int]:
|
| 314 |
+
path = paper_dir / CITATIONS_FILE
|
| 315 |
+
data = load_json(path)
|
| 316 |
+
if not isinstance(data, list):
|
| 317 |
+
return 0, 0
|
| 318 |
+
updated = 0
|
| 319 |
+
total = 0
|
| 320 |
+
for entry in data:
|
| 321 |
+
if not isinstance(entry, dict):
|
| 322 |
+
continue
|
| 323 |
+
ctxs = entry.get("contextsWithIntent") or []
|
| 324 |
+
if isinstance(ctxs, list):
|
| 325 |
+
dominant = _dominant_bracket(ctxs)
|
| 326 |
+
if dominant is None:
|
| 327 |
+
dominant = _single_bracket_candidate(ctxs)
|
| 328 |
+
upd, tot = _process_contexts(
|
| 329 |
+
ctxs,
|
| 330 |
+
author_patterns,
|
| 331 |
+
alias_patterns,
|
| 332 |
+
dominant,
|
| 333 |
+
first_author_last,
|
| 334 |
+
)
|
| 335 |
+
updated += upd
|
| 336 |
+
total += tot
|
| 337 |
+
save_json(path, data)
|
| 338 |
+
return updated, total
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def update_usage_contexts_file(
|
| 342 |
+
paper_dir: Path,
|
| 343 |
+
author_patterns: List[re.Pattern],
|
| 344 |
+
alias_patterns: List[re.Pattern],
|
| 345 |
+
first_author_last: str | None,
|
| 346 |
+
) -> Tuple[int, int]:
|
| 347 |
+
path = paper_dir / USAGE_CONTEXTS_FILE
|
| 348 |
+
data = load_json(path)
|
| 349 |
+
if not isinstance(data, dict):
|
| 350 |
+
return 0, 0
|
| 351 |
+
updated = 0
|
| 352 |
+
total = 0
|
| 353 |
+
for entry in data.get("citing_papers", []) or []:
|
| 354 |
+
if not isinstance(entry, dict):
|
| 355 |
+
continue
|
| 356 |
+
ctxs = entry.get("contexts") or []
|
| 357 |
+
if isinstance(ctxs, list):
|
| 358 |
+
dominant = _dominant_bracket(ctxs)
|
| 359 |
+
if dominant is None:
|
| 360 |
+
dominant = _single_bracket_candidate(ctxs)
|
| 361 |
+
upd, tot = _process_contexts(
|
| 362 |
+
ctxs,
|
| 363 |
+
author_patterns,
|
| 364 |
+
alias_patterns,
|
| 365 |
+
dominant,
|
| 366 |
+
first_author_last,
|
| 367 |
+
)
|
| 368 |
+
updated += upd
|
| 369 |
+
total += tot
|
| 370 |
+
save_json(path, data)
|
| 371 |
+
return updated, total
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def main() -> None:
|
| 375 |
+
parser = argparse.ArgumentParser(
|
| 376 |
+
description="Replace citation mentions with <CITED HERE> in context fields."
|
| 377 |
+
)
|
| 378 |
+
parser.add_argument(
|
| 379 |
+
"--root",
|
| 380 |
+
type=str,
|
| 381 |
+
default="runs/processed_papers",
|
| 382 |
+
help="Root directory containing processed paper directories.",
|
| 383 |
+
)
|
| 384 |
+
parser.add_argument(
|
| 385 |
+
"--usage-contexts",
|
| 386 |
+
action="store_true",
|
| 387 |
+
help="Also update usage_contexts.json.",
|
| 388 |
+
)
|
| 389 |
+
args = parser.parse_args()
|
| 390 |
+
|
| 391 |
+
root = Path(args.root).expanduser().resolve()
|
| 392 |
+
if not root.exists():
|
| 393 |
+
raise SystemExit(f"Root directory does not exist: {root}")
|
| 394 |
+
|
| 395 |
+
paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
|
| 396 |
+
print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
|
| 397 |
+
total_updated = 0
|
| 398 |
+
total_contexts = 0
|
| 399 |
+
skipped_incomplete = 0
|
| 400 |
+
|
| 401 |
+
for paper_dir in paper_dirs:
|
| 402 |
+
if not _is_structurally_complete(paper_dir):
|
| 403 |
+
skipped_incomplete += 1
|
| 404 |
+
continue
|
| 405 |
+
meta = load_paper_metadata(paper_dir)
|
| 406 |
+
if not meta:
|
| 407 |
+
continue
|
| 408 |
+
author_patterns, alias_patterns, first_author_last = build_patterns(meta, paper_dir)
|
| 409 |
+
if not (author_patterns or alias_patterns):
|
| 410 |
+
continue
|
| 411 |
+
updated, total = update_citations_file(
|
| 412 |
+
paper_dir,
|
| 413 |
+
author_patterns,
|
| 414 |
+
alias_patterns,
|
| 415 |
+
first_author_last,
|
| 416 |
+
)
|
| 417 |
+
total_updated += updated
|
| 418 |
+
total_contexts += total
|
| 419 |
+
if args.usage_contexts:
|
| 420 |
+
upd_usage, tot_usage = update_usage_contexts_file(
|
| 421 |
+
paper_dir,
|
| 422 |
+
author_patterns,
|
| 423 |
+
alias_patterns,
|
| 424 |
+
first_author_last,
|
| 425 |
+
)
|
| 426 |
+
updated += upd_usage
|
| 427 |
+
total += tot_usage
|
| 428 |
+
total_updated += upd_usage
|
| 429 |
+
total_contexts += tot_usage
|
| 430 |
+
if total:
|
| 431 |
+
print(f"[OK] {paper_dir.name}: updated {updated} contexts over {total}")
|
| 432 |
+
|
| 433 |
+
print(
|
| 434 |
+
f"[SUMMARY] total_updated={total_updated} over {total_contexts}; "
|
| 435 |
+
f"skipped_incomplete={skipped_incomplete}"
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
if __name__ == "__main__":
|
| 440 |
+
main()
|
src/step_03_usage_contexts/build_usage_contexts.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
PAPER_META_FILE = "paper_metadata.json"
|
| 8 |
+
CITATIONS_FILE = "citations_metadata.json"
|
| 9 |
+
DEFAULT_OUT_NAME = "usage_contexts.json"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_json(path: Path) -> Any | None:
|
| 13 |
+
if not path.exists():
|
| 14 |
+
return None
|
| 15 |
+
try:
|
| 16 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"[WARN] could not parse JSON at {path}: {e}")
|
| 19 |
+
return None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def iter_paper_dirs(root: Path) -> List[Path]:
|
| 23 |
+
out: List[Path] = []
|
| 24 |
+
for child in root.iterdir():
|
| 25 |
+
if child.is_dir() and (child / PAPER_META_FILE).exists():
|
| 26 |
+
out.append(child)
|
| 27 |
+
return out
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _extract_contexts(item: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 31 |
+
contexts: List[Dict[str, Any]] = []
|
| 32 |
+
|
| 33 |
+
raw = item.get("contextsWithIntent") or []
|
| 34 |
+
if isinstance(raw, list) and raw:
|
| 35 |
+
for entry in raw:
|
| 36 |
+
if not isinstance(entry, dict):
|
| 37 |
+
continue
|
| 38 |
+
text_raw = (entry.get("context") or "").strip()
|
| 39 |
+
text = (entry.get("context_with_marker") or text_raw).strip()
|
| 40 |
+
intents = entry.get("intents") or []
|
| 41 |
+
contexts.append(
|
| 42 |
+
{
|
| 43 |
+
"text": text,
|
| 44 |
+
"text_raw": text_raw,
|
| 45 |
+
"intents": intents,
|
| 46 |
+
}
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Fallback for older schema that only stores raw context strings.
|
| 50 |
+
if not contexts:
|
| 51 |
+
raw_alt = item.get("contexts") or []
|
| 52 |
+
if isinstance(raw_alt, list):
|
| 53 |
+
for text in raw_alt:
|
| 54 |
+
if not isinstance(text, str):
|
| 55 |
+
continue
|
| 56 |
+
text = text.strip()
|
| 57 |
+
if text:
|
| 58 |
+
contexts.append(
|
| 59 |
+
{
|
| 60 |
+
"text": text,
|
| 61 |
+
"intents": [],
|
| 62 |
+
}
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return contexts
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def build_usage_contexts_for_paper(paper_dir: Path) -> Optional[Dict[str, Any]]:
|
| 69 |
+
citations_path = paper_dir / CITATIONS_FILE
|
| 70 |
+
data = load_json(citations_path)
|
| 71 |
+
if data is None:
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
if not isinstance(data, list):
|
| 75 |
+
print(f"[WARN] {paper_dir.name}: {CITATIONS_FILE} is not a list")
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
citing_entries: List[Dict[str, Any]] = []
|
| 79 |
+
total_contexts = 0
|
| 80 |
+
citing_with_context = 0
|
| 81 |
+
influential_citations = 0
|
| 82 |
+
influential_with_context = 0
|
| 83 |
+
influential_contexts: List[Dict[str, Any]] = []
|
| 84 |
+
|
| 85 |
+
for item in data:
|
| 86 |
+
if not isinstance(item, dict):
|
| 87 |
+
continue
|
| 88 |
+
citing = item.get("citingPaper") or {}
|
| 89 |
+
|
| 90 |
+
contexts = _extract_contexts(item)
|
| 91 |
+
is_influential = bool(item.get("isInfluential", False))
|
| 92 |
+
if is_influential:
|
| 93 |
+
influential_citations += 1
|
| 94 |
+
if contexts:
|
| 95 |
+
citing_with_context += 1
|
| 96 |
+
total_contexts += len(contexts)
|
| 97 |
+
if is_influential:
|
| 98 |
+
influential_with_context += 1
|
| 99 |
+
|
| 100 |
+
citing_entries.append(
|
| 101 |
+
{
|
| 102 |
+
"citing_paper_id": citing.get("paperId"),
|
| 103 |
+
"title": citing.get("title"),
|
| 104 |
+
"external_ids": citing.get("externalIds") or {},
|
| 105 |
+
"is_influential": is_influential,
|
| 106 |
+
"contexts": contexts,
|
| 107 |
+
}
|
| 108 |
+
)
|
| 109 |
+
if is_influential and contexts:
|
| 110 |
+
influential_contexts.append(
|
| 111 |
+
{
|
| 112 |
+
"citing_paper_id": citing.get("paperId"),
|
| 113 |
+
"title": citing.get("title"),
|
| 114 |
+
"external_ids": citing.get("externalIds") or {},
|
| 115 |
+
"contexts": contexts,
|
| 116 |
+
}
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
payload = {
|
| 120 |
+
"paper_id": paper_dir.name,
|
| 121 |
+
"total_citations": len(data),
|
| 122 |
+
"num_contexts": total_contexts,
|
| 123 |
+
"num_citing_with_context": citing_with_context,
|
| 124 |
+
"num_citing_without_context": len(data) - citing_with_context,
|
| 125 |
+
"num_influential_citations": influential_citations,
|
| 126 |
+
"num_influential_with_context": influential_with_context,
|
| 127 |
+
"influential_contexts": influential_contexts,
|
| 128 |
+
"citing_papers": citing_entries,
|
| 129 |
+
}
|
| 130 |
+
return payload
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def run(root: Path, out_name: str, overwrite: bool) -> None:
|
| 134 |
+
root = root.resolve()
|
| 135 |
+
if not root.exists():
|
| 136 |
+
raise SystemExit(f"Root directory does not exist: {root}")
|
| 137 |
+
|
| 138 |
+
paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
|
| 139 |
+
print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
|
| 140 |
+
|
| 141 |
+
for paper_dir in paper_dirs:
|
| 142 |
+
out_path = paper_dir / out_name
|
| 143 |
+
if out_path.exists() and not overwrite:
|
| 144 |
+
print(f"[SKIP] {paper_dir.name}: {out_name} already exists")
|
| 145 |
+
continue
|
| 146 |
+
payload = build_usage_contexts_for_paper(paper_dir)
|
| 147 |
+
if payload is None:
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 151 |
+
print(
|
| 152 |
+
f"[OK] {paper_dir.name}: wrote {out_name} "
|
| 153 |
+
f"({payload['num_contexts']} contexts from {payload['total_citations']} citations)"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def main() -> None:
|
| 158 |
+
parser = argparse.ArgumentParser(
|
| 159 |
+
description="Build usage_contexts.json from citations_metadata.json files."
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--root",
|
| 163 |
+
type=str,
|
| 164 |
+
default="processed_papers/acl_2024",
|
| 165 |
+
help="Root directory containing processed_papers/acl_2024/<paper_id> dirs.",
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--out-name",
|
| 169 |
+
type=str,
|
| 170 |
+
default=DEFAULT_OUT_NAME,
|
| 171 |
+
help="Output filename to write inside each paper dir.",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--overwrite",
|
| 175 |
+
action="store_true",
|
| 176 |
+
help="Overwrite existing usage_contexts.json files.",
|
| 177 |
+
)
|
| 178 |
+
args = parser.parse_args()
|
| 179 |
+
|
| 180 |
+
run(Path(args.root), out_name=args.out_name, overwrite=args.overwrite)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
main()
|
src/step_04_label_citations/label_citation_functions.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
DEEP_CITATION_ROOT = Path(__file__).resolve().parents[2] / "Deep-Citation"
|
| 8 |
+
if not DEEP_CITATION_ROOT.exists():
|
| 9 |
+
raise SystemExit(f"Deep-Citation repo not found at {DEEP_CITATION_ROOT}")
|
| 10 |
+
|
| 11 |
+
sys.path.insert(0, str(DEEP_CITATION_ROOT))
|
| 12 |
+
|
| 13 |
+
from data import CollateFn, create_data_channels
|
| 14 |
+
from Model import MultiHeadLanguageModel
|
| 15 |
+
import torch
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
PAPER_META_FILE = "paper_metadata.json"
|
| 20 |
+
USAGE_CONTEXTS_FILE = "usage_contexts.json"
|
| 21 |
+
OUT_FILE = "usage_context_labels.json"
|
| 22 |
+
|
| 23 |
+
LABEL_SET = [
|
| 24 |
+
"Background",
|
| 25 |
+
"Uses",
|
| 26 |
+
"Extends",
|
| 27 |
+
"CompareOrContrast",
|
| 28 |
+
"Motivation",
|
| 29 |
+
"Future",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_json(path: Path) -> Any | None:
|
| 34 |
+
if not path.exists():
|
| 35 |
+
return None
|
| 36 |
+
try:
|
| 37 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 38 |
+
except Exception:
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def iter_paper_dirs(root: Path) -> List[Path]:
|
| 43 |
+
out: List[Path] = []
|
| 44 |
+
for child in root.iterdir():
|
| 45 |
+
if child.is_dir() and (child / PAPER_META_FILE).exists():
|
| 46 |
+
out.append(child)
|
| 47 |
+
return out
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def flatten_contexts(usage: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 51 |
+
contexts: List[Dict[str, Any]] = []
|
| 52 |
+
idx = 1
|
| 53 |
+
for entry in usage.get("citing_papers", []) or []:
|
| 54 |
+
if not isinstance(entry, dict):
|
| 55 |
+
continue
|
| 56 |
+
citing_title = entry.get("title") or "Unknown citing paper"
|
| 57 |
+
citing_paper_id = entry.get("citing_paper_id") or ""
|
| 58 |
+
for c in entry.get("contexts", []) or []:
|
| 59 |
+
if not isinstance(c, dict):
|
| 60 |
+
continue
|
| 61 |
+
text = (c.get("text") or "").strip()
|
| 62 |
+
if not text:
|
| 63 |
+
continue
|
| 64 |
+
contexts.append(
|
| 65 |
+
{
|
| 66 |
+
"id": idx,
|
| 67 |
+
"text": text,
|
| 68 |
+
"citing_title": citing_title,
|
| 69 |
+
"citing_paper_id": citing_paper_id,
|
| 70 |
+
}
|
| 71 |
+
)
|
| 72 |
+
idx += 1
|
| 73 |
+
return contexts
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _resolve_model_name(lm: str) -> str:
|
| 77 |
+
if lm == "scibert":
|
| 78 |
+
return "allenai/scibert_scivocab_uncased"
|
| 79 |
+
if lm == "bert":
|
| 80 |
+
return "bert-base-uncased"
|
| 81 |
+
if lm == "deberta":
|
| 82 |
+
return "microsoft/deberta-v3-base"
|
| 83 |
+
if lm == "deberta-large":
|
| 84 |
+
return "microsoft/deberta-v3-large"
|
| 85 |
+
return lm
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _infer_head_sizes(state_dict: Dict[str, Any]) -> List[int]:
|
| 89 |
+
head_weights = [
|
| 90 |
+
(k, v) for k, v in state_dict.items() if k.startswith("lns.") and k.endswith(".weight")
|
| 91 |
+
]
|
| 92 |
+
head_weights.sort(key=lambda x: int(x[0].split(".")[1]))
|
| 93 |
+
return [int(weight.shape[0]) for _, weight in head_weights]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class _ContextDataset:
|
| 97 |
+
def __init__(self, texts: List[str]):
|
| 98 |
+
self.texts = texts
|
| 99 |
+
|
| 100 |
+
def __len__(self) -> int:
|
| 101 |
+
return len(self.texts)
|
| 102 |
+
|
| 103 |
+
def __getitem__(self, idx: int):
|
| 104 |
+
return (self.texts[idx], torch.tensor(0), torch.tensor(0))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def label_with_model(
|
| 108 |
+
contexts: List[Dict[str, Any]],
|
| 109 |
+
model_path: Path,
|
| 110 |
+
data_dir: Path,
|
| 111 |
+
class_definition: Path,
|
| 112 |
+
lm: str,
|
| 113 |
+
device: str,
|
| 114 |
+
batch_size: int,
|
| 115 |
+
) -> Dict[int, Dict[str, Any]]:
|
| 116 |
+
data_file = data_dir / "acl.tsv"
|
| 117 |
+
train_data, _, _, label_names = create_data_channels(
|
| 118 |
+
str(data_file),
|
| 119 |
+
str(class_definition),
|
| 120 |
+
lmbd=1.0,
|
| 121 |
+
)
|
| 122 |
+
modelname = _resolve_model_name(lm)
|
| 123 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 124 |
+
head_sizes = _infer_head_sizes(state_dict)
|
| 125 |
+
model = MultiHeadLanguageModel(
|
| 126 |
+
modelname=modelname,
|
| 127 |
+
device=device,
|
| 128 |
+
readout="ch",
|
| 129 |
+
num_classes=head_sizes,
|
| 130 |
+
).to(device)
|
| 131 |
+
model.load_state_dict(state_dict)
|
| 132 |
+
model.eval()
|
| 133 |
+
|
| 134 |
+
collate_fn = CollateFn(
|
| 135 |
+
modelname=modelname,
|
| 136 |
+
class_definitions=train_data.class_definitions,
|
| 137 |
+
instance_weights=False,
|
| 138 |
+
)
|
| 139 |
+
dataset = _ContextDataset([ctx["text"] for ctx in contexts])
|
| 140 |
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
|
| 141 |
+
|
| 142 |
+
outputs: Dict[int, Dict[str, Any]] = {}
|
| 143 |
+
idx_offset = 0
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
for batched_text, labels, ds_indices, class_tokens, class_ds_indices in loader:
|
| 146 |
+
ds_indices = ds_indices.to(device)
|
| 147 |
+
class_ds_indices = class_ds_indices.to(device)
|
| 148 |
+
logits = model(batched_text, ds_indices, class_tokens, class_ds_indices)[0]
|
| 149 |
+
probs = torch.softmax(logits, dim=1)
|
| 150 |
+
preds = logits.argmax(dim=1).cpu().tolist()
|
| 151 |
+
pred_confidences = probs.max(dim=1).values.cpu().tolist()
|
| 152 |
+
top2 = torch.topk(probs, k=2, dim=1).values.cpu()
|
| 153 |
+
margins = (top2[:, 0] - top2[:, 1]).tolist()
|
| 154 |
+
for i, pred in enumerate(preds):
|
| 155 |
+
raw_label = label_names[pred]
|
| 156 |
+
outputs[idx_offset + i + 1] = {
|
| 157 |
+
"id": idx_offset + i + 1,
|
| 158 |
+
"label": raw_label,
|
| 159 |
+
"confidence": float(pred_confidences[i]),
|
| 160 |
+
"confidence_margin": float(margins[i]),
|
| 161 |
+
"cue_span": "",
|
| 162 |
+
"rationale": "scibert_model",
|
| 163 |
+
}
|
| 164 |
+
idx_offset += len(preds)
|
| 165 |
+
return outputs
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def aggregate_citing_labels(labels: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 169 |
+
by_citing: Dict[str, List[Dict[str, Any]]] = {}
|
| 170 |
+
for item in labels:
|
| 171 |
+
citing_id = item.get("citing_paper_id") or ""
|
| 172 |
+
by_citing.setdefault(citing_id, []).append(item)
|
| 173 |
+
|
| 174 |
+
aggregated: List[Dict[str, Any]] = []
|
| 175 |
+
for citing_id, items in by_citing.items():
|
| 176 |
+
title = items[0].get("citing_title", "")
|
| 177 |
+
labels_set = {it.get("label") for it in items}
|
| 178 |
+
|
| 179 |
+
if "Extends" in labels_set:
|
| 180 |
+
label = "Extends"
|
| 181 |
+
evidence_ids = [it["id"] for it in items if it.get("label") == "Extends"]
|
| 182 |
+
elif "Uses" in labels_set:
|
| 183 |
+
label = "Uses"
|
| 184 |
+
evidence_ids = [it["id"] for it in items if it.get("label") == "Uses"]
|
| 185 |
+
elif "CompareOrContrast" in labels_set:
|
| 186 |
+
label = "CompareOrContrast"
|
| 187 |
+
evidence_ids = [
|
| 188 |
+
it["id"] for it in items if it.get("label") == "CompareOrContrast"
|
| 189 |
+
]
|
| 190 |
+
else:
|
| 191 |
+
label = "Background"
|
| 192 |
+
evidence_ids = []
|
| 193 |
+
|
| 194 |
+
aggregated.append(
|
| 195 |
+
{
|
| 196 |
+
"citing_paper_id": citing_id,
|
| 197 |
+
"citing_title": title,
|
| 198 |
+
"label": label,
|
| 199 |
+
"evidence_context_ids": evidence_ids,
|
| 200 |
+
}
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
return aggregated
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def aggregate_final_label(citing_labels: List[Dict[str, Any]]) -> str:
|
| 207 |
+
labels_set = {item.get("label") for item in citing_labels}
|
| 208 |
+
if "Extends" in labels_set:
|
| 209 |
+
return "Extends"
|
| 210 |
+
if "Uses" in labels_set:
|
| 211 |
+
return "Uses"
|
| 212 |
+
if "CompareOrContrast" in labels_set:
|
| 213 |
+
return "CompareOrContrast"
|
| 214 |
+
return "Background"
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def score_for_paper(
|
| 218 |
+
paper_dir: Path,
|
| 219 |
+
batch_size: int,
|
| 220 |
+
overwrite: bool,
|
| 221 |
+
model_path: Path,
|
| 222 |
+
model_data_dir: Path,
|
| 223 |
+
model_class_def: Path,
|
| 224 |
+
model_lm: str,
|
| 225 |
+
device: str,
|
| 226 |
+
) -> str:
|
| 227 |
+
usage_path = paper_dir / USAGE_CONTEXTS_FILE
|
| 228 |
+
usage = load_json(usage_path)
|
| 229 |
+
if not isinstance(usage, dict):
|
| 230 |
+
return "missing_usage"
|
| 231 |
+
|
| 232 |
+
contexts = flatten_contexts(usage)
|
| 233 |
+
if not contexts:
|
| 234 |
+
return "empty_contexts"
|
| 235 |
+
|
| 236 |
+
out_path = paper_dir / OUT_FILE
|
| 237 |
+
if out_path.exists() and not overwrite:
|
| 238 |
+
return "skipped"
|
| 239 |
+
|
| 240 |
+
labeled = label_with_model(
|
| 241 |
+
contexts=contexts,
|
| 242 |
+
model_path=model_path,
|
| 243 |
+
data_dir=model_data_dir,
|
| 244 |
+
class_definition=model_class_def,
|
| 245 |
+
lm=model_lm,
|
| 246 |
+
device=device,
|
| 247 |
+
batch_size=batch_size,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
labels_sorted = []
|
| 251 |
+
for context in contexts:
|
| 252 |
+
context_id = context["id"]
|
| 253 |
+
item = labeled.get(context_id)
|
| 254 |
+
if not item:
|
| 255 |
+
item = {
|
| 256 |
+
"id": context_id,
|
| 257 |
+
"label": "Background",
|
| 258 |
+
"confidence": 0.0,
|
| 259 |
+
"cue_span": "",
|
| 260 |
+
"rationale": "missing label",
|
| 261 |
+
}
|
| 262 |
+
item = dict(item)
|
| 263 |
+
item["citing_paper_id"] = context.get("citing_paper_id", "")
|
| 264 |
+
item["citing_title"] = context.get("citing_title", "")
|
| 265 |
+
item["text"] = context.get("text", "")
|
| 266 |
+
labels_sorted.append(item)
|
| 267 |
+
|
| 268 |
+
citing_labels = aggregate_citing_labels(labels_sorted)
|
| 269 |
+
payload = {
|
| 270 |
+
"paper_id": usage.get("paper_id"),
|
| 271 |
+
"num_contexts": len(contexts),
|
| 272 |
+
"label_set": LABEL_SET,
|
| 273 |
+
"labels": labels_sorted,
|
| 274 |
+
"citing_paper_labels": citing_labels,
|
| 275 |
+
"final_label": aggregate_final_label(citing_labels),
|
| 276 |
+
}
|
| 277 |
+
out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 278 |
+
return "labeled"
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def main() -> None:
|
| 282 |
+
parser = argparse.ArgumentParser(
|
| 283 |
+
description="Label citation functions using a Deep-Citation checkpoint."
|
| 284 |
+
)
|
| 285 |
+
parser.add_argument(
|
| 286 |
+
"--root",
|
| 287 |
+
type=str,
|
| 288 |
+
default="runs/processed_papers",
|
| 289 |
+
help="Root directory containing processed paper directories.",
|
| 290 |
+
)
|
| 291 |
+
parser.add_argument(
|
| 292 |
+
"--batch-size",
|
| 293 |
+
type=int,
|
| 294 |
+
default=32,
|
| 295 |
+
help="Batch size for model inference.",
|
| 296 |
+
)
|
| 297 |
+
parser.add_argument(
|
| 298 |
+
"--overwrite",
|
| 299 |
+
action="store_true",
|
| 300 |
+
help="Overwrite existing usage_context_labels.json files.",
|
| 301 |
+
)
|
| 302 |
+
parser.add_argument(
|
| 303 |
+
"--model-path",
|
| 304 |
+
type=str,
|
| 305 |
+
required=True,
|
| 306 |
+
help="Path to Deep-Citation best_model.pt checkpoint.",
|
| 307 |
+
)
|
| 308 |
+
parser.add_argument(
|
| 309 |
+
"--model-data-dir",
|
| 310 |
+
type=str,
|
| 311 |
+
default="Deep-Citation/Data",
|
| 312 |
+
help="Deep-Citation data directory (for label order).",
|
| 313 |
+
)
|
| 314 |
+
parser.add_argument(
|
| 315 |
+
"--model-class-def",
|
| 316 |
+
type=str,
|
| 317 |
+
default="Deep-Citation/Data/class_def.json",
|
| 318 |
+
help="Deep-Citation class_def.json path.",
|
| 319 |
+
)
|
| 320 |
+
parser.add_argument(
|
| 321 |
+
"--model-lm",
|
| 322 |
+
type=str,
|
| 323 |
+
default="scibert",
|
| 324 |
+
help="Model name used for the Deep-Citation checkpoint.",
|
| 325 |
+
)
|
| 326 |
+
parser.add_argument(
|
| 327 |
+
"--device",
|
| 328 |
+
type=str,
|
| 329 |
+
default="cuda",
|
| 330 |
+
help="Device for model inference (cuda/cpu).",
|
| 331 |
+
)
|
| 332 |
+
args = parser.parse_args()
|
| 333 |
+
|
| 334 |
+
model_path = Path(args.model_path).expanduser().resolve()
|
| 335 |
+
if not model_path.exists():
|
| 336 |
+
raise SystemExit(f"Model path does not exist: {model_path}")
|
| 337 |
+
|
| 338 |
+
root = Path(args.root).expanduser().resolve()
|
| 339 |
+
if not root.exists():
|
| 340 |
+
raise SystemExit(f"Root directory does not exist: {root}")
|
| 341 |
+
|
| 342 |
+
paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
|
| 343 |
+
print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
|
| 344 |
+
|
| 345 |
+
counts = {
|
| 346 |
+
"labeled": 0,
|
| 347 |
+
"skipped": 0,
|
| 348 |
+
"missing_usage": 0,
|
| 349 |
+
"empty_contexts": 0,
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
for paper_dir in paper_dirs:
|
| 353 |
+
status = score_for_paper(
|
| 354 |
+
paper_dir,
|
| 355 |
+
args.batch_size,
|
| 356 |
+
args.overwrite,
|
| 357 |
+
model_path=model_path,
|
| 358 |
+
model_data_dir=Path(args.model_data_dir).expanduser().resolve(),
|
| 359 |
+
model_class_def=Path(args.model_class_def).expanduser().resolve(),
|
| 360 |
+
model_lm=args.model_lm,
|
| 361 |
+
device=args.device,
|
| 362 |
+
)
|
| 363 |
+
counts[status] = counts.get(status, 0) + 1
|
| 364 |
+
print(f"[{status.upper()}] {paper_dir.name}")
|
| 365 |
+
|
| 366 |
+
print(
|
| 367 |
+
"[SUMMARY] labeled={labeled}, skipped={skipped}, missing_usage={missing_usage}, "
|
| 368 |
+
"empty_contexts={empty_contexts}".format(**counts)
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
if __name__ == "__main__":
|
| 373 |
+
main()
|
src/step_05_verify_uses_extends/prompts.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
USES_DEFINITION = (
|
| 5 |
+
"USES: The CITING_PAPER explicitly uses/adopts/evaluates on/includes/relies on "
|
| 6 |
+
"a dataset, benchmark, method, tool, or reported results from TARGET_PAPER "
|
| 7 |
+
"as part of the CITING_PAPER's own methodology or evaluation."
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
EXTENDS_DEFINITION = (
|
| 11 |
+
"EXTENDS: The CITING_PAPER explicitly extends/modifies/adapts/builds upon "
|
| 12 |
+
"TARGET_PAPER's method/dataset/benchmark/tool."
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
NOTES_DEFINITION = (
|
| 16 |
+
"NOT USES/EXTENDS: Merely describing what TARGET_PAPER introduces/offers/proposes "
|
| 17 |
+
"or listing it among related work or benchmarks (without stating adoption). "
|
| 18 |
+
"If no explicit adoption/extension cue is present, label NOT_CONFIRMED."
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
FEW_SHOT_USES = [
|
| 23 |
+
"We use the same splits as <CITED HERE> .",
|
| 24 |
+
"The Praat tool was used ( <CITED HERE> ) .",
|
| 25 |
+
"CCGBank ( <CITED HERE> ) is used to train the model .",
|
| 26 |
+
"This design idea was adopted from TANKA ( <CITED HERE>b ) .",
|
| 27 |
+
"Our strategy is based on the approach presented by <CITED HERE> .",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
FEW_SHOT_EXTENDS = [
|
| 31 |
+
"The features can be easily obtained by modifying the TAT extraction algorithm described in ( <CITED HERE> ) .",
|
| 32 |
+
"Our own work ( <CITED HERE> ) extends the first idea to paraphrase fragment extraction on monolingual parallel and comparable corpora .",
|
| 33 |
+
"This article represents an extension of our previous work on unsupervised event coreference resolution ( Bejan et al. 2009 ; <CITED HERE> ) .",
|
| 34 |
+
"This evaluation set-up is an improvement versus the one we previously reported ( <CITED HERE> ) , in which fixed partitions were used for training , development , and testing .",
|
| 35 |
+
"The computational treatment of lexical rules proposed can be seen as an extension to the principled method discussed by Gotz and <CITED HERE> , 1996 , 1997b ) for encoding the main building block of HPSG grammars -- the implicative constraints -- as a logic program .",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
FEW_SHOT_NOT_CONFIRMED = [
|
| 39 |
+
"<CITED HERE> introduced factored SMT .",
|
| 40 |
+
"See ( <CITED HERE> ) for a discussion .",
|
| 41 |
+
"See , among others , ( <CITED HERE> ) .",
|
| 42 |
+
"<CITED HERE> reported a correlation of r = .69 .",
|
| 43 |
+
"See <CITED HERE> for further discussion .",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_uses_extends_verification_prompt(
|
| 48 |
+
target_info: Dict[str, str],
|
| 49 |
+
candidates: List[Dict[str, str]],
|
| 50 |
+
) -> str:
|
| 51 |
+
header = [
|
| 52 |
+
"You are verifying citation function for a TARGET paper inside a citing sentence.",
|
| 53 |
+
"Be strict: lists of related work or benchmarks are NOT USES/EXTENDS unless there is an explicit action",
|
| 54 |
+
"like \"use\", \"build on\", \"adopt\", \"extend\", \"based on\", \"trained on\", \"evaluate on\", \"implement\".",
|
| 55 |
+
"",
|
| 56 |
+
"Actor test (CRITICAL for USES/EXTENSION):",
|
| 57 |
+
"- Only label USES or EXTENSION if the ACTION is performed by the CITING_PAPER.",
|
| 58 |
+
"- The cue_span for USES/EXTENSION must include an explicit citing-paper actor phrase such as:",
|
| 59 |
+
" \"we\", \"our\", \"in this work\", \"in this paper\", \"we use\", \"we evaluate\",",
|
| 60 |
+
" \"our evaluation includes\", \"we extend\", \"we build on\", \"we adapt\".",
|
| 61 |
+
"- If the context says the TARGET_PAPER (or some other paper/system) uses/extends something",
|
| 62 |
+
" (e.g., \"TARGET_PAPER uses...\", \"TARGET_PAPER extends...\"),",
|
| 63 |
+
" then it is NOT USES/EXTENSION. Label NOT_CONFIRMED.",
|
| 64 |
+
"",
|
| 65 |
+
"Task: Label each sentence as USES, EXTENDS, or NOT_CONFIRMED.",
|
| 66 |
+
"Return JSON only with one entry per input sentence.",
|
| 67 |
+
"",
|
| 68 |
+
"Definitions:",
|
| 69 |
+
f"- {USES_DEFINITION}",
|
| 70 |
+
f"- {EXTENDS_DEFINITION}",
|
| 71 |
+
f"- {NOTES_DEFINITION}",
|
| 72 |
+
"",
|
| 73 |
+
"Output rules:",
|
| 74 |
+
"- label must be one of: USES, EXTENDS, NOT_CONFIRMED",
|
| 75 |
+
"- cue_span: exact substring from the sentence that justifies USES/EXTENDS, else empty",
|
| 76 |
+
"- rationale: one short sentence",
|
| 77 |
+
"- If cue_span is empty => label must be NOT_CONFIRMED",
|
| 78 |
+
"",
|
| 79 |
+
"Few-shot examples:",
|
| 80 |
+
"USES:",
|
| 81 |
+
]
|
| 82 |
+
for ex in FEW_SHOT_USES:
|
| 83 |
+
header.append(f"- {ex}")
|
| 84 |
+
header.append("EXTENDS:")
|
| 85 |
+
for ex in FEW_SHOT_EXTENDS:
|
| 86 |
+
header.append(f"- {ex}")
|
| 87 |
+
header.append("NOT_CONFIRMED:")
|
| 88 |
+
for ex in FEW_SHOT_NOT_CONFIRMED:
|
| 89 |
+
header.append(f"- {ex}")
|
| 90 |
+
|
| 91 |
+
header.extend(
|
| 92 |
+
[
|
| 93 |
+
"",
|
| 94 |
+
"TARGET_PAPER:",
|
| 95 |
+
f"- title: {target_info.get('title', '')}",
|
| 96 |
+
f"- first_author_last: {target_info.get('first_author_last', '')}",
|
| 97 |
+
f"- year: {target_info.get('year', '')}",
|
| 98 |
+
"",
|
| 99 |
+
"CANDIDATES:",
|
| 100 |
+
]
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
for item in candidates:
|
| 104 |
+
header.extend(
|
| 105 |
+
[
|
| 106 |
+
f"ID: {item['id']}",
|
| 107 |
+
f"Citing paper: {item.get('citing_title', '')}",
|
| 108 |
+
f"Sentence: {item.get('text', '')}",
|
| 109 |
+
"",
|
| 110 |
+
]
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
header.append("JSON OUTPUT:")
|
| 114 |
+
header.append("{\"labels\": [{\"id\": 1, \"label\": \"USES\", \"cue_span\": \"...\", \"rationale\": \"...\"}]}")
|
| 115 |
+
return "\n".join(header)
|
src/step_05_verify_uses_extends/schemas.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
USES_EXTENDS_VERIFICATION_JSON_SCHEMA = {
|
| 2 |
+
"type": "object",
|
| 3 |
+
"properties": {
|
| 4 |
+
"labels": {
|
| 5 |
+
"type": "array",
|
| 6 |
+
"items": {
|
| 7 |
+
"type": "object",
|
| 8 |
+
"properties": {
|
| 9 |
+
"id": {"type": "integer"},
|
| 10 |
+
"label": {
|
| 11 |
+
"type": "string",
|
| 12 |
+
"enum": ["USES", "EXTENDS", "NOT_CONFIRMED"],
|
| 13 |
+
},
|
| 14 |
+
"cue_span": {"type": "string"},
|
| 15 |
+
"rationale": {"type": "string"},
|
| 16 |
+
},
|
| 17 |
+
"required": ["id", "label", "cue_span", "rationale"],
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
},
|
| 21 |
+
"required": ["labels"],
|
| 22 |
+
}
|
src/step_05_verify_uses_extends/verify_uses_extends.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
SRC_ROOT = Path(__file__).resolve().parents[1]
|
| 8 |
+
if str(SRC_ROOT) not in sys.path:
|
| 9 |
+
sys.path.insert(0, str(SRC_ROOT))
|
| 10 |
+
|
| 11 |
+
from common.llm_client import LLMClient
|
| 12 |
+
|
| 13 |
+
from prompts import build_uses_extends_verification_prompt
|
| 14 |
+
from schemas import USES_EXTENDS_VERIFICATION_JSON_SCHEMA
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
PAPER_META_FILE = "paper_metadata.json"
|
| 18 |
+
USAGE_LABELS_FILE = "usage_context_labels.json"
|
| 19 |
+
OUT_FILE = "usage_uses_extends_verified.json"
|
| 20 |
+
|
| 21 |
+
USE_LABELS = {"Uses", "Extends"}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_json(path: Path) -> Any | None:
|
| 25 |
+
if not path.exists():
|
| 26 |
+
return None
|
| 27 |
+
try:
|
| 28 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 29 |
+
except Exception:
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def iter_paper_dirs(root: Path) -> List[Path]:
|
| 34 |
+
out: List[Path] = []
|
| 35 |
+
for child in root.iterdir():
|
| 36 |
+
if child.is_dir() and (child / PAPER_META_FILE).exists():
|
| 37 |
+
out.append(child)
|
| 38 |
+
return out
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _normalize_author_last(name: str) -> str:
|
| 42 |
+
parts = [p for p in (name or "").split() if p.strip()]
|
| 43 |
+
return parts[-1] if parts else ""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def extract_target_info(meta: Any) -> Dict[str, str]:
|
| 47 |
+
if isinstance(meta, list) and meta:
|
| 48 |
+
meta = meta[0]
|
| 49 |
+
if not isinstance(meta, dict):
|
| 50 |
+
return {"title": "", "first_author_last": "", "year": ""}
|
| 51 |
+
authors = meta.get("authors") or []
|
| 52 |
+
first_author = authors[0]["name"] if authors else ""
|
| 53 |
+
return {
|
| 54 |
+
"title": meta.get("title", ""),
|
| 55 |
+
"first_author_last": _normalize_author_last(first_author),
|
| 56 |
+
"year": str(meta.get("year", "")),
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def verify_candidates(
|
| 61 |
+
client: LLMClient,
|
| 62 |
+
target_info: Dict[str, str],
|
| 63 |
+
candidates: List[Dict[str, Any]],
|
| 64 |
+
) -> List[Dict[str, Any]]:
|
| 65 |
+
prompt = build_uses_extends_verification_prompt(target_info, candidates)
|
| 66 |
+
try:
|
| 67 |
+
raw = client.call(prompt, schema=USES_EXTENDS_VERIFICATION_JSON_SCHEMA)
|
| 68 |
+
except Exception as exc:
|
| 69 |
+
print(f"[WARN] LLM call failed: {exc}. Marking all candidates NOT_CONFIRMED.")
|
| 70 |
+
return [
|
| 71 |
+
{
|
| 72 |
+
"id": item.get("id"),
|
| 73 |
+
"label": "NOT_CONFIRMED",
|
| 74 |
+
"cue_span": "",
|
| 75 |
+
"rationale": "",
|
| 76 |
+
"text": item.get("text", ""),
|
| 77 |
+
"citing_paper_id": item.get("citing_paper_id", ""),
|
| 78 |
+
"citing_title": item.get("citing_title", ""),
|
| 79 |
+
"original_label": item.get("original_label", ""),
|
| 80 |
+
}
|
| 81 |
+
for item in candidates
|
| 82 |
+
]
|
| 83 |
+
data = _parse_llm_json(raw)
|
| 84 |
+
if not isinstance(data, dict):
|
| 85 |
+
print("[WARN] Failed to parse LLM JSON response; marking all candidates NOT_CONFIRMED.")
|
| 86 |
+
return [
|
| 87 |
+
{
|
| 88 |
+
"id": item.get("id"),
|
| 89 |
+
"label": "NOT_CONFIRMED",
|
| 90 |
+
"cue_span": "",
|
| 91 |
+
"rationale": "",
|
| 92 |
+
"text": item.get("text", ""),
|
| 93 |
+
"citing_paper_id": item.get("citing_paper_id", ""),
|
| 94 |
+
"citing_title": item.get("citing_title", ""),
|
| 95 |
+
"original_label": item.get("original_label", ""),
|
| 96 |
+
}
|
| 97 |
+
for item in candidates
|
| 98 |
+
]
|
| 99 |
+
labels = data.get("labels", [])
|
| 100 |
+
by_id = {item.get("id"): item for item in labels if isinstance(item, dict)}
|
| 101 |
+
|
| 102 |
+
verified: List[Dict[str, Any]] = []
|
| 103 |
+
for candidate in candidates:
|
| 104 |
+
item_id = candidate["id"]
|
| 105 |
+
model = by_id.get(item_id, {})
|
| 106 |
+
label = model.get("label", "NOT_CONFIRMED")
|
| 107 |
+
cue_span = model.get("cue_span", "")
|
| 108 |
+
if not cue_span:
|
| 109 |
+
label = "NOT_CONFIRMED"
|
| 110 |
+
verified.append(
|
| 111 |
+
{
|
| 112 |
+
"id": item_id,
|
| 113 |
+
"label": label,
|
| 114 |
+
"cue_span": cue_span,
|
| 115 |
+
"rationale": model.get("rationale", ""),
|
| 116 |
+
"text": candidate.get("text", ""),
|
| 117 |
+
"citing_paper_id": candidate.get("citing_paper_id", ""),
|
| 118 |
+
"citing_title": candidate.get("citing_title", ""),
|
| 119 |
+
"original_label": candidate.get("original_label", ""),
|
| 120 |
+
}
|
| 121 |
+
)
|
| 122 |
+
return verified
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _parse_llm_json(raw: str) -> Any | None:
|
| 126 |
+
try:
|
| 127 |
+
return json.loads(raw)
|
| 128 |
+
except json.JSONDecodeError:
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
cleaned = raw.strip()
|
| 132 |
+
if cleaned.startswith("```"):
|
| 133 |
+
cleaned = cleaned.strip("`")
|
| 134 |
+
cleaned = cleaned.replace("json", "", 1).strip()
|
| 135 |
+
|
| 136 |
+
start = cleaned.find("{")
|
| 137 |
+
end = cleaned.rfind("}")
|
| 138 |
+
if start == -1 or end == -1 or end <= start:
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
snippet = cleaned[start : end + 1]
|
| 142 |
+
try:
|
| 143 |
+
return json.loads(snippet)
|
| 144 |
+
except json.JSONDecodeError:
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def process_paper(
|
| 149 |
+
paper_dir: Path,
|
| 150 |
+
client: LLMClient,
|
| 151 |
+
k: int,
|
| 152 |
+
batch_size: int,
|
| 153 |
+
overwrite: bool,
|
| 154 |
+
resume: bool,
|
| 155 |
+
) -> str:
|
| 156 |
+
labels_path = paper_dir / USAGE_LABELS_FILE
|
| 157 |
+
payload = load_json(labels_path)
|
| 158 |
+
if not isinstance(payload, dict):
|
| 159 |
+
return "missing_labels"
|
| 160 |
+
|
| 161 |
+
out_path = paper_dir / OUT_FILE
|
| 162 |
+
if out_path.exists() and (resume or not overwrite):
|
| 163 |
+
return "skipped"
|
| 164 |
+
|
| 165 |
+
labels = payload.get("labels", [])
|
| 166 |
+
candidates_all = []
|
| 167 |
+
for item in labels:
|
| 168 |
+
if item.get("label") in USE_LABELS:
|
| 169 |
+
candidates_all.append(
|
| 170 |
+
{
|
| 171 |
+
"id": item.get("id"),
|
| 172 |
+
"text": item.get("text", ""),
|
| 173 |
+
"citing_paper_id": item.get("citing_paper_id", ""),
|
| 174 |
+
"citing_title": item.get("citing_title", ""),
|
| 175 |
+
"original_label": item.get("label"),
|
| 176 |
+
"confidence": float(item.get("confidence", 0.0) or 0.0),
|
| 177 |
+
}
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if not candidates_all:
|
| 181 |
+
result = {
|
| 182 |
+
"paper_id": payload.get("paper_id"),
|
| 183 |
+
"target": {},
|
| 184 |
+
"candidates_total": 0,
|
| 185 |
+
"candidates_considered": 0,
|
| 186 |
+
"verified": [],
|
| 187 |
+
"confirmed": [],
|
| 188 |
+
}
|
| 189 |
+
out_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
| 190 |
+
return "no_candidates"
|
| 191 |
+
|
| 192 |
+
# Keep top-k highest-confidence USES/EXTENDS contexts for LLM verification.
|
| 193 |
+
# If k <= 0, verify all candidates.
|
| 194 |
+
candidates_all = sorted(
|
| 195 |
+
candidates_all,
|
| 196 |
+
key=lambda x: x.get("confidence", 0.0),
|
| 197 |
+
reverse=True,
|
| 198 |
+
)
|
| 199 |
+
candidates = candidates_all if k <= 0 else candidates_all[:k]
|
| 200 |
+
|
| 201 |
+
target_info = extract_target_info(load_json(paper_dir / PAPER_META_FILE))
|
| 202 |
+
verified: List[Dict[str, Any]] = []
|
| 203 |
+
if batch_size <= 0:
|
| 204 |
+
batch_size = 25
|
| 205 |
+
for i in range(0, len(candidates), batch_size):
|
| 206 |
+
batch = candidates[i : i + batch_size]
|
| 207 |
+
verified.extend(verify_candidates(client, target_info, batch))
|
| 208 |
+
confirmed = [v for v in verified if v["label"] in {"USES", "EXTENDS"}]
|
| 209 |
+
if any(item["label"] == "EXTENDS" for item in confirmed):
|
| 210 |
+
final_label = "EXTENDS"
|
| 211 |
+
elif confirmed:
|
| 212 |
+
final_label = "USES"
|
| 213 |
+
else:
|
| 214 |
+
final_label = "NOT_CONFIRMED"
|
| 215 |
+
|
| 216 |
+
result = {
|
| 217 |
+
"paper_id": payload.get("paper_id"),
|
| 218 |
+
"target": target_info,
|
| 219 |
+
"candidates_total": len(candidates_all),
|
| 220 |
+
"candidates_considered": len(candidates),
|
| 221 |
+
"verification_batch_size": int(batch_size),
|
| 222 |
+
"verification_num_batches": (len(candidates) + batch_size - 1) // batch_size if candidates else 0,
|
| 223 |
+
"candidates_selected": len(confirmed),
|
| 224 |
+
"verified": verified,
|
| 225 |
+
"confirmed": confirmed,
|
| 226 |
+
"confirmed_extends": sum(1 for x in confirmed if x.get("label") == "EXTENDS"),
|
| 227 |
+
"confirmed_uses": sum(1 for x in confirmed if x.get("label") == "USES"),
|
| 228 |
+
"final_label": final_label,
|
| 229 |
+
}
|
| 230 |
+
out_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
| 231 |
+
return "verified"
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def main() -> None:
|
| 235 |
+
parser = argparse.ArgumentParser(
|
| 236 |
+
description="Verify USES/EXTENDS candidates via LLM and select top-K."
|
| 237 |
+
)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--root",
|
| 240 |
+
type=str,
|
| 241 |
+
default="runs/processed_papers",
|
| 242 |
+
help="Root directory containing processed paper directories.",
|
| 243 |
+
)
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--k",
|
| 246 |
+
type=int,
|
| 247 |
+
default=0,
|
| 248 |
+
help="Verify top-k USES/EXTENDS candidates ranked by classifier confidence (<=0 means all).",
|
| 249 |
+
)
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--batch-size",
|
| 252 |
+
type=int,
|
| 253 |
+
default=25,
|
| 254 |
+
help="Number of candidates per LLM verification batch.",
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--overwrite",
|
| 258 |
+
action="store_true",
|
| 259 |
+
help="Overwrite existing usage_uses_extends_verified.json files.",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--resume",
|
| 263 |
+
action="store_true",
|
| 264 |
+
help="Skip papers with existing output files (even if --overwrite is set).",
|
| 265 |
+
)
|
| 266 |
+
args = parser.parse_args()
|
| 267 |
+
|
| 268 |
+
root = Path(args.root).expanduser().resolve()
|
| 269 |
+
if not root.exists():
|
| 270 |
+
raise SystemExit(f"Root directory does not exist: {root}")
|
| 271 |
+
|
| 272 |
+
client = LLMClient()
|
| 273 |
+
paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
|
| 274 |
+
print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
|
| 275 |
+
|
| 276 |
+
counts = {"verified": 0, "skipped": 0, "missing_labels": 0, "no_candidates": 0}
|
| 277 |
+
for paper_dir in paper_dirs:
|
| 278 |
+
status = process_paper(
|
| 279 |
+
paper_dir,
|
| 280 |
+
client,
|
| 281 |
+
args.k,
|
| 282 |
+
args.batch_size,
|
| 283 |
+
args.overwrite,
|
| 284 |
+
args.resume,
|
| 285 |
+
)
|
| 286 |
+
counts[status] = counts.get(status, 0) + 1
|
| 287 |
+
print(f"[{status.upper()}] {paper_dir.name}")
|
| 288 |
+
|
| 289 |
+
print(
|
| 290 |
+
"[SUMMARY] verified={verified}, skipped={skipped}, missing_labels={missing_labels}, "
|
| 291 |
+
"no_candidates={no_candidates}".format(**counts)
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
main()
|
src/step_06_extract_paragraphs/extract_arxiv_paragraphs.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
import sys
|
| 6 |
+
import tarfile
|
| 7 |
+
import tempfile
|
| 8 |
+
import time
|
| 9 |
+
import urllib.request
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
SRC_ROOT = Path(__file__).resolve().parents[1]
|
| 15 |
+
if str(SRC_ROOT) not in sys.path:
|
| 16 |
+
sys.path.insert(0, str(SRC_ROOT))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
PAPER_META_FILE = "paper_metadata.json"
|
| 20 |
+
USAGE_CONTEXTS_FILE = "usage_contexts.json"
|
| 21 |
+
VERIFIED_FILE = "usage_uses_extends_verified.json"
|
| 22 |
+
OUT_FILE = "usage_citing_paragraphs.json"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_json(path: Path) -> Any | None:
|
| 26 |
+
if not path.exists():
|
| 27 |
+
return None
|
| 28 |
+
try:
|
| 29 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 30 |
+
except Exception:
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def iter_paper_dirs(root: Path) -> List[Path]:
|
| 35 |
+
out: List[Path] = []
|
| 36 |
+
for child in root.iterdir():
|
| 37 |
+
if child.is_dir() and (child / PAPER_META_FILE).exists():
|
| 38 |
+
out.append(child)
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def safe_extract(tar: tarfile.TarFile, path: Path) -> None:
|
| 43 |
+
for member in tar.getmembers():
|
| 44 |
+
member_path = path / member.name
|
| 45 |
+
if not str(member_path.resolve()).startswith(str(path.resolve())):
|
| 46 |
+
raise RuntimeError(f"Blocked path traversal in tar: {member.name}")
|
| 47 |
+
tar.extractall(path)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
_ARXIV_LAST_TS = 0.0
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _arxiv_min_interval_sleep() -> None:
|
| 54 |
+
"""Global throttle to avoid arXiv API rate limits."""
|
| 55 |
+
global _ARXIV_LAST_TS
|
| 56 |
+
min_interval = float(os.getenv("ARXIV_MIN_INTERVAL", "1.0"))
|
| 57 |
+
now = time.monotonic()
|
| 58 |
+
elapsed = now - _ARXIV_LAST_TS
|
| 59 |
+
if elapsed < min_interval:
|
| 60 |
+
time.sleep(min_interval - elapsed)
|
| 61 |
+
_ARXIV_LAST_TS = time.monotonic()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def download_arxiv_source(arxiv_id: str, tmpdir: Path) -> Optional[Path]:
|
| 65 |
+
url = f"https://arxiv.org/e-print/{arxiv_id}"
|
| 66 |
+
archive_path = tmpdir / f"{arxiv_id.replace('/', '_')}.tar"
|
| 67 |
+
max_retries = int(os.getenv("ARXIV_MAX_RETRIES", "6"))
|
| 68 |
+
base_sleep = float(os.getenv("ARXIV_BASE_SLEEP", "2.0"))
|
| 69 |
+
max_sleep = float(os.getenv("ARXIV_MAX_BACKOFF", "60"))
|
| 70 |
+
|
| 71 |
+
for attempt in range(max_retries):
|
| 72 |
+
try:
|
| 73 |
+
_arxiv_min_interval_sleep()
|
| 74 |
+
urllib.request.urlretrieve(url, archive_path) # noqa: S310
|
| 75 |
+
try:
|
| 76 |
+
with tarfile.open(archive_path) as tar:
|
| 77 |
+
safe_extract(tar, tmpdir)
|
| 78 |
+
return tmpdir
|
| 79 |
+
except tarfile.ReadError as exc:
|
| 80 |
+
print(f"[WARN] Invalid arXiv archive for {arxiv_id}: {exc}")
|
| 81 |
+
return None
|
| 82 |
+
except Exception as exc:
|
| 83 |
+
# arXiv sometimes returns 429; treat any network error as retryable.
|
| 84 |
+
sleep = min(base_sleep * (2 ** attempt), max_sleep) + random.uniform(0.0, 0.5)
|
| 85 |
+
print(f"[WARN] Failed to download arXiv source for {arxiv_id}: {exc}")
|
| 86 |
+
print(f"[WARN] arXiv download retrying in {sleep:.2f}s")
|
| 87 |
+
time.sleep(sleep)
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
print(f"[ERROR] Giving up after {max_retries} attempts for arXiv {arxiv_id}")
|
| 91 |
+
return None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def find_main_tex(root: Path) -> Optional[Path]:
|
| 95 |
+
tex_files = list(root.rglob("*.tex"))
|
| 96 |
+
if not tex_files:
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
candidates: List[Tuple[int, Path]] = []
|
| 100 |
+
for path in tex_files:
|
| 101 |
+
try:
|
| 102 |
+
text = path.read_text(encoding="utf-8", errors="ignore")
|
| 103 |
+
except Exception:
|
| 104 |
+
continue
|
| 105 |
+
score = 0
|
| 106 |
+
if "\\begin{document}" in text:
|
| 107 |
+
score += 3
|
| 108 |
+
if "\\documentclass" in text:
|
| 109 |
+
score += 2
|
| 110 |
+
score += len(text) // 1000
|
| 111 |
+
candidates.append((score, path))
|
| 112 |
+
|
| 113 |
+
candidates.sort(key=lambda x: x[0], reverse=True)
|
| 114 |
+
return candidates[0][1] if candidates else None
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def read_bib_files(root: Path) -> Dict[str, str]:
|
| 118 |
+
bibs: Dict[str, str] = {}
|
| 119 |
+
for path in root.rglob("*.bib"):
|
| 120 |
+
try:
|
| 121 |
+
bibs[str(path.relative_to(root))] = path.read_text(encoding="utf-8", errors="ignore")
|
| 122 |
+
except Exception:
|
| 123 |
+
continue
|
| 124 |
+
return bibs
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def normalize_text(text: str) -> str:
|
| 128 |
+
text = re.sub(r"[^a-z0-9\s]", " ", text.lower())
|
| 129 |
+
return re.sub(r"\s+", " ", text).strip()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def tokenize(text: str) -> List[str]:
|
| 133 |
+
return [t for t in normalize_text(text).split() if t]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def paragraphize(text: str) -> List[str]:
|
| 137 |
+
text = text.replace("\r\n", "\n")
|
| 138 |
+
text = re.sub(r"\n\s*\n", "\n\n", text)
|
| 139 |
+
paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
|
| 140 |
+
return paragraphs
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def strip_latex_comments(text: str) -> str:
|
| 144 |
+
# Remove explicit comment environments first.
|
| 145 |
+
text = re.sub(r"\\begin\{comment\}.*?\\end\{comment\}", "", text, flags=re.DOTALL)
|
| 146 |
+
|
| 147 |
+
cleaned_lines: List[str] = []
|
| 148 |
+
for line in text.splitlines():
|
| 149 |
+
out_chars: List[str] = []
|
| 150 |
+
i = 0
|
| 151 |
+
while i < len(line):
|
| 152 |
+
ch = line[i]
|
| 153 |
+
if ch == "%":
|
| 154 |
+
# Keep escaped percent (\%) and continue parsing.
|
| 155 |
+
if i > 0 and line[i - 1] == "\\":
|
| 156 |
+
out_chars.append(ch)
|
| 157 |
+
i += 1
|
| 158 |
+
continue
|
| 159 |
+
# Unescaped percent starts a LaTeX comment; ignore rest of the line.
|
| 160 |
+
break
|
| 161 |
+
out_chars.append(ch)
|
| 162 |
+
i += 1
|
| 163 |
+
cleaned_lines.append("".join(out_chars))
|
| 164 |
+
return "\n".join(cleaned_lines)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def parse_bib_entries(bib_text: str) -> List[Dict[str, str]]:
|
| 168 |
+
entries: List[Dict[str, str]] = []
|
| 169 |
+
matches = list(re.finditer(r"@[\w]+\s*\{\s*([^,]+),", bib_text))
|
| 170 |
+
for i, match in enumerate(matches):
|
| 171 |
+
key = match.group(1).strip()
|
| 172 |
+
start = match.end()
|
| 173 |
+
end = matches[i + 1].start() if i + 1 < len(matches) else len(bib_text)
|
| 174 |
+
body = bib_text[start:end]
|
| 175 |
+
fields = {}
|
| 176 |
+
for f_match in re.finditer(r"(\w+)\s*=\s*[{|\"](.+?)[}|\"]\s*,", body, re.DOTALL):
|
| 177 |
+
fields[f_match.group(1).lower()] = f_match.group(2).strip()
|
| 178 |
+
entries.append({"key": key, **fields})
|
| 179 |
+
return entries
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def find_target_bib_keys(
|
| 183 |
+
bib_texts: Dict[str, str],
|
| 184 |
+
target_info: Dict[str, str],
|
| 185 |
+
) -> List[str]:
|
| 186 |
+
target_title = normalize_text(target_info.get("title", ""))
|
| 187 |
+
target_author = normalize_text(target_info.get("first_author_last", ""))
|
| 188 |
+
target_year = target_info.get("year", "")
|
| 189 |
+
if not target_title and not target_author:
|
| 190 |
+
return []
|
| 191 |
+
|
| 192 |
+
keys: List[str] = []
|
| 193 |
+
for bib_text in bib_texts.values():
|
| 194 |
+
for entry in parse_bib_entries(bib_text):
|
| 195 |
+
title = normalize_text(entry.get("title", ""))
|
| 196 |
+
author = normalize_text(entry.get("author", ""))
|
| 197 |
+
year = str(entry.get("year", ""))
|
| 198 |
+
has_title = bool(title)
|
| 199 |
+
title_match = target_title and (target_title in title or title in target_title)
|
| 200 |
+
author_match = target_author and target_author in author
|
| 201 |
+
year_match = target_year and target_year in year
|
| 202 |
+
|
| 203 |
+
if title_match and author_match:
|
| 204 |
+
keys.append(entry["key"])
|
| 205 |
+
elif not has_title and author_match and year_match:
|
| 206 |
+
keys.append(entry["key"])
|
| 207 |
+
elif author_match and year_match:
|
| 208 |
+
keys.append(entry["key"])
|
| 209 |
+
return keys
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def replace_target_citations(text: str, target_keys: List[str], target_info: Dict[str, str]) -> str:
|
| 213 |
+
key_set = set(target_keys or [])
|
| 214 |
+
author = target_info.get("first_author_last", "").lower()
|
| 215 |
+
year = target_info.get("year", "")
|
| 216 |
+
alt_years = {year}
|
| 217 |
+
if year.isdigit():
|
| 218 |
+
alt_years.add(str(int(year) - 1))
|
| 219 |
+
alt_years.add(str(int(year) + 1))
|
| 220 |
+
|
| 221 |
+
def repl(match: re.Match) -> str:
|
| 222 |
+
keys = [k.strip() for k in match.group(1).split(",")]
|
| 223 |
+
for key in keys:
|
| 224 |
+
if key in key_set:
|
| 225 |
+
return "<CITED HERE>"
|
| 226 |
+
key_lc = key.lower()
|
| 227 |
+
if author and author in key_lc and any(y in key_lc for y in alt_years if y):
|
| 228 |
+
return "<CITED HERE>"
|
| 229 |
+
return match.group(0)
|
| 230 |
+
|
| 231 |
+
return re.sub(r"\\cite[a-zA-Z]*\s*\{([^}]+)\}", repl, text)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def match_paragraphs(
|
| 235 |
+
paragraphs: List[str],
|
| 236 |
+
contexts: List[Dict[str, str]],
|
| 237 |
+
) -> List[Dict[str, Any]]:
|
| 238 |
+
results: List[Dict[str, Any]] = []
|
| 239 |
+
para_tokens = [set(tokenize(p)) for p in paragraphs]
|
| 240 |
+
|
| 241 |
+
for idx, ctx in enumerate(contexts, start=1):
|
| 242 |
+
ctx_text = ctx.get("text", "")
|
| 243 |
+
ctx_tokens = set(tokenize(ctx_text))
|
| 244 |
+
if not ctx_tokens:
|
| 245 |
+
continue
|
| 246 |
+
best = None
|
| 247 |
+
best_score = 0.0
|
| 248 |
+
for p_idx, tokens in enumerate(para_tokens):
|
| 249 |
+
if not tokens:
|
| 250 |
+
continue
|
| 251 |
+
overlap = len(ctx_tokens & tokens) / max(1, len(ctx_tokens))
|
| 252 |
+
if overlap > best_score:
|
| 253 |
+
best = p_idx
|
| 254 |
+
best_score = overlap
|
| 255 |
+
if best is not None and best_score >= 0.5:
|
| 256 |
+
paragraph = paragraphs[best]
|
| 257 |
+
results.append(
|
| 258 |
+
{
|
| 259 |
+
"context_id": idx,
|
| 260 |
+
"context": ctx_text,
|
| 261 |
+
"context_with_marker": ctx.get("text_with_marker", ctx_text),
|
| 262 |
+
"paragraph": paragraph,
|
| 263 |
+
"overlap": round(best_score, 3),
|
| 264 |
+
}
|
| 265 |
+
)
|
| 266 |
+
return results
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def _normalize_text(text: str) -> str:
|
| 270 |
+
return " ".join(text.split()).strip().lower()
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _normalize_for_match(text: str) -> str:
|
| 274 |
+
text = text.replace("<CITED HERE>", "")
|
| 275 |
+
text = re.sub(r"\[[^\]]+\]", "", text)
|
| 276 |
+
return _normalize_text(text)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _normalize_author_last(name: str) -> str:
|
| 280 |
+
parts = [p for p in (name or "").split() if p.strip()]
|
| 281 |
+
return parts[-1] if parts else ""
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def extract_target_info(meta: Any) -> Dict[str, str]:
|
| 285 |
+
if isinstance(meta, list) and meta:
|
| 286 |
+
meta = meta[0]
|
| 287 |
+
if not isinstance(meta, dict):
|
| 288 |
+
return {"title": "", "first_author_last": "", "year": ""}
|
| 289 |
+
authors = meta.get("authors") or []
|
| 290 |
+
first_author = authors[0]["name"] if authors else ""
|
| 291 |
+
return {
|
| 292 |
+
"title": meta.get("title", ""),
|
| 293 |
+
"first_author_last": _normalize_author_last(first_author),
|
| 294 |
+
"year": str(meta.get("year", "")),
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def build_citing_contexts_map(
|
| 299 |
+
usage: Dict[str, Any],
|
| 300 |
+
confirmed_texts_by_citing: Dict[str, set] | None,
|
| 301 |
+
) -> Dict[str, Dict[str, Any]]:
|
| 302 |
+
citing_map: Dict[str, Dict[str, Any]] = {}
|
| 303 |
+
for entry in usage.get("citing_papers", []) or []:
|
| 304 |
+
if not isinstance(entry, dict):
|
| 305 |
+
continue
|
| 306 |
+
citing_id = entry.get("citing_paper_id") or ""
|
| 307 |
+
allowed_texts = confirmed_texts_by_citing.get(citing_id) if confirmed_texts_by_citing else None
|
| 308 |
+
allowed_norms = (
|
| 309 |
+
{_normalize_for_match(text) for text in allowed_texts} if allowed_texts else None
|
| 310 |
+
)
|
| 311 |
+
contexts = []
|
| 312 |
+
seen = set()
|
| 313 |
+
for c in entry.get("contexts", []) or []:
|
| 314 |
+
if not isinstance(c, dict):
|
| 315 |
+
continue
|
| 316 |
+
text_raw = (c.get("text") or "").strip()
|
| 317 |
+
text_with_marker = (c.get("context_with_marker") or text_raw).strip()
|
| 318 |
+
if not text_raw:
|
| 319 |
+
continue
|
| 320 |
+
norm = _normalize_for_match(text_raw)
|
| 321 |
+
if allowed_norms is not None and norm not in allowed_norms:
|
| 322 |
+
continue
|
| 323 |
+
if norm in seen:
|
| 324 |
+
continue
|
| 325 |
+
seen.add(norm)
|
| 326 |
+
contexts.append({"text": text_raw, "text_with_marker": text_with_marker})
|
| 327 |
+
if allowed_texts is not None and not contexts:
|
| 328 |
+
for text in allowed_texts:
|
| 329 |
+
norm = _normalize_for_match(text)
|
| 330 |
+
if norm in seen:
|
| 331 |
+
continue
|
| 332 |
+
seen.add(norm)
|
| 333 |
+
contexts.append({"text": text, "text_with_marker": text})
|
| 334 |
+
citing_map[citing_id] = {
|
| 335 |
+
"title": entry.get("title", ""),
|
| 336 |
+
"paper_id": citing_id,
|
| 337 |
+
"arxiv_id": (entry.get("external_ids") or {}).get("ArXiv", ""),
|
| 338 |
+
"contexts": contexts,
|
| 339 |
+
}
|
| 340 |
+
return citing_map
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def process_citing_paper(citing: Dict[str, Any]) -> Dict[str, Any]:
|
| 344 |
+
target_info = citing.get("target_info", {})
|
| 345 |
+
arxiv_id = citing.get("arxiv_id", "")
|
| 346 |
+
if not arxiv_id:
|
| 347 |
+
return {"error": "missing_arxiv_id", **citing}
|
| 348 |
+
|
| 349 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 350 |
+
tmpdir = Path(tmp)
|
| 351 |
+
if not download_arxiv_source(arxiv_id, tmpdir):
|
| 352 |
+
return {"error": "bad_arxiv_archive", **citing}
|
| 353 |
+
main_tex = find_main_tex(tmpdir)
|
| 354 |
+
if not main_tex:
|
| 355 |
+
return {"error": "missing_main_tex", **citing}
|
| 356 |
+
|
| 357 |
+
tex_text = main_tex.read_text(encoding="utf-8", errors="ignore")
|
| 358 |
+
tex_text = strip_latex_comments(tex_text)
|
| 359 |
+
bibs = read_bib_files(tmpdir)
|
| 360 |
+
target_keys = find_target_bib_keys(bibs, target_info)
|
| 361 |
+
tex_text = replace_target_citations(tex_text, target_keys, target_info)
|
| 362 |
+
paragraphs = paragraphize(tex_text)
|
| 363 |
+
target_citing_paragraphs = [p for p in paragraphs if "<CITED HERE>" in p]
|
| 364 |
+
matched = match_paragraphs(paragraphs, citing.get("contexts", []))
|
| 365 |
+
|
| 366 |
+
return {
|
| 367 |
+
"citing_paper_id": citing.get("paper_id", ""),
|
| 368 |
+
"citing_title": citing.get("title", ""),
|
| 369 |
+
"arxiv_id": arxiv_id,
|
| 370 |
+
"main_tex_file": str(main_tex.relative_to(tmpdir)),
|
| 371 |
+
"bib_files": list(bibs.keys()),
|
| 372 |
+
"bib_texts": bibs,
|
| 373 |
+
"target_bib_keys": target_keys,
|
| 374 |
+
"contexts": citing.get("contexts", []),
|
| 375 |
+
"target_citing_paragraphs": target_citing_paragraphs,
|
| 376 |
+
"matched_paragraphs": matched,
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def process_paper(root: Path, overwrite: bool, include_all: bool, resume: bool) -> str:
|
| 381 |
+
usage = load_json(root / USAGE_CONTEXTS_FILE)
|
| 382 |
+
if not isinstance(usage, dict):
|
| 383 |
+
return "missing_usage"
|
| 384 |
+
|
| 385 |
+
out_path = root / OUT_FILE
|
| 386 |
+
if out_path.exists() and (resume or not overwrite):
|
| 387 |
+
return "skipped"
|
| 388 |
+
|
| 389 |
+
verified = None
|
| 390 |
+
confirmed_texts_by_citing: Dict[str, set] = {}
|
| 391 |
+
if not include_all:
|
| 392 |
+
verified = load_json(root / VERIFIED_FILE)
|
| 393 |
+
if not isinstance(verified, dict):
|
| 394 |
+
return "missing_verified"
|
| 395 |
+
for item in verified.get("confirmed", []) or []:
|
| 396 |
+
citing_id = item.get("citing_paper_id") or ""
|
| 397 |
+
text = item.get("text") or ""
|
| 398 |
+
if not citing_id or not text:
|
| 399 |
+
continue
|
| 400 |
+
confirmed_texts_by_citing.setdefault(citing_id, set()).add(text)
|
| 401 |
+
|
| 402 |
+
target_info = extract_target_info(load_json(root / PAPER_META_FILE))
|
| 403 |
+
citing_map = build_citing_contexts_map(
|
| 404 |
+
usage,
|
| 405 |
+
confirmed_texts_by_citing if confirmed_texts_by_citing else None,
|
| 406 |
+
)
|
| 407 |
+
if not citing_map:
|
| 408 |
+
out_path.write_text(
|
| 409 |
+
json.dumps({"paper_id": usage.get("paper_id"), "citing_papers": []}, indent=2),
|
| 410 |
+
encoding="utf-8",
|
| 411 |
+
)
|
| 412 |
+
return "empty_citing"
|
| 413 |
+
|
| 414 |
+
confirmed_ids: Optional[set] = None
|
| 415 |
+
if not include_all and isinstance(verified, dict):
|
| 416 |
+
confirmed = verified.get("confirmed", [])
|
| 417 |
+
confirmed_ids = {
|
| 418 |
+
item.get("citing_paper_id")
|
| 419 |
+
for item in confirmed
|
| 420 |
+
if item.get("citing_paper_id")
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
citing_papers = []
|
| 424 |
+
for citing_id, citing in citing_map.items():
|
| 425 |
+
if confirmed_ids is not None and citing_id not in confirmed_ids:
|
| 426 |
+
continue
|
| 427 |
+
citing["target_info"] = target_info
|
| 428 |
+
citing_papers.append(process_citing_paper(citing))
|
| 429 |
+
|
| 430 |
+
payload = {"paper_id": usage.get("paper_id"), "citing_papers": citing_papers}
|
| 431 |
+
out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 432 |
+
return "processed"
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def main() -> None:
|
| 436 |
+
parser = argparse.ArgumentParser(
|
| 437 |
+
description="Download arXiv sources and extract citation-local paragraphs."
|
| 438 |
+
)
|
| 439 |
+
parser.add_argument(
|
| 440 |
+
"--root",
|
| 441 |
+
type=str,
|
| 442 |
+
default="runs/processed_papers",
|
| 443 |
+
help="Root directory containing processed paper directories.",
|
| 444 |
+
)
|
| 445 |
+
parser.add_argument(
|
| 446 |
+
"--overwrite",
|
| 447 |
+
action="store_true",
|
| 448 |
+
help="Overwrite existing usage_citing_paragraphs.json files.",
|
| 449 |
+
)
|
| 450 |
+
parser.add_argument(
|
| 451 |
+
"--all",
|
| 452 |
+
action="store_true",
|
| 453 |
+
help="Process all citing papers (not just confirmed USES/EXTENDS).",
|
| 454 |
+
)
|
| 455 |
+
parser.add_argument(
|
| 456 |
+
"--resume",
|
| 457 |
+
action="store_true",
|
| 458 |
+
help="Skip papers with existing output files (even if --overwrite is set).",
|
| 459 |
+
)
|
| 460 |
+
args = parser.parse_args()
|
| 461 |
+
|
| 462 |
+
root = Path(args.root).expanduser().resolve()
|
| 463 |
+
if not root.exists():
|
| 464 |
+
raise SystemExit(f"Root directory does not exist: {root}")
|
| 465 |
+
|
| 466 |
+
paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
|
| 467 |
+
print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
|
| 468 |
+
|
| 469 |
+
counts = {
|
| 470 |
+
"processed": 0,
|
| 471 |
+
"skipped": 0,
|
| 472 |
+
"missing_usage": 0,
|
| 473 |
+
"missing_verified": 0,
|
| 474 |
+
"empty_citing": 0,
|
| 475 |
+
}
|
| 476 |
+
for paper_dir in paper_dirs:
|
| 477 |
+
status = process_paper(paper_dir, args.overwrite, args.all, args.resume)
|
| 478 |
+
counts[status] = counts.get(status, 0) + 1
|
| 479 |
+
print(f"[{status.upper()}] {paper_dir.name}")
|
| 480 |
+
|
| 481 |
+
print(
|
| 482 |
+
"[SUMMARY] processed={processed}, skipped={skipped}, missing_usage={missing_usage}, "
|
| 483 |
+
"missing_verified={missing_verified}, empty_citing={empty_citing}".format(**counts)
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
if __name__ == "__main__":
|
| 488 |
+
main()
|
src/step_07_extract_and_refine/extract_contributions_from_citations.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
SRC_ROOT = Path(__file__).resolve().parents[1]
|
| 8 |
+
if str(SRC_ROOT) not in sys.path:
|
| 9 |
+
sys.path.insert(0, str(SRC_ROOT))
|
| 10 |
+
|
| 11 |
+
from common.llm_client import LLMClient
|
| 12 |
+
|
| 13 |
+
from prompts import build_contribution_prompt
|
| 14 |
+
from schemas import CONTRIBUTION_JSON_SCHEMA
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
PAPER_META_FILE = "paper_metadata.json"
|
| 18 |
+
USAGE_CONTEXTS_FILE = "usage_contexts.json"
|
| 19 |
+
ARXIV_PARAGRAPHS_FILE = "usage_citing_paragraphs.json"
|
| 20 |
+
VERIFIED_FILE = "usage_uses_extends_verified.json"
|
| 21 |
+
OUT_FILE = "usage_contributions.json"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_json(path: Path) -> Any | None:
|
| 25 |
+
if not path.exists():
|
| 26 |
+
return None
|
| 27 |
+
try:
|
| 28 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 29 |
+
except Exception:
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def iter_paper_dirs(root: Path) -> List[Path]:
|
| 34 |
+
out: List[Path] = []
|
| 35 |
+
for child in root.iterdir():
|
| 36 |
+
if child.is_dir() and (child / PAPER_META_FILE).exists():
|
| 37 |
+
out.append(child)
|
| 38 |
+
return out
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _normalize_author_last(name: str) -> str:
|
| 42 |
+
parts = [p for p in (name or "").split() if p.strip()]
|
| 43 |
+
return parts[-1] if parts else ""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def extract_target_info(meta: Any) -> Dict[str, str]:
|
| 47 |
+
if isinstance(meta, list) and meta:
|
| 48 |
+
meta = meta[0]
|
| 49 |
+
if not isinstance(meta, dict):
|
| 50 |
+
return {
|
| 51 |
+
"title": "",
|
| 52 |
+
"first_author_last": "",
|
| 53 |
+
"year": "",
|
| 54 |
+
"tldr": "",
|
| 55 |
+
"abstract": "",
|
| 56 |
+
}
|
| 57 |
+
authors = meta.get("authors") or []
|
| 58 |
+
first_author = authors[0]["name"] if authors else ""
|
| 59 |
+
tldr = ""
|
| 60 |
+
tldr_obj = meta.get("tldr")
|
| 61 |
+
if isinstance(tldr_obj, dict):
|
| 62 |
+
tldr = tldr_obj.get("text", "")
|
| 63 |
+
return {
|
| 64 |
+
"title": meta.get("title", ""),
|
| 65 |
+
"first_author_last": _normalize_author_last(first_author),
|
| 66 |
+
"year": str(meta.get("year", "")),
|
| 67 |
+
"tldr": tldr,
|
| 68 |
+
"abstract": meta.get("abstract", ""),
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def build_citing_contexts_map_from_paragraphs(
|
| 73 |
+
arxiv_data: Dict[str, Any],
|
| 74 |
+
) -> Dict[str, Dict[str, Any]]:
|
| 75 |
+
citing_map: Dict[str, Dict[str, Any]] = {}
|
| 76 |
+
for entry in arxiv_data.get("citing_papers", []) or []:
|
| 77 |
+
if not isinstance(entry, dict):
|
| 78 |
+
continue
|
| 79 |
+
citing_id = entry.get("citing_paper_id") or ""
|
| 80 |
+
contexts = []
|
| 81 |
+
seen = set()
|
| 82 |
+
for paragraph in entry.get("target_citing_paragraphs", []) or []:
|
| 83 |
+
paragraph = (paragraph or "").strip()
|
| 84 |
+
if not paragraph:
|
| 85 |
+
continue
|
| 86 |
+
combined = f"Target-citing paragraph: {paragraph}"
|
| 87 |
+
norm = " ".join(combined.split()).lower()
|
| 88 |
+
if norm in seen:
|
| 89 |
+
continue
|
| 90 |
+
seen.add(norm)
|
| 91 |
+
contexts.append(combined)
|
| 92 |
+
citing_map[citing_id] = {
|
| 93 |
+
"title": entry.get("citing_title", ""),
|
| 94 |
+
"paper_id": citing_id,
|
| 95 |
+
"contexts": contexts,
|
| 96 |
+
"source": "arxiv_paragraphs",
|
| 97 |
+
}
|
| 98 |
+
return citing_map
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def build_citing_contexts_map_from_usage(
|
| 102 |
+
usage: Dict[str, Any],
|
| 103 |
+
confirmed_texts_by_citing: Dict[str, set] | None,
|
| 104 |
+
) -> Dict[str, Dict[str, Any]]:
|
| 105 |
+
citing_map: Dict[str, Dict[str, Any]] = {}
|
| 106 |
+
for entry in usage.get("citing_papers", []) or []:
|
| 107 |
+
if not isinstance(entry, dict):
|
| 108 |
+
continue
|
| 109 |
+
citing_id = entry.get("citing_paper_id") or ""
|
| 110 |
+
allowed_texts = confirmed_texts_by_citing.get(citing_id) if confirmed_texts_by_citing else None
|
| 111 |
+
contexts = []
|
| 112 |
+
seen = set()
|
| 113 |
+
for c in entry.get("contexts", []) or []:
|
| 114 |
+
if not isinstance(c, dict):
|
| 115 |
+
continue
|
| 116 |
+
text = (c.get("context_with_marker") or c.get("text") or "").strip()
|
| 117 |
+
if not text:
|
| 118 |
+
continue
|
| 119 |
+
if allowed_texts is not None and text not in allowed_texts:
|
| 120 |
+
continue
|
| 121 |
+
norm = " ".join(text.split()).lower()
|
| 122 |
+
if norm in seen:
|
| 123 |
+
continue
|
| 124 |
+
seen.add(norm)
|
| 125 |
+
contexts.append(f"Target sentence: {text}")
|
| 126 |
+
citing_map[citing_id] = {
|
| 127 |
+
"title": entry.get("title", ""),
|
| 128 |
+
"paper_id": citing_id,
|
| 129 |
+
"contexts": contexts,
|
| 130 |
+
"source": "usage_contexts_fallback",
|
| 131 |
+
}
|
| 132 |
+
return citing_map
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def extract_contribution(
|
| 136 |
+
client: LLMClient,
|
| 137 |
+
target_info: Dict[str, str],
|
| 138 |
+
citing_info: Dict[str, Any],
|
| 139 |
+
) -> Dict[str, Any]:
|
| 140 |
+
contexts = citing_info.get("contexts", [])
|
| 141 |
+
prompt = build_contribution_prompt(target_info, citing_info, contexts)
|
| 142 |
+
raw = client.call(prompt, schema=CONTRIBUTION_JSON_SCHEMA)
|
| 143 |
+
data = _parse_llm_json(raw)
|
| 144 |
+
if not isinstance(data, dict):
|
| 145 |
+
return {
|
| 146 |
+
"citing_paper_id": citing_info.get("paper_id", ""),
|
| 147 |
+
"citing_title": citing_info.get("title", ""),
|
| 148 |
+
"label": "NOT_CONFIRMED",
|
| 149 |
+
"paper_claim": "",
|
| 150 |
+
"claim": "",
|
| 151 |
+
"cluster_title": "",
|
| 152 |
+
"cluster_key": "",
|
| 153 |
+
"evidence_span": "",
|
| 154 |
+
"rationale": "",
|
| 155 |
+
"contexts": contexts,
|
| 156 |
+
"source": citing_info.get("source", "unknown"),
|
| 157 |
+
}
|
| 158 |
+
label = data.get("label", "NOT_CONFIRMED")
|
| 159 |
+
paper_claim = data.get("paper_claim", "") or data.get("claim", "")
|
| 160 |
+
cluster_title = data.get("cluster_title", "") or data.get("cluster_claim", "")
|
| 161 |
+
cluster_key = data.get("cluster_key", "")
|
| 162 |
+
evidence_span = data.get("evidence_span", "")
|
| 163 |
+
if not evidence_span:
|
| 164 |
+
label = "NOT_CONFIRMED"
|
| 165 |
+
paper_claim = ""
|
| 166 |
+
cluster_title = ""
|
| 167 |
+
cluster_key = ""
|
| 168 |
+
if label in {"USES", "EXTENDS"} and not cluster_title:
|
| 169 |
+
cluster_title = paper_claim
|
| 170 |
+
if label in {"USES", "EXTENDS"} and not cluster_key:
|
| 171 |
+
cluster_key = f"{label}|contribution|unspecified"
|
| 172 |
+
return {
|
| 173 |
+
"citing_paper_id": citing_info.get("paper_id", ""),
|
| 174 |
+
"citing_title": citing_info.get("title", ""),
|
| 175 |
+
"label": label,
|
| 176 |
+
"paper_claim": paper_claim,
|
| 177 |
+
"claim": paper_claim,
|
| 178 |
+
"cluster_title": cluster_title,
|
| 179 |
+
"cluster_key": cluster_key,
|
| 180 |
+
"evidence_span": evidence_span,
|
| 181 |
+
"rationale": data.get("rationale", ""),
|
| 182 |
+
"contexts": contexts,
|
| 183 |
+
"source": citing_info.get("source", "unknown"),
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _parse_llm_json(raw: str) -> Any | None:
|
| 188 |
+
try:
|
| 189 |
+
return json.loads(raw)
|
| 190 |
+
except json.JSONDecodeError:
|
| 191 |
+
pass
|
| 192 |
+
|
| 193 |
+
cleaned = raw.strip()
|
| 194 |
+
if cleaned.startswith("```"):
|
| 195 |
+
cleaned = cleaned.strip("`")
|
| 196 |
+
cleaned = cleaned.replace("json", "", 1).strip()
|
| 197 |
+
|
| 198 |
+
start = cleaned.find("{")
|
| 199 |
+
end = cleaned.rfind("}")
|
| 200 |
+
if start == -1 or end == -1 or end <= start:
|
| 201 |
+
return None
|
| 202 |
+
|
| 203 |
+
snippet = cleaned[start : end + 1]
|
| 204 |
+
try:
|
| 205 |
+
return json.loads(snippet)
|
| 206 |
+
except json.JSONDecodeError:
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def process_paper(
|
| 211 |
+
paper_dir: Path,
|
| 212 |
+
client: LLMClient,
|
| 213 |
+
overwrite: bool,
|
| 214 |
+
resume: bool,
|
| 215 |
+
) -> str:
|
| 216 |
+
verified = load_json(paper_dir / VERIFIED_FILE)
|
| 217 |
+
if not isinstance(verified, dict):
|
| 218 |
+
return "missing_verified"
|
| 219 |
+
out_path = paper_dir / OUT_FILE
|
| 220 |
+
if out_path.exists() and (resume or not overwrite):
|
| 221 |
+
return "skipped"
|
| 222 |
+
|
| 223 |
+
if verified.get("final_label") == "NOT_CONFIRMED":
|
| 224 |
+
payload = {
|
| 225 |
+
"paper_id": verified.get("paper_id"),
|
| 226 |
+
"final_label": "NOT_CONFIRMED",
|
| 227 |
+
"contributions": [],
|
| 228 |
+
}
|
| 229 |
+
out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 230 |
+
return "no_confirmed"
|
| 231 |
+
|
| 232 |
+
arxiv_data = load_json(paper_dir / ARXIV_PARAGRAPHS_FILE)
|
| 233 |
+
if not isinstance(arxiv_data, dict):
|
| 234 |
+
return "missing_arxiv_paragraphs"
|
| 235 |
+
|
| 236 |
+
target_info = extract_target_info(load_json(paper_dir / PAPER_META_FILE))
|
| 237 |
+
citing_map = build_citing_contexts_map_from_paragraphs(arxiv_data)
|
| 238 |
+
usage = load_json(paper_dir / USAGE_CONTEXTS_FILE)
|
| 239 |
+
confirmed_texts_by_citing: Dict[str, set] = {}
|
| 240 |
+
for item in verified.get("confirmed", []) or []:
|
| 241 |
+
citing_id = item.get("citing_paper_id") or ""
|
| 242 |
+
text = item.get("text") or ""
|
| 243 |
+
if not citing_id or not text:
|
| 244 |
+
continue
|
| 245 |
+
confirmed_texts_by_citing.setdefault(citing_id, set()).add(text)
|
| 246 |
+
usage_map = (
|
| 247 |
+
build_citing_contexts_map_from_usage(usage, confirmed_texts_by_citing)
|
| 248 |
+
if isinstance(usage, dict)
|
| 249 |
+
else {}
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
confirmed = verified.get("confirmed", [])
|
| 253 |
+
confirmed_ids = {item.get("citing_paper_id") for item in confirmed if item.get("citing_paper_id")}
|
| 254 |
+
contributions: List[Dict[str, Any]] = []
|
| 255 |
+
fallback_citing_ids: List[str] = []
|
| 256 |
+
for citing_id in confirmed_ids:
|
| 257 |
+
citing_info = citing_map.get(citing_id)
|
| 258 |
+
if citing_info and not citing_info.get("contexts"):
|
| 259 |
+
citing_info = None
|
| 260 |
+
if not citing_info:
|
| 261 |
+
fallback = usage_map.get(citing_id)
|
| 262 |
+
if fallback and fallback.get("contexts"):
|
| 263 |
+
citing_info = fallback
|
| 264 |
+
fallback_citing_ids.append(citing_id)
|
| 265 |
+
else:
|
| 266 |
+
continue
|
| 267 |
+
contributions.append(extract_contribution(client, target_info, citing_info))
|
| 268 |
+
|
| 269 |
+
payload = {
|
| 270 |
+
"paper_id": verified.get("paper_id"),
|
| 271 |
+
"final_label": verified.get("final_label"),
|
| 272 |
+
"contributions": contributions,
|
| 273 |
+
"source": "arxiv_paragraphs",
|
| 274 |
+
"fallback_citing_ids": fallback_citing_ids,
|
| 275 |
+
}
|
| 276 |
+
out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 277 |
+
return "labeled"
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def main() -> None:
|
| 281 |
+
parser = argparse.ArgumentParser(
|
| 282 |
+
description="Extract per-citing-paper contribution claims from verified USES/EXTENDS."
|
| 283 |
+
)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
"--root",
|
| 286 |
+
type=str,
|
| 287 |
+
default="runs/processed_papers",
|
| 288 |
+
help="Root directory containing processed paper directories.",
|
| 289 |
+
)
|
| 290 |
+
parser.add_argument(
|
| 291 |
+
"--overwrite",
|
| 292 |
+
action="store_true",
|
| 293 |
+
help="Overwrite existing usage_contributions.json files.",
|
| 294 |
+
)
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--resume",
|
| 297 |
+
action="store_true",
|
| 298 |
+
help="Skip papers with existing output files (even if --overwrite is set).",
|
| 299 |
+
)
|
| 300 |
+
args = parser.parse_args()
|
| 301 |
+
|
| 302 |
+
root = Path(args.root).expanduser().resolve()
|
| 303 |
+
if not root.exists():
|
| 304 |
+
raise SystemExit(f"Root directory does not exist: {root}")
|
| 305 |
+
|
| 306 |
+
client = LLMClient()
|
| 307 |
+
paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
|
| 308 |
+
print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
|
| 309 |
+
|
| 310 |
+
counts = {
|
| 311 |
+
"labeled": 0,
|
| 312 |
+
"skipped": 0,
|
| 313 |
+
"missing_verified": 0,
|
| 314 |
+
"missing_arxiv_paragraphs": 0,
|
| 315 |
+
"no_confirmed": 0,
|
| 316 |
+
}
|
| 317 |
+
for paper_dir in paper_dirs:
|
| 318 |
+
status = process_paper(paper_dir, client, args.overwrite, args.resume)
|
| 319 |
+
counts[status] = counts.get(status, 0) + 1
|
| 320 |
+
print(f"[{status.upper()}] {paper_dir.name}")
|
| 321 |
+
|
| 322 |
+
print(
|
| 323 |
+
"[SUMMARY] labeled={labeled}, skipped={skipped}, missing_verified={missing_verified}, "
|
| 324 |
+
"missing_arxiv_paragraphs={missing_arxiv_paragraphs}, no_confirmed={no_confirmed}".format(**counts)
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if __name__ == "__main__":
|
| 329 |
+
main()
|
src/step_07_extract_and_refine/prompts.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def build_contribution_prompt(
|
| 5 |
+
target_info: Dict[str, str],
|
| 6 |
+
citing_info: Dict[str, str],
|
| 7 |
+
contexts: List[str],
|
| 8 |
+
) -> str:
|
| 9 |
+
header = [
|
| 10 |
+
"You are extracting how a citing paper uses or extends a target paper.",
|
| 11 |
+
"Read the paragraph(s) below and write ONE concise contribution claim.",
|
| 12 |
+
"Focus only on what the citing paper actually does with the target paper.",
|
| 13 |
+
"",
|
| 14 |
+
"Rules:",
|
| 15 |
+
"- If the citing paper explicitly uses/adopts/evaluates on the target paper's method/data/benchmark, label USES.",
|
| 16 |
+
"- If it explicitly extends/modifies/adapts/builds upon the target paper, label EXTENDS.",
|
| 17 |
+
"- If the paragraph is only descriptive/background or only compares/mentions the target paper, return label NOT_CONFIRMED and empty fields.",
|
| 18 |
+
"- Do not output comparison-only claims (e.g., 'compares to <CITED HERE>'); those are NOT_CONFIRMED.",
|
| 19 |
+
"- Output paper_claim: one concise, paper-specific contribution claim.",
|
| 20 |
+
"- Output cluster_title: concise natural-language cluster summary (6-14 words), generic across papers.",
|
| 21 |
+
"- Also output cluster_key in this exact format: RELATION|artifact|purpose",
|
| 22 |
+
"- cluster_key must be generic and reusable across papers.",
|
| 23 |
+
"- artifact and purpose must be short snake_case phrases (e.g., dataset, evaluation_protocol, evaluation).",
|
| 24 |
+
"- cluster_key RELATION must exactly match label.",
|
| 25 |
+
"- Avoid overly specific keys (no paper names, no model/version numbers, no citation keys).",
|
| 26 |
+
"- Prefer stable generic keys such as: USES|dataset|evaluation, EXTENDS|dataset|dataset_creation, USES|evaluation_protocol|evaluation.",
|
| 27 |
+
"- If label is NOT_CONFIRMED, paper_claim, cluster_title, cluster_key, and evidence_span must be empty.",
|
| 28 |
+
"- The evidence_span must be a verbatim substring from the provided contexts.",
|
| 29 |
+
"- The TARGET_PAPER abstract/TLDR is for background only; do not use it as evidence.",
|
| 30 |
+
"",
|
| 31 |
+
"Negative example (NOT_CONFIRMED):",
|
| 32 |
+
"Paragraph: \"We compare our method to <CITED HERE> and other baselines.\"",
|
| 33 |
+
"Output: {\"label\":\"NOT_CONFIRMED\",\"paper_claim\":\"\",\"cluster_title\":\"\",\"cluster_key\":\"\",\"evidence_span\":\"\",\"rationale\":\"Comparison only.\"}",
|
| 34 |
+
"",
|
| 35 |
+
"Return JSON only.",
|
| 36 |
+
"",
|
| 37 |
+
"TARGET_PAPER:",
|
| 38 |
+
f"- title: {target_info.get('title', '')}",
|
| 39 |
+
f"- first_author_last: {target_info.get('first_author_last', '')}",
|
| 40 |
+
f"- year: {target_info.get('year', '')}",
|
| 41 |
+
f"- tldr: {target_info.get('tldr', '')}",
|
| 42 |
+
f"- abstract: {target_info.get('abstract', '')}",
|
| 43 |
+
"",
|
| 44 |
+
"CITING_PAPER:",
|
| 45 |
+
f"- title: {citing_info.get('title', '')}",
|
| 46 |
+
f"- paper_id: {citing_info.get('paper_id', '')}",
|
| 47 |
+
"",
|
| 48 |
+
"CONTEXTS (verbatim, same order as extracted):",
|
| 49 |
+
]
|
| 50 |
+
for i, text in enumerate(contexts, start=1):
|
| 51 |
+
header.append(f"({i}) {text}")
|
| 52 |
+
|
| 53 |
+
header.append("")
|
| 54 |
+
header.append("JSON OUTPUT:")
|
| 55 |
+
header.append(
|
| 56 |
+
"{"
|
| 57 |
+
"\"label\":\"USES\","
|
| 58 |
+
"\"paper_claim\":\"...\","
|
| 59 |
+
"\"cluster_title\":\"Uses target dataset for evaluation\","
|
| 60 |
+
"\"cluster_key\":\"USES|dataset|evaluation\","
|
| 61 |
+
"\"evidence_span\":\"...\","
|
| 62 |
+
"\"rationale\":\"...\""
|
| 63 |
+
"}"
|
| 64 |
+
)
|
| 65 |
+
return "\n".join(header)
|
src/step_07_extract_and_refine/refine_and_filter_clusters_llm.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List, Tuple
|
| 6 |
+
|
| 7 |
+
SRC_ROOT = Path(__file__).resolve().parents[1]
|
| 8 |
+
if str(SRC_ROOT) not in sys.path:
|
| 9 |
+
sys.path.insert(0, str(SRC_ROOT))
|
| 10 |
+
|
| 11 |
+
from common.llm_client import LLMClient
|
| 12 |
+
|
| 13 |
+
PAPER_META_FILE = "paper_metadata.json"
|
| 14 |
+
CONTRIB_FILE = "usage_contributions.json"
|
| 15 |
+
DISCOVERY_FILE = "usage_discovery_from_contributions.json"
|
| 16 |
+
OUT_FILE = "usage_discovery_from_contributions_refined.json"
|
| 17 |
+
|
| 18 |
+
REFINE_SCHEMA = {
|
| 19 |
+
"type": "object",
|
| 20 |
+
"properties": {
|
| 21 |
+
"kept_groups": {
|
| 22 |
+
"type": "array",
|
| 23 |
+
"items": {
|
| 24 |
+
"type": "object",
|
| 25 |
+
"properties": {
|
| 26 |
+
"cluster_ids": {"type": "array", "items": {"type": "string"}},
|
| 27 |
+
"merged_title": {"type": "string"},
|
| 28 |
+
"merged_key": {"type": "string"},
|
| 29 |
+
"rationale": {"type": "string"},
|
| 30 |
+
},
|
| 31 |
+
"required": ["cluster_ids", "merged_title", "merged_key", "rationale"],
|
| 32 |
+
},
|
| 33 |
+
},
|
| 34 |
+
"dropped_clusters": {
|
| 35 |
+
"type": "array",
|
| 36 |
+
"items": {
|
| 37 |
+
"type": "object",
|
| 38 |
+
"properties": {
|
| 39 |
+
"cluster_id": {"type": "string"},
|
| 40 |
+
"reason": {"type": "string"},
|
| 41 |
+
},
|
| 42 |
+
"required": ["cluster_id", "reason"],
|
| 43 |
+
},
|
| 44 |
+
},
|
| 45 |
+
},
|
| 46 |
+
"required": ["kept_groups", "dropped_clusters"],
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_json(path: Path) -> Any | None:
|
| 51 |
+
if not path.exists():
|
| 52 |
+
return None
|
| 53 |
+
try:
|
| 54 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 55 |
+
except Exception:
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def iter_paper_dirs(root: Path) -> List[Path]:
|
| 60 |
+
return sorted([p for p in root.iterdir() if p.is_dir() and (p / PAPER_META_FILE).exists()], key=lambda p: p.name)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _to_int_indices(raw_indices: List[Any]) -> List[int]:
|
| 64 |
+
out: List[int] = []
|
| 65 |
+
for i in raw_indices or []:
|
| 66 |
+
try:
|
| 67 |
+
out.append(int(i))
|
| 68 |
+
except Exception:
|
| 69 |
+
continue
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _parse_key(key: str) -> Tuple[str, str, str]:
|
| 74 |
+
parts = [p.strip() for p in str(key or "").split("|")]
|
| 75 |
+
if len(parts) >= 3:
|
| 76 |
+
return parts[0].upper(), parts[1], parts[2]
|
| 77 |
+
return "", "", ""
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _dominant_key(member_clusters: List[Dict[str, Any]]) -> str:
|
| 81 |
+
rel_count: Dict[str, int] = {}
|
| 82 |
+
art_count: Dict[str, int] = {}
|
| 83 |
+
pur_count: Dict[str, int] = {}
|
| 84 |
+
for c in member_clusters:
|
| 85 |
+
rel, art, pur = _parse_key(c.get("cluster_key", ""))
|
| 86 |
+
if rel:
|
| 87 |
+
rel_count[rel] = rel_count.get(rel, 0) + 1
|
| 88 |
+
if art:
|
| 89 |
+
art_count[art] = art_count.get(art, 0) + 1
|
| 90 |
+
if pur:
|
| 91 |
+
pur_count[pur] = pur_count.get(pur, 0) + 1
|
| 92 |
+
rel = max(rel_count, key=rel_count.get) if rel_count else "USES"
|
| 93 |
+
art = max(art_count, key=art_count.get) if art_count else "contribution"
|
| 94 |
+
pur = max(pur_count, key=pur_count.get) if pur_count else "unspecified"
|
| 95 |
+
return f"{rel}|{art}|{pur}"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _extract_title(meta: Any) -> str:
|
| 99 |
+
if isinstance(meta, list) and meta:
|
| 100 |
+
meta = meta[0]
|
| 101 |
+
if not isinstance(meta, dict):
|
| 102 |
+
return ""
|
| 103 |
+
return str(meta.get("title", ""))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _title_from_cluster_key(cluster_key: str) -> str:
|
| 107 |
+
parts = [p.strip() for p in str(cluster_key or "").split("|")]
|
| 108 |
+
if len(parts) >= 3:
|
| 109 |
+
relation, artifact, purpose = parts[0], parts[1], parts[2]
|
| 110 |
+
relation_txt = "Uses" if relation.upper() == "USES" else "Extends"
|
| 111 |
+
artifact_txt = artifact.replace("_", " ")
|
| 112 |
+
purpose_txt = purpose.replace("_", " ")
|
| 113 |
+
return f"{relation_txt} {artifact_txt} for {purpose_txt}".strip()
|
| 114 |
+
return cluster_key or ""
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _cluster_by_exact_keys(keys: List[str]) -> List[List[int]]:
|
| 118 |
+
groups: Dict[str, List[int]] = {}
|
| 119 |
+
order: List[str] = []
|
| 120 |
+
for i, key in enumerate(keys):
|
| 121 |
+
k = (key or "").strip()
|
| 122 |
+
if not k:
|
| 123 |
+
k = f"__EMPTY__::{i}"
|
| 124 |
+
if k not in groups:
|
| 125 |
+
groups[k] = []
|
| 126 |
+
order.append(k)
|
| 127 |
+
groups[k].append(i)
|
| 128 |
+
return [groups[k] for k in order]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _build_initial_clusters_from_contributions(contrib: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 132 |
+
contributions = [
|
| 133 |
+
c for c in contrib.get("contributions", []) or []
|
| 134 |
+
if c.get("label") in {"USES", "EXTENDS"} and (c.get("paper_claim") or c.get("claim"))
|
| 135 |
+
]
|
| 136 |
+
if not contributions:
|
| 137 |
+
return []
|
| 138 |
+
cluster_keys_all: List[str] = []
|
| 139 |
+
for c in contributions:
|
| 140 |
+
key = (c.get("cluster_key") or "").strip()
|
| 141 |
+
if not key:
|
| 142 |
+
label = str(c.get("label", "USES")).upper()
|
| 143 |
+
if label not in {"USES", "EXTENDS"}:
|
| 144 |
+
label = "USES"
|
| 145 |
+
key = f"{label}|contribution|unspecified"
|
| 146 |
+
cluster_keys_all.append(key)
|
| 147 |
+
clusters = _cluster_by_exact_keys(cluster_keys_all)
|
| 148 |
+
out: List[Dict[str, Any]] = []
|
| 149 |
+
for idx, cluster in enumerate(clusters, start=1):
|
| 150 |
+
first = contributions[cluster[0]]
|
| 151 |
+
key = cluster_keys_all[cluster[0]]
|
| 152 |
+
title = (first.get("cluster_title") or "").strip() or _title_from_cluster_key(key)
|
| 153 |
+
out.append({
|
| 154 |
+
"cluster_id": f"C{idx}",
|
| 155 |
+
"count": str(len(cluster)),
|
| 156 |
+
"representative_claim": title,
|
| 157 |
+
"cluster_key": key,
|
| 158 |
+
"cluster_title": title,
|
| 159 |
+
"claim_indices": [str(i) for i in cluster],
|
| 160 |
+
})
|
| 161 |
+
return out
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _cluster_support_summary(cluster: Dict[str, Any], contributions: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 165 |
+
indices = _to_int_indices(cluster.get("claim_indices") or [])
|
| 166 |
+
items: List[Dict[str, Any]] = []
|
| 167 |
+
for i in indices:
|
| 168 |
+
if 0 <= i < len(contributions):
|
| 169 |
+
items.append(contributions[i])
|
| 170 |
+
labels = [str(item.get("label", "")).upper() for item in items if item.get("label")]
|
| 171 |
+
examples: List[str] = []
|
| 172 |
+
for item in items:
|
| 173 |
+
text = str(item.get("paper_claim") or item.get("claim") or "").strip()
|
| 174 |
+
if text:
|
| 175 |
+
examples.append(text)
|
| 176 |
+
if len(examples) >= 3:
|
| 177 |
+
break
|
| 178 |
+
rationales = [str(item.get("rationale", "")).strip() for item in items if item.get("rationale")][:2]
|
| 179 |
+
use_count = sum(1 for x in labels if x == "USES")
|
| 180 |
+
ext_count = sum(1 for x in labels if x == "EXTENDS")
|
| 181 |
+
return {
|
| 182 |
+
"examples": examples,
|
| 183 |
+
"rationales": rationales,
|
| 184 |
+
"uses_count": use_count,
|
| 185 |
+
"extends_count": ext_count,
|
| 186 |
+
"member_count": len(items),
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def build_prompt(paper_title: str, centroids: List[Dict[str, Any]]) -> str:
|
| 191 |
+
lines: List[str] = [
|
| 192 |
+
"You are refining downstream citation contribution clusters for one target paper.",
|
| 193 |
+
"Input clusters are already built. Your job is to (a) conservatively merge near-duplicate downstream-usage clusters and (b) drop clusters that do not actually show substantive downstream usage of the target contribution.",
|
| 194 |
+
"",
|
| 195 |
+
f"Target paper: {paper_title}",
|
| 196 |
+
"",
|
| 197 |
+
"Rules:",
|
| 198 |
+
"- Operate only at cluster level. Do not invent new instances.",
|
| 199 |
+
"- Prefer conservative merges. If unsure, keep clusters separate.",
|
| 200 |
+
"- You may drop clusters only when they fail to show real downstream use or extension of the target contribution.",
|
| 201 |
+
"- Drop clusters that are clearly mere mention, loose comparison, background citation, noisy extraction, or off-target usage.",
|
| 202 |
+
"- Never merge USES and EXTENDS clusters together.",
|
| 203 |
+
"- Every input cluster_id must either appear in exactly one kept group or in dropped_clusters.",
|
| 204 |
+
"- kept merged_key must be in format RELATION|artifact|purpose.",
|
| 205 |
+
"- merged_title must be a short natural-language summary (5-12 words).",
|
| 206 |
+
"",
|
| 207 |
+
"Input clusters:",
|
| 208 |
+
]
|
| 209 |
+
for c in centroids:
|
| 210 |
+
lines.append(
|
| 211 |
+
f"- {c['cluster_id']}: key={c.get('cluster_key','')}; title={c.get('cluster_title','')}; count={c.get('count', 0)}; uses={c.get('uses_count',0)}; extends={c.get('extends_count',0)}; examples={' | '.join(c.get('examples',[])[:2])}; rationales={' | '.join(c.get('rationales',[])[:1])}"
|
| 212 |
+
)
|
| 213 |
+
lines += [
|
| 214 |
+
"",
|
| 215 |
+
"Return JSON only with this shape:",
|
| 216 |
+
"{",
|
| 217 |
+
' "kept_groups": [',
|
| 218 |
+
" {",
|
| 219 |
+
' "cluster_ids": ["C1","C3"],',
|
| 220 |
+
' "merged_title": "Uses target dataset for evaluation",',
|
| 221 |
+
' "merged_key": "USES|dataset|evaluation",',
|
| 222 |
+
' "rationale": "Both clusters describe the same downstream dataset use."',
|
| 223 |
+
" }",
|
| 224 |
+
" ],",
|
| 225 |
+
' "dropped_clusters": [',
|
| 226 |
+
' {"cluster_id": "C7", "reason": "Only background mention; no substantive downstream use."}',
|
| 227 |
+
" ]",
|
| 228 |
+
"}",
|
| 229 |
+
]
|
| 230 |
+
return "\n".join(lines)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _normalize_decision(data: Dict[str, Any], original_clusters: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 234 |
+
valid_ids = [c.get("cluster_id", "") for c in original_clusters if c.get("cluster_id")]
|
| 235 |
+
valid_set = set(valid_ids)
|
| 236 |
+
assigned = set()
|
| 237 |
+
kept: List[Dict[str, Any]] = []
|
| 238 |
+
dropped: List[Dict[str, Any]] = []
|
| 239 |
+
|
| 240 |
+
for item in data.get("dropped_clusters") or []:
|
| 241 |
+
cid = item.get("cluster_id")
|
| 242 |
+
if cid in valid_set and cid not in assigned:
|
| 243 |
+
assigned.add(cid)
|
| 244 |
+
dropped.append({"cluster_id": cid, "reason": str(item.get("reason", "")).strip() or "Dropped by LLM filter."})
|
| 245 |
+
|
| 246 |
+
for g in data.get("kept_groups") or []:
|
| 247 |
+
ids = [cid for cid in (g.get("cluster_ids") or []) if cid in valid_set and cid not in assigned]
|
| 248 |
+
if not ids:
|
| 249 |
+
continue
|
| 250 |
+
for cid in ids:
|
| 251 |
+
assigned.add(cid)
|
| 252 |
+
kept.append({
|
| 253 |
+
"cluster_ids": ids,
|
| 254 |
+
"merged_title": str(g.get("merged_title", "")).strip(),
|
| 255 |
+
"merged_key": str(g.get("merged_key", "")).strip(),
|
| 256 |
+
"rationale": str(g.get("rationale", "")).strip(),
|
| 257 |
+
})
|
| 258 |
+
|
| 259 |
+
for cid in valid_ids:
|
| 260 |
+
if cid not in assigned:
|
| 261 |
+
kept.append({
|
| 262 |
+
"cluster_ids": [cid],
|
| 263 |
+
"merged_title": "",
|
| 264 |
+
"merged_key": "",
|
| 265 |
+
"rationale": "Auto-singleton fallback.",
|
| 266 |
+
})
|
| 267 |
+
|
| 268 |
+
order = {cid: i for i, cid in enumerate(valid_ids)}
|
| 269 |
+
kept.sort(key=lambda g: min(order[cid] for cid in g["cluster_ids"]))
|
| 270 |
+
dropped.sort(key=lambda x: order.get(x["cluster_id"], 10**9))
|
| 271 |
+
return kept, dropped
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def refine_paper(paper_dir: Path, overwrite: bool, inplace: bool) -> str:
|
| 275 |
+
disc_path = paper_dir / DISCOVERY_FILE
|
| 276 |
+
contrib_path = paper_dir / CONTRIB_FILE
|
| 277 |
+
meta_path = paper_dir / PAPER_META_FILE
|
| 278 |
+
|
| 279 |
+
disc = load_json(disc_path)
|
| 280 |
+
contrib = load_json(contrib_path)
|
| 281 |
+
meta = load_json(meta_path)
|
| 282 |
+
if not isinstance(contrib, dict):
|
| 283 |
+
return "missing_inputs"
|
| 284 |
+
|
| 285 |
+
if not isinstance(disc, dict):
|
| 286 |
+
disc = {"paper_id": contrib.get("paper_id"), "decision": "", "justification": "", "clusters": []}
|
| 287 |
+
|
| 288 |
+
clusters = disc.get("clusters") or []
|
| 289 |
+
if not clusters:
|
| 290 |
+
clusters = _build_initial_clusters_from_contributions(contrib)
|
| 291 |
+
if not clusters:
|
| 292 |
+
payload = dict(disc)
|
| 293 |
+
payload["clusters"] = []
|
| 294 |
+
payload["dropped_clusters"] = []
|
| 295 |
+
payload["cluster_refine_method"] = "llm_centroid_merge_filter"
|
| 296 |
+
payload["cluster_refine_source"] = CONTRIB_FILE
|
| 297 |
+
out_path = disc_path if inplace else (paper_dir / OUT_FILE)
|
| 298 |
+
out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 299 |
+
return "empty_clusters"
|
| 300 |
+
|
| 301 |
+
out_path = disc_path if inplace else (paper_dir / OUT_FILE)
|
| 302 |
+
if out_path.exists() and not overwrite:
|
| 303 |
+
return "skipped"
|
| 304 |
+
|
| 305 |
+
contributions = contrib.get("contributions") or []
|
| 306 |
+
centroids: List[Dict[str, Any]] = []
|
| 307 |
+
auto_dropped: List[Dict[str, Any]] = []
|
| 308 |
+
active_clusters: List[Dict[str, Any]] = []
|
| 309 |
+
|
| 310 |
+
for c in clusters:
|
| 311 |
+
cid = c.get("cluster_id", "")
|
| 312 |
+
summary = _cluster_support_summary(c, contributions)
|
| 313 |
+
rel, _, _ = _parse_key(c.get("cluster_key", ""))
|
| 314 |
+
if summary["uses_count"] + summary["extends_count"] == 0 or rel not in {"USES", "EXTENDS"}:
|
| 315 |
+
auto_dropped.append({"cluster_id": cid, "reason": "No verified USES/EXTENDS support in member contributions."})
|
| 316 |
+
continue
|
| 317 |
+
row = {
|
| 318 |
+
"cluster_id": cid,
|
| 319 |
+
"cluster_key": c.get("cluster_key", ""),
|
| 320 |
+
"cluster_title": c.get("cluster_title") or c.get("representative_claim") or "",
|
| 321 |
+
"count": int(c.get("count", summary["member_count"]) or summary["member_count"]),
|
| 322 |
+
**summary,
|
| 323 |
+
}
|
| 324 |
+
centroids.append(row)
|
| 325 |
+
active_clusters.append(c)
|
| 326 |
+
|
| 327 |
+
if not active_clusters:
|
| 328 |
+
payload = dict(disc)
|
| 329 |
+
payload["clusters"] = []
|
| 330 |
+
payload["dropped_clusters"] = auto_dropped
|
| 331 |
+
payload["cluster_refine_method"] = "llm_centroid_merge_filter"
|
| 332 |
+
payload["cluster_refine_source"] = CONTRIB_FILE if not load_json(disc_path) else DISCOVERY_FILE
|
| 333 |
+
out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 334 |
+
return "refined"
|
| 335 |
+
|
| 336 |
+
prompt = build_prompt(_extract_title(meta), centroids)
|
| 337 |
+
client = LLMClient()
|
| 338 |
+
raw = client.call(prompt, schema=REFINE_SCHEMA)
|
| 339 |
+
data = json.loads(raw)
|
| 340 |
+
kept_groups, llm_dropped = _normalize_decision(data, active_clusters)
|
| 341 |
+
|
| 342 |
+
id_to_cluster = {c.get("cluster_id"): c for c in active_clusters if c.get("cluster_id")}
|
| 343 |
+
merged_clusters: List[Dict[str, Any]] = []
|
| 344 |
+
for idx, g in enumerate(kept_groups, start=1):
|
| 345 |
+
member_ids = g["cluster_ids"]
|
| 346 |
+
members = [id_to_cluster[mid] for mid in member_ids if mid in id_to_cluster]
|
| 347 |
+
merged_indices: List[int] = []
|
| 348 |
+
for m in members:
|
| 349 |
+
for i in _to_int_indices(m.get("claim_indices") or []):
|
| 350 |
+
if i not in merged_indices:
|
| 351 |
+
merged_indices.append(i)
|
| 352 |
+
merged_indices.sort()
|
| 353 |
+
merged_key = g.get("merged_key") or _dominant_key(members)
|
| 354 |
+
rel, _, _ = _parse_key(merged_key)
|
| 355 |
+
if rel not in {"USES", "EXTENDS"}:
|
| 356 |
+
merged_key = _dominant_key(members)
|
| 357 |
+
merged_title = g.get("merged_title") or (members[0].get("cluster_title") if members else "")
|
| 358 |
+
if not merged_title:
|
| 359 |
+
merged_title = members[0].get("representative_claim", "") if members else ""
|
| 360 |
+
merged_clusters.append({
|
| 361 |
+
"cluster_id": f"C{idx}",
|
| 362 |
+
"count": str(len(merged_indices)),
|
| 363 |
+
"representative_claim": merged_title,
|
| 364 |
+
"cluster_key": merged_key,
|
| 365 |
+
"cluster_title": merged_title,
|
| 366 |
+
"claim_indices": [str(i) for i in merged_indices],
|
| 367 |
+
"source_cluster_ids": member_ids,
|
| 368 |
+
"merge_rationale": g.get("rationale", ""),
|
| 369 |
+
})
|
| 370 |
+
|
| 371 |
+
payload = dict(disc)
|
| 372 |
+
payload["clusters"] = merged_clusters
|
| 373 |
+
payload["dropped_clusters"] = auto_dropped + llm_dropped
|
| 374 |
+
payload["cluster_refine_method"] = "llm_centroid_merge_filter"
|
| 375 |
+
payload["cluster_refine_source"] = CONTRIB_FILE if not load_json(disc_path) else DISCOVERY_FILE
|
| 376 |
+
out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 377 |
+
return "refined"
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def main() -> None:
|
| 381 |
+
parser = argparse.ArgumentParser(description="LLM centroid-level merge/filter pass for downstream contribution clusters.")
|
| 382 |
+
parser.add_argument("--root", type=str, default="runs/processed_papers", help="Root directory containing processed paper directories.")
|
| 383 |
+
parser.add_argument("--overwrite", action="store_true", help="Overwrite output file if it exists.")
|
| 384 |
+
parser.add_argument("--inplace", action="store_true", help="Write back to usage_discovery_from_contributions.json.")
|
| 385 |
+
args = parser.parse_args()
|
| 386 |
+
|
| 387 |
+
root = Path(args.root).expanduser().resolve()
|
| 388 |
+
if not root.exists():
|
| 389 |
+
raise SystemExit(f"Root directory does not exist: {root}")
|
| 390 |
+
|
| 391 |
+
paper_dirs = iter_paper_dirs(root)
|
| 392 |
+
print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
|
| 393 |
+
counts = {"refined": 0, "skipped": 0, "missing_inputs": 0, "empty_clusters": 0}
|
| 394 |
+
for paper_dir in paper_dirs:
|
| 395 |
+
status = refine_paper(paper_dir, overwrite=args.overwrite, inplace=args.inplace)
|
| 396 |
+
counts[status] = counts.get(status, 0) + 1
|
| 397 |
+
print(f"[{status.upper()}] {paper_dir.name}")
|
| 398 |
+
print("[SUMMARY] refined={refined}, skipped={skipped}, missing_inputs={missing_inputs}, empty_clusters={empty_clusters}".format(**counts))
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
if __name__ == "__main__":
|
| 402 |
+
main()
|
src/step_07_extract_and_refine/schemas.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CONTRIBUTION_JSON_SCHEMA = {
|
| 2 |
+
"type": "object",
|
| 3 |
+
"properties": {
|
| 4 |
+
"label": {"type": "string", "enum": ["USES", "EXTENDS", "NOT_CONFIRMED"]},
|
| 5 |
+
"paper_claim": {"type": "string"},
|
| 6 |
+
"cluster_title": {"type": "string"},
|
| 7 |
+
"cluster_key": {"type": "string"},
|
| 8 |
+
"evidence_span": {"type": "string"},
|
| 9 |
+
"rationale": {"type": "string"},
|
| 10 |
+
},
|
| 11 |
+
"required": ["label", "paper_claim", "cluster_title", "cluster_key", "evidence_span", "rationale"],
|
| 12 |
+
}
|
src/step_08_annotation/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipeline import TwoPassAnnotationPipeline, TwoPassPipelineResult
|
| 2 |
+
|
| 3 |
+
__all__ = ["TwoPassAnnotationPipeline", "TwoPassPipelineResult"]
|
src/step_08_annotation/cli.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import typer
|
| 7 |
+
|
| 8 |
+
from .paper_package import load_paper_package
|
| 9 |
+
|
| 10 |
+
from .pipeline import TwoPassAnnotationPipeline
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
app = typer.Typer(help="Run step 8: derive target contributions, enabling contributions, and groundings.")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _default_output_root() -> Path:
|
| 17 |
+
return Path("runs/two_pass_outputs")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@app.command()
|
| 21 |
+
def run(
|
| 22 |
+
paper_dir: Path = typer.Option(..., exists=True, file_okay=False, dir_okay=True),
|
| 23 |
+
provider: str = typer.Option("openai", help="Provider family: openai or gemini."),
|
| 24 |
+
model: str = typer.Option("openai/gpt-5", help="Reasoning model used for target-contribution derivation and annotation."),
|
| 25 |
+
formatter_model: str | None = typer.Option(
|
| 26 |
+
None,
|
| 27 |
+
help="Optional model override for pass 2 formatting, e.g. openai/gpt-5-mini or openai/gpt-5.4-pro.",
|
| 28 |
+
),
|
| 29 |
+
judge_model: str | None = typer.Option(
|
| 30 |
+
None,
|
| 31 |
+
help="Optional model override for pass 1 candidate ranking. Ignored when --candidate-count=1.",
|
| 32 |
+
),
|
| 33 |
+
candidate_count: int = typer.Option(
|
| 34 |
+
1,
|
| 35 |
+
help="Number of reasoning candidates to generate. If set to 1, no judge call is made.",
|
| 36 |
+
),
|
| 37 |
+
formatter_max_attempts: int = typer.Option(
|
| 38 |
+
3,
|
| 39 |
+
help="Formatter-only retry attempts after pass 1 has succeeded.",
|
| 40 |
+
),
|
| 41 |
+
include_reference_examples: bool = typer.Option(
|
| 42 |
+
True,
|
| 43 |
+
"--include-reference-examples/--no-include-reference-examples",
|
| 44 |
+
help="Include the built-in reference examples in the pass-1 reasoning prompt.",
|
| 45 |
+
),
|
| 46 |
+
prompt_profile: str = typer.Option(
|
| 47 |
+
"full",
|
| 48 |
+
help="Reasoning prompt profile: full or generic.",
|
| 49 |
+
),
|
| 50 |
+
output_root: Path = typer.Option(
|
| 51 |
+
_default_output_root(),
|
| 52 |
+
help="Directory to store run outputs.",
|
| 53 |
+
),
|
| 54 |
+
run_label: str | None = typer.Option(None, help="Optional label to include in the saved run directory name."),
|
| 55 |
+
annotator_id: str = typer.Option("llm", help="Annotator id to embed in the final UI payload."),
|
| 56 |
+
extracted_claim: str | None = typer.Option(None, help="Optional override for the extracted target contribution."),
|
| 57 |
+
) -> None:
|
| 58 |
+
paper = load_paper_package(paper_dir, extracted_claim_override=extracted_claim)
|
| 59 |
+
pipeline = TwoPassAnnotationPipeline(
|
| 60 |
+
provider=provider,
|
| 61 |
+
model=model,
|
| 62 |
+
formatter_model=formatter_model,
|
| 63 |
+
judge_model=judge_model,
|
| 64 |
+
output_root=output_root,
|
| 65 |
+
run_label=run_label,
|
| 66 |
+
annotator_id=annotator_id,
|
| 67 |
+
candidate_count=candidate_count,
|
| 68 |
+
formatter_max_attempts=formatter_max_attempts,
|
| 69 |
+
include_reference_examples=include_reference_examples,
|
| 70 |
+
prompt_profile=prompt_profile,
|
| 71 |
+
progress_callback=typer.echo,
|
| 72 |
+
)
|
| 73 |
+
result = pipeline.run(paper)
|
| 74 |
+
typer.echo(str(result.run_dir / "run_output.json"))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@app.command()
|
| 78 |
+
def summarize(run_output: Path = typer.Option(..., exists=True, dir_okay=False, file_okay=True)) -> None:
|
| 79 |
+
data = json.loads(run_output.read_text())
|
| 80 |
+
payload = data.get("ui_payload") or {}
|
| 81 |
+
claims = payload.get("claims") or []
|
| 82 |
+
summary = {
|
| 83 |
+
"paper_id": data.get("paper_id"),
|
| 84 |
+
"target_contribution_count": len(claims),
|
| 85 |
+
"target_contributions": [
|
| 86 |
+
{
|
| 87 |
+
"claim_id": claim.get("claim_id"),
|
| 88 |
+
"rewritten_claim": claim.get("rewritten_claim"),
|
| 89 |
+
"decision": claim.get("decision"),
|
| 90 |
+
"enabling_contribution_count": len(claim.get("ingredients") or []),
|
| 91 |
+
}
|
| 92 |
+
for claim in claims
|
| 93 |
+
],
|
| 94 |
+
}
|
| 95 |
+
typer.echo(json.dumps(summary, indent=2))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
app()
|
src/step_08_annotation/final_prompts.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/step_08_annotation/paper_package.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
from common.paper_package import (
|
| 8 |
+
PaperPackage,
|
| 9 |
+
_collect_bibliography,
|
| 10 |
+
_collect_citation_contexts,
|
| 11 |
+
_collect_full_processed_text,
|
| 12 |
+
_collect_sections,
|
| 13 |
+
_load_json,
|
| 14 |
+
_normalize_dict_payload,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _collect_all_cluster_evidence(paper_dir: Path) -> List[Dict[str, Any]]:
|
| 19 |
+
discovery = _normalize_dict_payload(_load_json(paper_dir / "usage_discovery_from_contributions.json", {}))
|
| 20 |
+
clusters = discovery.get("clusters", [])
|
| 21 |
+
out = []
|
| 22 |
+
for cluster in clusters:
|
| 23 |
+
out.append(
|
| 24 |
+
{
|
| 25 |
+
"cluster_id": cluster.get("cluster_id"),
|
| 26 |
+
"representative_claim": cluster.get("representative_claim") or cluster.get("cluster_title"),
|
| 27 |
+
"cluster_title": cluster.get("cluster_title"),
|
| 28 |
+
"count": cluster.get("count"),
|
| 29 |
+
"cluster_key": cluster.get("cluster_key"),
|
| 30 |
+
"claim_indices": cluster.get("claim_indices", []),
|
| 31 |
+
"source_cluster_ids": cluster.get("source_cluster_ids", []),
|
| 32 |
+
"merge_rationale": cluster.get("merge_rationale"),
|
| 33 |
+
}
|
| 34 |
+
)
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_paper_package(paper_dir: str | Path, extracted_claim_override: str | None = None) -> PaperPackage:
|
| 39 |
+
paper_dir = Path(paper_dir)
|
| 40 |
+
paper_metadata = _normalize_dict_payload(_load_json(paper_dir / "paper_metadata.json", {}))
|
| 41 |
+
cluster_evidence = _collect_all_cluster_evidence(paper_dir)
|
| 42 |
+
seed = extracted_claim_override or ""
|
| 43 |
+
return PaperPackage(
|
| 44 |
+
paper_dir=paper_dir,
|
| 45 |
+
paper_metadata=paper_metadata,
|
| 46 |
+
extracted_discovery_claim=seed,
|
| 47 |
+
downstream_cluster_evidence=cluster_evidence,
|
| 48 |
+
paper_text=_collect_sections(paper_dir),
|
| 49 |
+
full_processed_text=_collect_full_processed_text(paper_dir),
|
| 50 |
+
bibliography=_collect_bibliography(paper_dir),
|
| 51 |
+
citation_contexts=_collect_citation_contexts(paper_dir),
|
| 52 |
+
)
|
src/step_08_annotation/pipeline.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import traceback
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from datetime import datetime, timezone
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Callable, Dict
|
| 9 |
+
|
| 10 |
+
from common.model_client import ModelConfig, MultiProviderLLMClient
|
| 11 |
+
from common.paper_package import PaperPackage
|
| 12 |
+
|
| 13 |
+
from .final_prompts import (
|
| 14 |
+
SYSTEM_TWO_PASS_FORMATTER,
|
| 15 |
+
SYSTEM_TWO_PASS_JUDGE,
|
| 16 |
+
SYSTEM_TWO_PASS_REASONING,
|
| 17 |
+
formatter_prompt,
|
| 18 |
+
judge_prompt,
|
| 19 |
+
reasoning_prompt,
|
| 20 |
+
)
|
| 21 |
+
from .schemas import JudgeResult, UIPayload
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class TwoPassPipelineResult:
|
| 26 |
+
run_dir: Path
|
| 27 |
+
result: Dict[str, Any]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class FormatterStageError(RuntimeError):
|
| 31 |
+
def __init__(self, message: str, run_dir: Path):
|
| 32 |
+
super().__init__(message)
|
| 33 |
+
self.run_dir = run_dir
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TwoPassAnnotationPipeline:
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
*,
|
| 40 |
+
provider: str,
|
| 41 |
+
model: str,
|
| 42 |
+
formatter_model: str | None,
|
| 43 |
+
judge_model: str | None,
|
| 44 |
+
output_root: Path,
|
| 45 |
+
run_label: str | None = None,
|
| 46 |
+
annotator_id: str = "llm",
|
| 47 |
+
temperature: float = 0.2,
|
| 48 |
+
max_tokens: int = 16000,
|
| 49 |
+
candidate_count: int = 1,
|
| 50 |
+
formatter_max_attempts: int = 3,
|
| 51 |
+
include_reference_examples: bool = True,
|
| 52 |
+
prompt_profile: str = "full",
|
| 53 |
+
progress_callback: Callable[[str], None] | None = None,
|
| 54 |
+
):
|
| 55 |
+
self.output_root = output_root
|
| 56 |
+
self.annotator_id = annotator_id
|
| 57 |
+
self.progress_callback = progress_callback
|
| 58 |
+
self.run_label = run_label
|
| 59 |
+
self.candidate_count = max(1, candidate_count)
|
| 60 |
+
self.formatter_max_attempts = max(1, formatter_max_attempts)
|
| 61 |
+
self.include_reference_examples = include_reference_examples
|
| 62 |
+
self.prompt_profile = prompt_profile
|
| 63 |
+
self.use_judge = self.candidate_count > 1
|
| 64 |
+
stage_models = {}
|
| 65 |
+
if formatter_model:
|
| 66 |
+
stage_models["two_pass_formatter"] = formatter_model
|
| 67 |
+
if judge_model and self.use_judge:
|
| 68 |
+
stage_models["two_pass_judge"] = judge_model
|
| 69 |
+
self.client = MultiProviderLLMClient(
|
| 70 |
+
default_config=ModelConfig(
|
| 71 |
+
provider=provider,
|
| 72 |
+
model=model,
|
| 73 |
+
temperature=temperature,
|
| 74 |
+
max_tokens=max_tokens,
|
| 75 |
+
),
|
| 76 |
+
stage_models=stage_models,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def run(self, paper: PaperPackage) -> TwoPassPipelineResult:
|
| 80 |
+
run_dir = self._make_run_dir(paper)
|
| 81 |
+
payload = {
|
| 82 |
+
**paper.to_prompt_payload(),
|
| 83 |
+
"paper_dir": paper.paper_dir,
|
| 84 |
+
"full_processed_text": self._load_full_processed_text(paper),
|
| 85 |
+
}
|
| 86 |
+
formatter_config = self.client.config_for_stage("two_pass_formatter")
|
| 87 |
+
self._log(
|
| 88 |
+
f"[run] paper={paper.paper_dir.name} provider={self.client.default_config.provider} model={self.client.default_config.model_name}"
|
| 89 |
+
)
|
| 90 |
+
self._log(f"[run] formatter_model={formatter_config.model_name}")
|
| 91 |
+
self._log(f"[run] include_reference_examples={self.include_reference_examples}")
|
| 92 |
+
self._log(f"[run] prompt_profile={self.prompt_profile}")
|
| 93 |
+
if self.use_judge:
|
| 94 |
+
judge_config = self.client.config_for_stage("two_pass_judge")
|
| 95 |
+
self._log(f"[run] judge_model={judge_config.model_name}")
|
| 96 |
+
else:
|
| 97 |
+
self._log("[run] judge_model=disabled (candidate_count=1)")
|
| 98 |
+
self._log(f"[run] output={run_dir}")
|
| 99 |
+
|
| 100 |
+
reasoning_user_prompt = reasoning_prompt(payload, include_reference_examples=self.include_reference_examples, prompt_profile=self.prompt_profile)
|
| 101 |
+
self._write_text(run_dir / "pass_1_reasoning.prompt.txt", reasoning_user_prompt)
|
| 102 |
+
self._log(
|
| 103 |
+
f"[pass 1] free-form reasoning ({self.candidate_count} candidate{'s' if self.candidate_count != 1 else ''})"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
candidate_texts: list[str] = []
|
| 107 |
+
candidate_paths: list[str] = []
|
| 108 |
+
for index in range(self.candidate_count):
|
| 109 |
+
reasoning_text = self.client.generate_text(
|
| 110 |
+
stage_name="two_pass_reasoning",
|
| 111 |
+
system_prompt=SYSTEM_TWO_PASS_REASONING,
|
| 112 |
+
user_prompt=reasoning_user_prompt,
|
| 113 |
+
)
|
| 114 |
+
candidate_id = f"candidate_{index + 1}"
|
| 115 |
+
candidate_path = run_dir / f"pass_1_reasoning.output.{candidate_id}.md"
|
| 116 |
+
self._write_text(candidate_path, reasoning_text)
|
| 117 |
+
candidate_texts.append(reasoning_text)
|
| 118 |
+
candidate_paths.append(str(candidate_path))
|
| 119 |
+
|
| 120 |
+
selected_candidate_index = 0
|
| 121 |
+
selected_candidate_id = "candidate_1"
|
| 122 |
+
selected_reasoning_text = candidate_texts[0]
|
| 123 |
+
judge_output_path: Path | None = None
|
| 124 |
+
|
| 125 |
+
if self.use_judge:
|
| 126 |
+
judge_user_prompt = judge_prompt(payload, candidate_texts)
|
| 127 |
+
self._write_text(run_dir / "pass_1_reasoning.judge.prompt.txt", judge_user_prompt)
|
| 128 |
+
self._log("[pass 1] candidate judging")
|
| 129 |
+
judge_result = self.client.generate_structured(
|
| 130 |
+
stage_name="two_pass_judge",
|
| 131 |
+
system_prompt=SYSTEM_TWO_PASS_JUDGE,
|
| 132 |
+
user_prompt=judge_user_prompt,
|
| 133 |
+
response_model=JudgeResult,
|
| 134 |
+
)
|
| 135 |
+
judge_output_path = run_dir / "pass_1_reasoning.judge.output.json"
|
| 136 |
+
self._write_json(judge_output_path, judge_result.model_dump())
|
| 137 |
+
selected_candidate_index = judge_result.selected_candidate_index
|
| 138 |
+
selected_candidate_id = judge_result.selected_candidate_id
|
| 139 |
+
selected_reasoning_text = candidate_texts[selected_candidate_index]
|
| 140 |
+
|
| 141 |
+
selected_reasoning_path = run_dir / "pass_1_reasoning.selected.md"
|
| 142 |
+
self._write_text(selected_reasoning_path, selected_reasoning_text)
|
| 143 |
+
|
| 144 |
+
formatter_user_prompt = formatter_prompt(payload, selected_reasoning_text, self.annotator_id)
|
| 145 |
+
self._write_text(run_dir / "pass_2_formatter.prompt.txt", formatter_user_prompt)
|
| 146 |
+
final_payload: UIPayload | None = None
|
| 147 |
+
formatter_attempts: list[dict[str, Any]] = []
|
| 148 |
+
for attempt in range(1, self.formatter_max_attempts + 1):
|
| 149 |
+
self._log(
|
| 150 |
+
f"[pass 2] strict ui json formatting (attempt {attempt}/{self.formatter_max_attempts})"
|
| 151 |
+
)
|
| 152 |
+
try:
|
| 153 |
+
final_payload = self.client.generate_structured(
|
| 154 |
+
stage_name="two_pass_formatter",
|
| 155 |
+
system_prompt=SYSTEM_TWO_PASS_FORMATTER,
|
| 156 |
+
user_prompt=formatter_user_prompt,
|
| 157 |
+
response_model=UIPayload,
|
| 158 |
+
)
|
| 159 |
+
formatter_attempts.append({"attempt": attempt, "status": "success"})
|
| 160 |
+
break
|
| 161 |
+
except Exception as exc:
|
| 162 |
+
error_text = "".join(traceback.format_exception(exc)).strip()
|
| 163 |
+
error_path = run_dir / f"pass_2_formatter.attempt_{attempt}.error.txt"
|
| 164 |
+
self._write_text(error_path, error_text)
|
| 165 |
+
formatter_attempts.append(
|
| 166 |
+
{
|
| 167 |
+
"attempt": attempt,
|
| 168 |
+
"status": "failed",
|
| 169 |
+
"error": str(exc),
|
| 170 |
+
"error_path": str(error_path),
|
| 171 |
+
}
|
| 172 |
+
)
|
| 173 |
+
if attempt < self.formatter_max_attempts:
|
| 174 |
+
self._log("[pass 2] formatter failed; retrying formatter only")
|
| 175 |
+
|
| 176 |
+
if final_payload is None:
|
| 177 |
+
self._write_json(run_dir / "formatter_attempts.json", {"attempts": formatter_attempts})
|
| 178 |
+
raise FormatterStageError(
|
| 179 |
+
f"Formatter failed after {self.formatter_max_attempts} attempts; pass 1 outputs kept in {run_dir}",
|
| 180 |
+
run_dir,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
self._write_json(run_dir / "formatter_attempts.json", {"attempts": formatter_attempts})
|
| 184 |
+
self._write_json(run_dir / "pass_2_ui_payload.json", final_payload.model_dump())
|
| 185 |
+
|
| 186 |
+
result = {
|
| 187 |
+
"paper_id": paper.paper_dir.name,
|
| 188 |
+
"paper_dir": str(paper.paper_dir),
|
| 189 |
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
| 190 |
+
"reasoner_model": self.client.default_config.model_name,
|
| 191 |
+
"formatter_model": formatter_config.model_name,
|
| 192 |
+
"judge_model": judge_config.model_name if self.use_judge else None,
|
| 193 |
+
"candidate_count": self.candidate_count,
|
| 194 |
+
"include_reference_examples": self.include_reference_examples,
|
| 195 |
+
"prompt_profile": self.prompt_profile,
|
| 196 |
+
"reasoning_candidate_paths": [str(path) for path in candidate_paths],
|
| 197 |
+
"selected_reasoning_candidate": selected_candidate_id,
|
| 198 |
+
"selected_candidate_index": selected_candidate_index,
|
| 199 |
+
"selected_reasoning_path": str(selected_reasoning_path),
|
| 200 |
+
"judge_output_path": str(judge_output_path) if judge_output_path is not None else None,
|
| 201 |
+
"formatter_attempts": formatter_attempts,
|
| 202 |
+
"ui_payload_path": str(run_dir / "pass_2_ui_payload.json"),
|
| 203 |
+
"ui_payload": final_payload.model_dump(),
|
| 204 |
+
}
|
| 205 |
+
self._write_json(run_dir / "run_output.json", result)
|
| 206 |
+
self._log("[run] complete")
|
| 207 |
+
return TwoPassPipelineResult(run_dir=run_dir, result=result)
|
| 208 |
+
|
| 209 |
+
def _make_run_dir(self, paper: PaperPackage) -> Path:
|
| 210 |
+
stamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
| 211 |
+
run_name = stamp
|
| 212 |
+
if self.run_label:
|
| 213 |
+
run_name = f"{self._slugify(self.run_label)}__{stamp}"
|
| 214 |
+
run_dir = self.output_root / paper.paper_dir.name / run_name
|
| 215 |
+
run_dir.mkdir(parents=True, exist_ok=True)
|
| 216 |
+
return run_dir
|
| 217 |
+
|
| 218 |
+
def _write_json(self, path: Path, payload: Dict[str, Any]) -> None:
|
| 219 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 220 |
+
path.write_text(json.dumps(payload, indent=2, ensure_ascii=True) + "\n")
|
| 221 |
+
|
| 222 |
+
def _write_text(self, path: Path, text: str) -> None:
|
| 223 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 224 |
+
path.write_text(text)
|
| 225 |
+
|
| 226 |
+
def _log(self, message: str) -> None:
|
| 227 |
+
if self.progress_callback:
|
| 228 |
+
self.progress_callback(message)
|
| 229 |
+
|
| 230 |
+
@staticmethod
|
| 231 |
+
def _slugify(value: str) -> str:
|
| 232 |
+
slug = "".join(ch if ch.isalnum() or ch in {"-", "_", "."} else "-" for ch in value.strip())
|
| 233 |
+
slug = "-".join(part for part in slug.split("-") if part)
|
| 234 |
+
return slug[:160] or "run"
|
| 235 |
+
|
| 236 |
+
def _load_full_processed_text(self, paper: PaperPackage) -> str:
|
| 237 |
+
processed_path = paper.paper_dir / "processed_main.tex"
|
| 238 |
+
if processed_path.exists():
|
| 239 |
+
try:
|
| 240 |
+
return processed_path.read_text()
|
| 241 |
+
except Exception:
|
| 242 |
+
pass
|
| 243 |
+
|
| 244 |
+
sections_dir = paper.paper_dir / "sections"
|
| 245 |
+
parts: list[str] = []
|
| 246 |
+
if sections_dir.exists():
|
| 247 |
+
for path in sorted(sections_dir.iterdir()):
|
| 248 |
+
if not path.is_file():
|
| 249 |
+
continue
|
| 250 |
+
try:
|
| 251 |
+
text = path.read_text().strip()
|
| 252 |
+
except Exception:
|
| 253 |
+
continue
|
| 254 |
+
if text:
|
| 255 |
+
parts.append(f"[{path.name}]\n{text}")
|
| 256 |
+
return "\n\n".join(parts)
|
src/step_08_annotation/schemas.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
ALLOWED_ARTIFACT_TYPES = ["Resource", "Finding", "Method", "Benchmark", "Dataset", "Tool", "Other"]
|
| 9 |
+
ALLOWED_ROLES = [
|
| 10 |
+
"CONCEPTUAL_FRAMEWORK",
|
| 11 |
+
"CORE_METHOD",
|
| 12 |
+
"DATA_SOURCE",
|
| 13 |
+
"MODEL_INITIALIZATION",
|
| 14 |
+
"EVALUATION_PROTOCOL",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ReasoningCandidate(BaseModel):
|
| 19 |
+
study: str
|
| 20 |
+
decision: Literal["accepted_canonical", "accepted_additional", "accepted_none", "rejected_candidate"]
|
| 21 |
+
why: str
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ReasoningIngredient(BaseModel):
|
| 25 |
+
ingredient_id: str
|
| 26 |
+
ingredient: str
|
| 27 |
+
necessary: bool
|
| 28 |
+
from_prior_work: bool
|
| 29 |
+
maps_cleanly_to_one_study: bool
|
| 30 |
+
notes: str
|
| 31 |
+
canonical_grounding_decision: Dict[str, Any]
|
| 32 |
+
additional_groundings: List[Dict[str, Any]]
|
| 33 |
+
candidate_studies_considered: List[ReasoningCandidate]
|
| 34 |
+
role: Optional[Literal["CONCEPTUAL_FRAMEWORK", "CORE_METHOD", "DATA_SOURCE", "MODEL_INITIALIZATION", "EVALUATION_PROTOCOL", "IMPLEMENTATION_TOOLING", "TRAINING_DATA"]] = None
|
| 35 |
+
contribution: str
|
| 36 |
+
rationale: str
|
| 37 |
+
evidence_span: str
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ReasoningClaim(BaseModel):
|
| 41 |
+
claim_id: str
|
| 42 |
+
artifact_type: Literal["Resource", "Finding", "Method", "Benchmark", "Dataset", "Tool", "Other"]
|
| 43 |
+
rewritten_claim: str
|
| 44 |
+
cluster_id: str = ""
|
| 45 |
+
decision: Literal["YES_SUFFICIENT", "NO_NOT_DISCOVERY"] = "YES_SUFFICIENT"
|
| 46 |
+
notes: str = ""
|
| 47 |
+
why_this_is_atomic: str
|
| 48 |
+
ingredients: List[ReasoningIngredient]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ReasoningOutput(BaseModel):
|
| 52 |
+
original_discovery_claim: str
|
| 53 |
+
claim_split_decision: Dict[str, Any]
|
| 54 |
+
rewritten_claims: List[ReasoningClaim]
|
| 55 |
+
paper_level_notes: str = ""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class GroundingRecord(BaseModel):
|
| 59 |
+
ref_id: Optional[str] = None
|
| 60 |
+
bib_key: Optional[str] = None
|
| 61 |
+
paper_id: Optional[str] = None
|
| 62 |
+
external_ids: Optional[Dict[str, Any]] = None
|
| 63 |
+
ref_title: Optional[str] = None
|
| 64 |
+
ref_year: Optional[str] = None
|
| 65 |
+
ref_authors: Optional[str] = None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class CanonicalAnnotation(BaseModel):
|
| 69 |
+
role: Optional[Literal["CONCEPTUAL_FRAMEWORK", "CORE_METHOD", "DATA_SOURCE", "MODEL_INITIALIZATION", "EVALUATION_PROTOCOL", "IMPLEMENTATION_TOOLING", "TRAINING_DATA"]] = None
|
| 70 |
+
roles: List[Literal["CONCEPTUAL_FRAMEWORK", "CORE_METHOD", "DATA_SOURCE", "MODEL_INITIALIZATION", "EVALUATION_PROTOCOL", "IMPLEMENTATION_TOOLING", "TRAINING_DATA"]]
|
| 71 |
+
contribution: str
|
| 72 |
+
rationale: str
|
| 73 |
+
evidence_span: str
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class IngredientPayload(BaseModel):
|
| 77 |
+
ingredient_id: str
|
| 78 |
+
ingredient: str
|
| 79 |
+
canonical_ref_id: str
|
| 80 |
+
canonical_grounding: Optional[GroundingRecord] = None
|
| 81 |
+
additional_ref_ids: List[str]
|
| 82 |
+
additional_groundings: List[GroundingRecord]
|
| 83 |
+
canonical_annotation: CanonicalAnnotation
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class EnablingDiscoveryPayload(GroundingRecord):
|
| 87 |
+
ingredient_id: str
|
| 88 |
+
ingredient: str
|
| 89 |
+
role: Optional[Literal["CONCEPTUAL_FRAMEWORK", "CORE_METHOD", "DATA_SOURCE", "MODEL_INITIALIZATION", "EVALUATION_PROTOCOL", "IMPLEMENTATION_TOOLING", "TRAINING_DATA"]] = None
|
| 90 |
+
roles: List[Literal["CONCEPTUAL_FRAMEWORK", "CORE_METHOD", "DATA_SOURCE", "MODEL_INITIALIZATION", "EVALUATION_PROTOCOL", "IMPLEMENTATION_TOOLING", "TRAINING_DATA"]]
|
| 91 |
+
contribution: str
|
| 92 |
+
rationale: str
|
| 93 |
+
evidence_span: str
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class ClaimPayload(BaseModel):
|
| 97 |
+
claim_id: str
|
| 98 |
+
text: str
|
| 99 |
+
rewritten_claim: str
|
| 100 |
+
cluster_id: str = ""
|
| 101 |
+
decision: Literal["YES_SUFFICIENT", "NO_NOT_DISCOVERY", "UNCERTAIN"] = "YES_SUFFICIENT"
|
| 102 |
+
notes: str = ""
|
| 103 |
+
ingredients: List[IngredientPayload]
|
| 104 |
+
enabling_discoveries: List[EnablingDiscoveryPayload]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class JudgeCandidateScore(BaseModel):
|
| 108 |
+
candidate_id: str
|
| 109 |
+
candidate_index: int
|
| 110 |
+
score: int
|
| 111 |
+
assessment: str
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class JudgeResult(BaseModel):
|
| 115 |
+
selected_candidate_index: int
|
| 116 |
+
selected_candidate_id: str
|
| 117 |
+
selected_reason: str
|
| 118 |
+
candidate_scores: List[JudgeCandidateScore]
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class UIPayload(BaseModel):
|
| 122 |
+
target_paper_id: str
|
| 123 |
+
target_title: Optional[str] = None
|
| 124 |
+
target_year: Optional[int] = None
|
| 125 |
+
annotator_id: str
|
| 126 |
+
active_claim_id: Optional[str] = None
|
| 127 |
+
claims: List[ClaimPayload]
|