Eric Chamoun commited on
Commit
0a55f0f
·
0 Parent(s):

Initial SciPaths Space release

Browse files
Files changed (43) hide show
  1. .dockerignore +7 -0
  2. .gitattributes +1 -0
  3. .gitignore +7 -0
  4. Deep-Citation/Data/acl.tsv +0 -0
  5. Deep-Citation/Data/class_def.json +23 -0
  6. Deep-Citation/Model/__init__.py +1 -0
  7. Deep-Citation/Model/model.py +89 -0
  8. Deep-Citation/Workspace/acl_scicite_wksp_trl/args.txt +21 -0
  9. Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt +3 -0
  10. Deep-Citation/data.py +211 -0
  11. Dockerfile +24 -0
  12. README.md +232 -0
  13. app.py +5 -0
  14. hf_space/requirements.txt +17 -0
  15. hf_space/runner.py +333 -0
  16. hf_space/streamlit_app.py +864 -0
  17. hf_space/streamlit_config.py +30 -0
  18. requirements.txt +1 -0
  19. src/common/__init__.py +0 -0
  20. src/common/llm_client.py +49 -0
  21. src/common/model_client.py +143 -0
  22. src/common/paper_package.py +288 -0
  23. src/step_01_fetch/config.py +6 -0
  24. src/step_01_fetch/fetch_metadata.py +440 -0
  25. src/step_01_fetch/process_tex_source.py +203 -0
  26. src/step_01_fetch/semanticscholar_client.py +158 -0
  27. src/step_02_mark_citations/replace_citation_markers.py +440 -0
  28. src/step_03_usage_contexts/build_usage_contexts.py +184 -0
  29. src/step_04_label_citations/label_citation_functions.py +373 -0
  30. src/step_05_verify_uses_extends/prompts.py +115 -0
  31. src/step_05_verify_uses_extends/schemas.py +22 -0
  32. src/step_05_verify_uses_extends/verify_uses_extends.py +296 -0
  33. src/step_06_extract_paragraphs/extract_arxiv_paragraphs.py +488 -0
  34. src/step_07_extract_and_refine/extract_contributions_from_citations.py +329 -0
  35. src/step_07_extract_and_refine/prompts.py +65 -0
  36. src/step_07_extract_and_refine/refine_and_filter_clusters_llm.py +402 -0
  37. src/step_07_extract_and_refine/schemas.py +12 -0
  38. src/step_08_annotation/__init__.py +3 -0
  39. src/step_08_annotation/cli.py +99 -0
  40. src/step_08_annotation/final_prompts.py +0 -0
  41. src/step_08_annotation/paper_package.py +52 -0
  42. src/step_08_annotation/pipeline.py +256 -0
  43. src/step_08_annotation/schemas.py +127 -0
.dockerignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .git
2
+ __pycache__
3
+ *.pyc
4
+ hf_space/runs
5
+ **/__pycache__
6
+ **/*.pyc
7
+ *.zip
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ .DS_Store
4
+ .streamlit/secrets.toml
5
+ hf_space/runs/
6
+ runs/
7
+ *.zip
Deep-Citation/Data/acl.tsv ADDED
The diff for this file is too large to render. See raw diff
 
Deep-Citation/Data/class_def.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "acl":
3
+ {
4
+ "BACKGROUND": "The citation provides relevant information for the domain that the present paper discusses.",
5
+ "MOTIVATION": "The citation illustrates the need for data, goals, methods, etc that is proposed in the present paper.",
6
+ "USES": "The present paper uses data, methods, etc., from the paper associated with the citation.",
7
+ "EXTENDS": "The present paper extends the data, methods, etc. from the paper associated with the citation.",
8
+ "COMPAREORCONTRAST": "The present paper expresses similarity / differences to the citation.",
9
+ "FUTURE": "The citation is a potential avenue for future work of the present paper."
10
+ },
11
+ "kim":
12
+ {
13
+ "Used": "The present paper uses at least one method that is proposed in the paper associated with the citation.",
14
+ "Not used": "The present paper does not use or extend any methods that is proposed in the paper associated with the citation.",
15
+ "Extended": "The present paper uses an extended / modified version of the method proposed in the paper associated with the citation."
16
+ },
17
+ "scicite":
18
+ {
19
+ "Background": "The citation states, mentions, or points to the background information giving more context about a problem, concept, approach, topic, or importance of the problem that is discussed in the present paper.",
20
+ "Method": "The present paper uses a method, tool, approach or dataset that is proposed in the paper associated with the citation.",
21
+ "Result": "The present paper compares its results/findings with the results/findings of the paper associated with the citation."
22
+ }
23
+ }
Deep-Citation/Model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LanguageModel, MultiHeadLanguageModel
Deep-Citation/Model/model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from typing import List
6
+ from transformers import AutoModel
7
+
8
+ def mask_pooling(model_output, attention_mask):
9
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
10
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
11
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
12
+
13
+ class LanguageModel(nn.Module):
14
+ def __init__(self,
15
+ modelname: str,
16
+ device: str,
17
+ readout: str
18
+ ):
19
+ super(LanguageModel, self).__init__()
20
+ self.device = device
21
+ self.modelname = modelname
22
+ self.readout_fn = readout
23
+
24
+ self.model = AutoModel.from_pretrained(modelname)
25
+ self.hidden_size = self.model.config.hidden_size
26
+
27
+ def readout(self, model_inputs, model_outputs, readout_masks=None):
28
+ if self.readout_fn == 'cls':
29
+ if 'bert' in self.modelname or 'deberta' in self.modelname:
30
+ text_representations = model_outputs.last_hidden_state[:, 0]
31
+ elif 'xlnet' in self.modelname:
32
+ text_representations = model_outputs.last_hidden_state[:, -1]
33
+ else:
34
+ raise ValueError('Invalid model name {} for the cls readout.'.format(self.modelname))
35
+ elif self.readout_fn == 'mean':
36
+ text_representations = mask_pooling(model_outputs, model_inputs['attention_mask'])
37
+ elif self.readout_fn == 'ch' and readout_masks is not None:
38
+ text_representations = mask_pooling(model_outputs, readout_masks)
39
+ else:
40
+ raise ValueError('Invalid readout function.')
41
+ return text_representations
42
+
43
+ def _lm_forward(self, tokens):
44
+ tokens = tokens.to(self.device)
45
+ if 'readout_mask' in tokens:
46
+ readout_mask = tokens.pop('readout_mask')
47
+ else:
48
+ readout_mask = None
49
+ outputs = self.model(**tokens)
50
+ return self.readout(tokens, outputs, readout_mask)
51
+
52
+ def forward(self):
53
+ raise NotImplementedError
54
+
55
+ def save_pretrained(self, modeldir):
56
+ model_filename = os.path.join(modeldir, 'checkpoint.pt')
57
+ torch.save(self.state_dict(), model_filename)
58
+
59
+ def load_pretrained(self, modeldir):
60
+ model_filename = os.path.join(modeldir, 'checkpoint.pt')
61
+ self.load_state_dict(torch.load(model_filename))
62
+
63
+ class MultiHeadLanguageModel(LanguageModel):
64
+ def __init__(self,
65
+ modelname: str,
66
+ device: str,
67
+ readout: str,
68
+ num_classes: List
69
+ ):
70
+ super().__init__(
71
+ modelname,
72
+ device,
73
+ readout
74
+ )
75
+
76
+ self.num_classes = num_classes
77
+ self.lns = nn.ModuleList([nn.Linear(self.hidden_size, num_class) for num_class in num_classes])
78
+
79
+ def forward(self, input_tokens, input_head_indices, class_tokens, class_head_indices):
80
+ head_indices = torch.unique(input_head_indices)
81
+ text_representations = self._lm_forward(input_tokens)
82
+
83
+ final_preds = {}
84
+ for i in head_indices:
85
+ if torch.any(input_head_indices == i):
86
+ final_preds[i.item()] = self.lns[i.item()](text_representations[input_head_indices == i])
87
+ else:
88
+ final_preds[i.item()] = torch.tensor([]).to(self.device)
89
+ return final_preds
Deep-Citation/Workspace/acl_scicite_wksp_trl/args.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Namespace(dataset='acl-scicite',
2
+ lambdas='1-0.063',
3
+ data_dir='Data',
4
+ workspace='Workspace/acl_scicite_wksp_trl',
5
+ class_definition='Data/class_def.json',
6
+ batch_size=32,
7
+ lr=5e-05,
8
+ decay_rate=0.5,
9
+ decay_step=5,
10
+ num_epochs=10,
11
+ scheduler='slanted',
12
+ dropout_rate=0.2,
13
+ l2=0.0,
14
+ device='cuda',
15
+ tol=10,
16
+ inference_only=False,
17
+ seed=1,
18
+ lm='scibert',
19
+ max_length=512,
20
+ batch_size_factor=2,
21
+ readout='ch')
Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e45ab11942439f80a121dad5b2d9da392470e0cedf6a7335991fa0a1f616dcb2
3
+ size 439784777
Deep-Citation/data.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import copy
4
+ import torch
5
+ import scipy
6
+ import numpy as np
7
+ import pandas as pd
8
+ from tqdm import tqdm
9
+ from scipy.special import softmax
10
+
11
+ from transformers import AutoTokenizer
12
+
13
+ class CollateFn(object):
14
+ def __init__(self, modelname, class_definitions=None, instance_weights=False):
15
+ self.instance_weights = instance_weights
16
+ use_fast = False if 'deberta' in modelname else True
17
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname, use_fast=use_fast)
18
+ cited_ids = self.tokenizer.encode('<CITED HERE>', add_special_tokens=False)
19
+ self.cited_here_tokens = torch.tensor(cited_ids, dtype=torch.long)
20
+
21
+ if class_definitions is not None:
22
+ self.class_definitions = []
23
+ self.class_head_indices = []
24
+ for i, defs in enumerate(class_definitions):
25
+ self.class_definitions += defs
26
+ self.class_head_indices.append(i * torch.ones(len(defs), dtype=torch.long))
27
+ self.class_head_indices = torch.cat(self.class_head_indices, dim=0)
28
+ self.class_tokens = self.tokenizer(
29
+ self.class_definitions,
30
+ return_tensors="pt",
31
+ max_length=512,
32
+ truncation=True,
33
+ padding=True
34
+ )
35
+
36
+ def _get_readout_mask(self, tokens):
37
+ # cited_here_tokens = torch.tensor([962, 8412, 1530, 1374])
38
+ readout_mask = torch.zeros_like(tokens['input_ids'], dtype=torch.bool)
39
+
40
+ batch_size = tokens['input_ids'].size(0)
41
+ l = tokens['input_ids'].size(1)
42
+ ctk_l = self.cited_here_tokens.size(0)
43
+ for b in range(batch_size):
44
+ for i in range(1, l - ctk_l):
45
+ if torch.equal(tokens['input_ids'][b, i:i+ctk_l], self.cited_here_tokens):
46
+ readout_mask[b, i:i+ctk_l] = True
47
+ if not readout_mask[b].any():
48
+ # Fallback to CLS if the citation marker isn't matched.
49
+ readout_mask[b, 0] = True
50
+ return readout_mask
51
+
52
+ def _tokenize_context(self, context):
53
+ tokens = self.tokenizer(
54
+ context,
55
+ return_tensors="pt",
56
+ max_length=512,
57
+ truncation=True,
58
+ padding=True
59
+ )
60
+ tokens['readout_mask'] = self._get_readout_mask(
61
+ tokens
62
+ )
63
+
64
+ return tokens
65
+
66
+ def __call__(self, samples):
67
+ if self.instance_weights:
68
+ text, labels, ds_indices, instance_weights = list(map(list, zip(*samples)))
69
+ batched_text = self._tokenize_context(text)
70
+ labels = torch.stack(labels)
71
+ ds_indices = torch.stack(ds_indices)
72
+ instance_weights = torch.stack(instance_weights)
73
+ return batched_text, labels, ds_indices, instance_weights
74
+ else:
75
+ text, labels, ds_indices = list(map(list, zip(*samples)))
76
+ batched_text = self._tokenize_context(text)
77
+ labels = torch.stack(labels)
78
+ ds_indices = torch.stack(ds_indices)
79
+
80
+ return batched_text, labels, ds_indices, copy.deepcopy(self.class_tokens), self.class_head_indices
81
+
82
+ class Dataset(object):
83
+ def __init__(self, dataframe, class_definitions, lmbd=1.0):
84
+ self.class_definitions = class_definitions
85
+ self.lmbd = lmbd
86
+ self._load_data(dataframe)
87
+
88
+ def __len__(self):
89
+ return len(self.labels)
90
+
91
+ def __getitem__(self, idx):
92
+ '''Get datapoint with index'''
93
+ return (self.text[idx], self.labels[idx], self.ds_index[idx])
94
+
95
+ def _load_data(self, annotated_data):
96
+ self.labels = torch.LongTensor(annotated_data['label'].tolist())
97
+ self.original_labels = torch.LongTensor(annotated_data['label'].tolist())
98
+ self.ds_index = torch.zeros_like(self.original_labels)
99
+ self.text = annotated_data['context'].tolist()
100
+
101
+ class MultiHeadDatasets(object):
102
+ def __init__(self, datasets, batch_size_factor=2):
103
+ self.text = []
104
+ self.ds_index = []
105
+ self.labels = []
106
+ self.class_definitions = []
107
+ self.lambdas = []
108
+
109
+ self.dataset_sizes = [len(d.labels) for d in datasets]
110
+ if len(self.dataset_sizes) > 1:
111
+ if sum(self.dataset_sizes) / self.dataset_sizes[0] <= batch_size_factor:
112
+ self.sample_auxiliary = False
113
+ self.adjusted_batch_size_factor = sum(self.dataset_sizes) / self.dataset_sizes[0]
114
+ else:
115
+ self.sample_auxiliary = True
116
+ self.sample_distribution = np.array([d.lmbd for d in datasets[1:]]) / sum([d.lmbd for d in datasets[1:]])
117
+ self.adjusted_batch_size_factor = batch_size_factor
118
+ else:
119
+ self.sample_auxiliary = False
120
+ self.adjusted_batch_size_factor = 1
121
+
122
+ for i, d in enumerate(datasets):
123
+ self.text += d.text
124
+ self.ds_index.append(i * torch.ones(len(d.text), dtype=torch.long))
125
+ self.labels.append(d.labels)
126
+ self.class_definitions.append(d.class_definitions)
127
+ self.lambdas.append(d.lmbd)
128
+ self.labels = torch.cat(self.labels, dim=0)
129
+ self.ds_index = torch.cat(self.ds_index, dim=0)
130
+
131
+ def sample_auxiliary_instace(self):
132
+ sampled_dataset_idx = np.random.choice(
133
+ np.arange(1, len(self.dataset_sizes)),
134
+ p=self.sample_distribution
135
+ )
136
+ instance_idx = np.random.choice(
137
+ self.dataset_sizes[sampled_dataset_idx]
138
+ ) + sum(self.dataset_sizes[:sampled_dataset_idx])
139
+ return instance_idx
140
+
141
+ def __len__(self):
142
+ if self.sample_auxiliary: # if the auxiliary dataset is larger than the main dataset
143
+ return self.dataset_sizes[0] * self.adjusted_batch_size_factor
144
+ return len(self.labels)
145
+
146
+ def __getitem__(self, idx):
147
+ '''Get datapoint with index'''
148
+ if idx < self.dataset_sizes[0] or not self.sample_auxiliary:
149
+ return (self.text[idx], self.labels[idx], self.ds_index[idx])
150
+ else:
151
+ real_idx = self.sample_auxiliary_instace()
152
+ return (self.text[real_idx], self.labels[real_idx], self.ds_index[real_idx])
153
+
154
+ def load_class_definitions(filename):
155
+ with open(filename, 'r') as f:
156
+ class_definitions = json.load(f)
157
+
158
+ results = {k:{} for k in class_definitions.keys()}
159
+ for k, v in class_definitions.items():
160
+ for kk, vv in v.items():
161
+ results[k][kk.lower()] = vv
162
+ return results
163
+
164
+ def create_data_channels(filename, class_definition_filename, split=None, lmbd=1.0):
165
+ data = pd.read_csv(filename, sep='\t')
166
+ data = data.fillna(' ')
167
+
168
+ print('Number of data instance: {}'.format(data.shape[0]))
169
+
170
+ # map labels to ids
171
+ unique_labels = data['label'].unique().tolist()
172
+ label2id = {lb: i for i, lb in enumerate(unique_labels)}
173
+
174
+ data['label'] = data['label'].apply(
175
+ lambda x: label2id[x])
176
+
177
+ data_train = data[data['split'] == 'train'].reset_index()
178
+ data_val = data[data['split'] == 'val'].reset_index()
179
+ data_test = data[data['split'] == 'test'].reset_index()
180
+
181
+ class_definitions = load_class_definitions(class_definition_filename)
182
+ dataname = filename.split('/')[-1].split('.')[0]
183
+ data_class_definitions = [class_definitions[dataname][lb.lower()] for lb in unique_labels]
184
+
185
+ train_data = Dataset(data_train, data_class_definitions, lmbd=lmbd)
186
+ val_data = Dataset(data_val, data_class_definitions, lmbd=lmbd)
187
+ test_data = Dataset(data_test, data_class_definitions, lmbd=lmbd)
188
+
189
+ return train_data, val_data, test_data, unique_labels
190
+
191
+ def create_single_data_object(filename, class_definition_filename, split=None, lmbd=1.0):
192
+ data = pd.read_csv(filename, sep='\t')
193
+ data = data.fillna(' ')
194
+
195
+ print('Number of data instance: {}'.format(data.shape[0]))
196
+
197
+ # map labels to ids
198
+ unique_labels = data['label'].unique()
199
+ label2id = {lb: i for i, lb in enumerate(unique_labels)}
200
+
201
+ data['label'] = data['label'].apply(
202
+ lambda x: label2id[x])
203
+
204
+ class_definitions = load_class_definitions(class_definition_filename)
205
+ dataname = filename.split('/')[-1].split('.')[0]
206
+ data_class_definitions = [class_definitions[dataname][lb.lower()] for lb in unique_labels]
207
+
208
+ if split is None:
209
+ return Dataset(data, data_class_definitions, lmbd=lmbd), unique_labels
210
+ else:
211
+ return Dataset(data[data['split'] == split].reset_index(), data_class_definitions, lmbd=lmbd), unique_labels
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1 \
6
+ STREAMLIT_SERVER_HEADLESS=true
7
+
8
+ WORKDIR /app
9
+
10
+ RUN apt-get update && apt-get install -y --no-install-recommends \
11
+ git \
12
+ build-essential \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ COPY requirements.txt /app/requirements.txt
16
+ COPY hf_space/requirements.txt /app/hf_space/requirements.txt
17
+ RUN python -m pip install --upgrade pip && \
18
+ pip install -r requirements.txt
19
+
20
+ COPY . /app
21
+
22
+ EXPOSE 7860
23
+
24
+ CMD ["streamlit", "run", "hf_space/streamlit_app.py", "--server.address", "0.0.0.0", "--server.port", "7860"]
README.md ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SciPaths
3
+ emoji: 🔬
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ # SciPaths
11
+
12
+ SciPaths runs an end-to-end target-contribution pathway pipeline for arXiv papers. It collects downstream citation evidence, derives target contributions from refined citation clusters, decomposes each target contribution into enabling contributions, and grounds those enabling contributions in prior studies.
13
+
14
+ The Hugging Face Space launches the Streamlit app from `hf_space/streamlit_app.py`.
15
+
16
+ ## Citation
17
+
18
+ If you find this useful, please cite our paper as:
19
+
20
+ ```bibtex
21
+ @misc{chamoun2026scipathsforecastingpathwaysscientific,
22
+ title={SciPaths: Forecasting Pathways to Scientific Discovery},
23
+ author={Eric Chamoun and Yizhou Chi and Yulong Chen and Rui Cao and Zifeng Ding and Michalis Korakakis and Andreas Vlachos},
24
+ year={2026},
25
+ eprint={2605.14600},
26
+ archivePrefix={arXiv},
27
+ primaryClass={cs.CL},
28
+ url={https://arxiv.org/abs/2605.14600},
29
+ }
30
+ ```
31
+
32
+ Paper URL: https://arxiv.org/abs/2605.14600
33
+
34
+ ## Required Secrets
35
+
36
+ Set this in the Space settings before publishing:
37
+
38
+ ```text
39
+ GEMINI_API_KEY=<Google Gemini API key>
40
+ ```
41
+
42
+ Optional, for saving completed run artifacts to a Hugging Face Dataset:
43
+
44
+ ```text
45
+ HF_WRITE_TOKEN=<Hugging Face write token>
46
+ RUNS_REPO_ID=<owner/dataset-name>
47
+ RUNS_REPO_TYPE=dataset
48
+ ```
49
+
50
+ Optional, for higher Semantic Scholar limits:
51
+
52
+ ```text
53
+ SEMANTIC_SCHOLAR_API_KEY=<Semantic Scholar API key>
54
+ ```
55
+
56
+ ## Run The Demo Locally
57
+
58
+ ```bash
59
+ pip install -r requirements.txt
60
+ streamlit run hf_space/streamlit_app.py
61
+ ```
62
+
63
+ Then enter an arXiv URL or ID, for example:
64
+
65
+ ```text
66
+ https://arxiv.org/abs/2211.08788
67
+ ```
68
+
69
+ The app writes each run under:
70
+
71
+ ```text
72
+ hf_space/runs/<job_id>/
73
+ ```
74
+
75
+ ## Run One Example From The Command Line
76
+
77
+ This example stores all intermediate files under `runs/example/processed_papers`.
78
+
79
+ ```bash
80
+ mkdir -p runs/example
81
+ printf '[{"id":"2211.08788","title":"","id_type":"ArXiv"}]\n' > runs/example/input_ids.json
82
+
83
+ python src/step_01_fetch/fetch_metadata.py \
84
+ --ids runs/example/input_ids.json \
85
+ --outdir runs/example/processed_papers
86
+
87
+ python src/step_02_mark_citations/replace_citation_markers.py \
88
+ --root runs/example/processed_papers
89
+
90
+ python src/step_03_usage_contexts/build_usage_contexts.py \
91
+ --root runs/example/processed_papers \
92
+ --out-name usage_contexts.json
93
+
94
+ python src/step_04_label_citations/label_citation_functions.py \
95
+ --root runs/example/processed_papers \
96
+ --model-path Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt \
97
+ --model-data-dir Deep-Citation/Data \
98
+ --model-class-def Deep-Citation/Data/class_def.json \
99
+ --model-lm scibert \
100
+ --device cpu
101
+
102
+ python src/step_05_verify_uses_extends/verify_uses_extends.py \
103
+ --root runs/example/processed_papers \
104
+ --k 0 \
105
+ --batch-size 25
106
+
107
+ python src/step_06_extract_paragraphs/extract_arxiv_paragraphs.py \
108
+ --root runs/example/processed_papers
109
+
110
+ python src/step_07_extract_and_refine/extract_contributions_from_citations.py \
111
+ --root runs/example/processed_papers
112
+
113
+ python src/step_07_extract_and_refine/refine_and_filter_clusters_llm.py \
114
+ --root runs/example/processed_papers \
115
+ --inplace \
116
+ --overwrite
117
+
118
+ PYTHONPATH=src \
119
+ python -m step_08_annotation.cli run \
120
+ --paper-dir runs/example/processed_papers/2211.08788 \
121
+ --provider gemini \
122
+ --model gemini/gemini-3.1-pro-preview \
123
+ --formatter-model gemini/gemini-3.1-pro-preview \
124
+ --judge-model gemini/gemini-3.1-pro-preview \
125
+ --candidate-count 3 \
126
+ --output-root runs/example/two_pass_outputs
127
+ ```
128
+
129
+ The final UI payload is written as `pass_2_ui_payload.json` inside the annotation run directory printed by the last command.
130
+
131
+ ## Run Each Step On A Set Of Papers
132
+
133
+ Create an ID file with one entry per paper:
134
+
135
+ ```json
136
+ [
137
+ {"id": "2211.08788", "title": "", "id_type": "ArXiv"},
138
+ {"id": "2311.14919", "title": "", "id_type": "ArXiv"}
139
+ ]
140
+ ```
141
+
142
+ Save it as `runs/batch/input_ids.json`, then run:
143
+
144
+ ```bash
145
+ mkdir -p runs/batch
146
+
147
+ # 1. Fetch metadata + LaTeX for each input paper.
148
+ python src/step_01_fetch/fetch_metadata.py \
149
+ --ids runs/batch/input_ids.json \
150
+ --outdir runs/batch/processed_papers
151
+
152
+ # 2. Add explicit citation markers to the target-paper text.
153
+ python src/step_02_mark_citations/replace_citation_markers.py \
154
+ --root runs/batch/processed_papers
155
+
156
+ # 3. Build downstream citation usage contexts.
157
+ python src/step_03_usage_contexts/build_usage_contexts.py \
158
+ --root runs/batch/processed_papers \
159
+ --out-name usage_contexts.json
160
+
161
+ # 4. Label citation functions with the bundled Deep-Citation classifier.
162
+ python src/step_04_label_citations/label_citation_functions.py \
163
+ --root runs/batch/processed_papers \
164
+ --model-path Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt \
165
+ --model-data-dir Deep-Citation/Data \
166
+ --model-class-def Deep-Citation/Data/class_def.json \
167
+ --model-lm scibert \
168
+ --device cpu
169
+
170
+ # 5. Verify USES/EXTENDS citations with an LLM.
171
+ python src/step_05_verify_uses_extends/verify_uses_extends.py \
172
+ --root runs/batch/processed_papers \
173
+ --k 0 \
174
+ --batch-size 25
175
+
176
+ # 6. Extract arXiv paragraphs from downstream citing papers.
177
+ python src/step_06_extract_paragraphs/extract_arxiv_paragraphs.py \
178
+ --root runs/batch/processed_papers
179
+
180
+ # 7. Extract downstream contribution clusters, then merge/filter them.
181
+ python src/step_07_extract_and_refine/extract_contributions_from_citations.py \
182
+ --root runs/batch/processed_papers
183
+
184
+ python src/step_07_extract_and_refine/refine_and_filter_clusters_llm.py \
185
+ --root runs/batch/processed_papers \
186
+ --inplace \
187
+ --overwrite
188
+
189
+ # 8. Annotate each ready paper: target contributions, enabling contributions, and groundings.
190
+ for paper_dir in runs/batch/processed_papers/*; do
191
+ [ -d "$paper_dir" ] || continue
192
+ [ -f "$paper_dir/usage_discovery_from_contributions.json" ] || continue
193
+ PYTHONPATH=src \
194
+ python -m step_08_annotation.cli run \
195
+ --paper-dir "$paper_dir" \
196
+ --provider gemini \
197
+ --model gemini/gemini-3.1-pro-preview \
198
+ --formatter-model gemini/gemini-3.1-pro-preview \
199
+ --judge-model gemini/gemini-3.1-pro-preview \
200
+ --candidate-count 3 \
201
+ --output-root runs/batch/two_pass_outputs
202
+ done
203
+ ```
204
+
205
+ ## Pipeline Steps
206
+
207
+ 1. **Fetch metadata + LaTeX.** Downloads target-paper metadata, references, citing-paper metadata, and arXiv source where available.
208
+ 2. **Add citation markers.** Inserts normalized citation markers into the target paper so downstream citation contexts can be aligned.
209
+ 3. **Build usage contexts.** Collects text windows around downstream citations to the target paper.
210
+ 4. **Label citation functions.** Uses the bundled Deep-Citation classifier to label citation contexts as background, use, extension, comparison, and related categories.
211
+ 5. **Verify USES/EXTENDS.** Uses an LLM to check whether candidate downstream citations genuinely use or extend the target paper.
212
+ 6. **Extract arXiv paragraphs.** Retrieves fuller paragraphs from citing papers so the system has enough context for contribution extraction.
213
+ 7. **Extract and refine target-contribution clusters.** Extracts what downstream papers use the target paper for, clusters near-duplicates, and filters weak/non-usage evidence.
214
+ 8. **Annotate pathways.** Derives target contributions from the refined clusters, decomposes each into enabling contributions, selects primary groundings, and records additional grounding studies.
215
+
216
+ ## Important Files
217
+
218
+ ```text
219
+ hf_space/streamlit_app.py Streamlit UI
220
+ hf_space/runner.py Orchestrates steps 1-7 for the UI
221
+ hf_space/streamlit_config.py Example papers and tab names
222
+ src/common/ Shared LLM and paper-package utilities
223
+ src/step_01_fetch/ Metadata, references, citations, and LaTeX
224
+ src/step_02_mark_citations/ Citation-marker insertion
225
+ src/step_03_usage_contexts/ Downstream usage-context construction
226
+ src/step_04_label_citations/ Deep-Citation citation-function labeling
227
+ src/step_05_verify_uses_extends/ LLM verification of USES/EXTENDS citations
228
+ src/step_06_extract_paragraphs/ ArXiv paragraph extraction from citing papers
229
+ src/step_07_extract_and_refine/ Contribution extraction and cluster refinement
230
+ src/step_08_annotation/ Target/enabling contribution annotation and grounding
231
+ Deep-Citation/ Bundled citation-function classifier assets
232
+ ```
app.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from hf_space.streamlit_app import main
2
+
3
+
4
+ if __name__ == "__main__":
5
+ main()
hf_space/requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.36.0
2
+ arxiv==2.2.0
3
+ requests==2.32.5
4
+ google-generativeai
5
+ litellm
6
+ rapidfuzz
7
+ bibtexparser
8
+ sentence-transformers
9
+ transformers
10
+ torch
11
+ huggingface_hub
12
+ typer
13
+ tqdm
14
+ pydantic
15
+ numpy
16
+ pandas
17
+ scipy
hf_space/runner.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ import uuid
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Generator, List, Optional, Tuple
12
+
13
+
14
+ @dataclass
15
+ class PipelineConfig:
16
+ repo_root: Path
17
+ source_root: Path
18
+ paper_input: str
19
+ llm_provider: str
20
+ llm_model: str
21
+ llm_model_step4: str
22
+ model_path: str
23
+ model_data_dir: str
24
+ model_class_def: str
25
+ model_lm: str
26
+ device: str
27
+ embedding_model: str
28
+
29
+
30
+ @dataclass
31
+ class PipelineResult:
32
+ job_id: str
33
+ job_dir: Path
34
+ paper_dir: Path
35
+ zip_path: Path
36
+
37
+
38
+ STEP_LABELS = {
39
+ 1: "Fetch metadata + LaTeX for input paper",
40
+ 2: "Add citation markers",
41
+ 3: "Build usage contexts",
42
+ 4: "Label citation functions",
43
+ 5: "Verify USES/EXTENDS",
44
+ 6: "Extract arXiv paragraphs",
45
+ 7: "Extract target contributions and refine clusters",
46
+ }
47
+
48
+ FULL_STEPS = [1, 2, 3, 4, 5, 6, 7]
49
+ STOP_PREFIX = "Pipeline stopped:"
50
+
51
+
52
+ def parse_arxiv_id(paper_input: str) -> str:
53
+ s = (paper_input or "").strip()
54
+ if not s:
55
+ raise ValueError("paper_input is required")
56
+ if "arxiv.org" in s:
57
+ m = re.search(r"arxiv\.org/(abs|pdf)/([^/?#]+)", s)
58
+ if not m:
59
+ raise ValueError(f"Could not parse arXiv ID from URL: {s}")
60
+ s = m.group(2)
61
+ s = s.replace(".pdf", "")
62
+ s = re.sub(r"v\d+$", "", s)
63
+ if not re.match(r"^[0-9]{4}\.[0-9]{4,5}$", s):
64
+ raise ValueError(f"Invalid arXiv ID format: {s}")
65
+ return s
66
+
67
+
68
+ def _build_commands(
69
+ cfg: PipelineConfig,
70
+ step: int,
71
+ job_processed_root: Path,
72
+ paper_id: str,
73
+ ids_path: Optional[Path],
74
+ ) -> List[List[str]]:
75
+ py = sys.executable
76
+ if step == 1:
77
+ assert ids_path is not None
78
+ return [[
79
+ py,
80
+ "src/step_01_fetch/fetch_metadata.py",
81
+ "--ids",
82
+ str(ids_path),
83
+ "--outdir",
84
+ str(job_processed_root),
85
+ ]]
86
+ if step == 2:
87
+ return [[py, "src/step_02_mark_citations/replace_citation_markers.py", "--root", str(job_processed_root)]]
88
+ if step == 3:
89
+ return [[py, "src/step_03_usage_contexts/build_usage_contexts.py", "--root", str(job_processed_root), "--out-name", "usage_contexts.json"]]
90
+ if step == 4:
91
+ return [[
92
+ py,
93
+ "src/step_04_label_citations/label_citation_functions.py",
94
+ "--root",
95
+ str(job_processed_root),
96
+ "--model-path",
97
+ cfg.model_path,
98
+ "--model-data-dir",
99
+ cfg.model_data_dir,
100
+ "--model-class-def",
101
+ cfg.model_class_def,
102
+ "--model-lm",
103
+ cfg.model_lm,
104
+ "--device",
105
+ cfg.device,
106
+ ]]
107
+ if step == 5:
108
+ return [[
109
+ py,
110
+ "src/step_05_verify_uses_extends/verify_uses_extends.py",
111
+ "--root",
112
+ str(job_processed_root),
113
+ "--k",
114
+ "0",
115
+ "--batch-size",
116
+ "25",
117
+ ]]
118
+ if step == 6:
119
+ return [[py, "src/step_06_extract_paragraphs/extract_arxiv_paragraphs.py", "--root", str(job_processed_root)]]
120
+ if step == 7:
121
+ return [
122
+ [py, "src/step_07_extract_and_refine/extract_contributions_from_citations.py", "--root", str(job_processed_root)],
123
+ [py, "src/step_07_extract_and_refine/refine_and_filter_clusters_llm.py", "--root", str(job_processed_root), "--inplace", "--overwrite"],
124
+ ]
125
+ raise ValueError(f"Unknown step: {step}")
126
+
127
+
128
+ def _write_single_id_file(job_dir: Path, arxiv_id: str) -> Path:
129
+ ids_path = job_dir / "input_ids.json"
130
+ payload = [{"id": arxiv_id, "title": "", "id_type": "ArXiv"}]
131
+ ids_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
132
+ return ids_path
133
+
134
+
135
+ def _write_run_metadata(cfg: PipelineConfig, job_dir: Path, paper_id: str, arxiv_id: str) -> None:
136
+ payload = {
137
+ "paper_input": cfg.paper_input,
138
+ "paper_id": paper_id,
139
+ "arxiv_id": arxiv_id,
140
+ "source_root": str(cfg.source_root),
141
+ "steps": FULL_STEPS + ["annotation"],
142
+ "llm_provider": cfg.llm_provider,
143
+ "llm_model": cfg.llm_model,
144
+ "llm_model_step4": cfg.llm_model_step4,
145
+ "device": cfg.device,
146
+ "embedding_model": cfg.embedding_model,
147
+ "timestamp": int(time.time()),
148
+ }
149
+ (job_dir / "run_config.json").write_text(json.dumps(payload, indent=2), encoding="utf-8")
150
+
151
+
152
+ def _zip_job_dir(job_dir: Path) -> Path:
153
+ zip_base = job_dir.parent / job_dir.name
154
+ archive = shutil.make_archive(str(zip_base), "zip", root_dir=str(job_dir))
155
+ return Path(archive)
156
+
157
+
158
+ def _tail_log(path: Path, max_lines: int = 60) -> str:
159
+ try:
160
+ lines = path.read_text(encoding="utf-8", errors="ignore").splitlines()
161
+ except Exception:
162
+ return ""
163
+ if not lines:
164
+ return ""
165
+ return "\n".join(lines[-max_lines:])
166
+
167
+
168
+ def _load_json(path: Path, default=None):
169
+ try:
170
+ return json.loads(path.read_text(encoding="utf-8"))
171
+ except Exception:
172
+ return default
173
+
174
+
175
+ def _write_summary_and_zip(job_dir: Path, summary_lines: List[str]) -> Path:
176
+ (job_dir / "summary.txt").write_text("\n".join(summary_lines), encoding="utf-8")
177
+ return _zip_job_dir(job_dir)
178
+
179
+
180
+ def _count_verified_uses_extends(payload: dict) -> int:
181
+ records = payload.get("confirmed") or payload.get("verified_contexts") or payload.get("contexts") or payload.get("items") or []
182
+ if not isinstance(records, list):
183
+ return 0
184
+ accepted = {"USES", "EXTENDS", "Uses", "Extends"}
185
+ return sum(1 for item in records if isinstance(item, dict) and item.get("label") in accepted)
186
+
187
+
188
+ def _stop_reason_after_step(step: int, paper_dir: Path) -> str | None:
189
+ if step == 1:
190
+ if not paper_dir.exists():
191
+ return "metadata could not be fetched for this paper"
192
+ if not (paper_dir / "processed_main.tex").exists():
193
+ return "arXiv source could not be retrieved or converted for this paper"
194
+ citations = _load_json(paper_dir / "citations_metadata.json", [])
195
+ if not isinstance(citations, list) or not citations:
196
+ return "Semantic Scholar returned no citing papers for this target paper"
197
+
198
+ if step == 3:
199
+ usage = _load_json(paper_dir / "usage_contexts.json", {})
200
+ if not isinstance(usage, dict):
201
+ return "citation usage contexts could not be built"
202
+ if int(usage.get("num_contexts") or 0) == 0:
203
+ return "no citation usage contexts were found"
204
+
205
+ if step == 4:
206
+ labels = _load_json(paper_dir / "usage_context_labels.json", {})
207
+ contexts = labels.get("labels") if isinstance(labels, dict) else None
208
+ if not isinstance(contexts, list) or not contexts:
209
+ return "citation-function labeling produced no labeled contexts"
210
+
211
+ if step == 5:
212
+ verified = _load_json(paper_dir / "usage_uses_extends_verified.json", {})
213
+ if not isinstance(verified, dict):
214
+ return "USES/EXTENDS verification did not produce an output file"
215
+ if _count_verified_uses_extends(verified) == 0:
216
+ return "no downstream citations were verified as USES or EXTENDS"
217
+
218
+ if step == 6:
219
+ paragraphs = _load_json(paper_dir / "usage_citing_paragraphs.json", {})
220
+ citing = paragraphs.get("citing_papers") if isinstance(paragraphs, dict) else None
221
+ if not isinstance(citing, list) or not citing:
222
+ return "no citing-paper paragraphs could be extracted from arXiv"
223
+ usable = [
224
+ item for item in citing
225
+ if isinstance(item, dict)
226
+ and not item.get("error")
227
+ and (item.get("matched_paragraphs") or item.get("target_citing_paragraphs"))
228
+ ]
229
+ if not usable:
230
+ return "arXiv paragraph extraction returned no usable citing-paper text"
231
+
232
+ if step == 7:
233
+ contributions = _load_json(paper_dir / "usage_contributions.json", {})
234
+ items = contributions.get("contributions") if isinstance(contributions, dict) else None
235
+ if not isinstance(items, list) or not items:
236
+ return "no downstream target-contribution evidence could be extracted"
237
+ refined = _load_json(paper_dir / "usage_discovery_from_contributions.json", {})
238
+ clusters = refined.get("clusters") if isinstance(refined, dict) else None
239
+ if not isinstance(clusters, list) or not clusters:
240
+ return "no valid downstream usage clusters survived refinement"
241
+
242
+ return None
243
+
244
+
245
+ def run_pipeline(cfg: PipelineConfig, output_root: Path) -> Generator[Tuple[str, Optional[str]], None, PipelineResult]:
246
+ output_root.mkdir(parents=True, exist_ok=True)
247
+ job_id = f"job_{int(time.time())}_{uuid.uuid4().hex[:8]}"
248
+ job_dir = output_root / job_id
249
+ job_processed_root = job_dir / "processed_papers"
250
+ job_logs = job_dir / "logs"
251
+
252
+ job_processed_root.mkdir(parents=True, exist_ok=True)
253
+ job_logs.mkdir(parents=True, exist_ok=True)
254
+
255
+ arxiv_id = parse_arxiv_id(cfg.paper_input)
256
+ paper_id = arxiv_id
257
+ ids_path = _write_single_id_file(job_dir, arxiv_id)
258
+ _write_run_metadata(cfg, job_dir, paper_id, arxiv_id)
259
+
260
+ base_env = os.environ.copy()
261
+ base_env["LLM_PROVIDER"] = cfg.llm_provider
262
+ base_env["LLM_MODEL"] = cfg.llm_model
263
+
264
+ summary_lines: List[str] = []
265
+ paper_dir = job_processed_root / paper_id
266
+
267
+ max_step = 8
268
+ for step in FULL_STEPS:
269
+ label = STEP_LABELS[step]
270
+ log_file = job_logs / f"step_{step:02d}.log"
271
+ summary_lines.append(f"[{step}] {label}")
272
+ yield (f"Step {step}/{max_step}: {label}", None)
273
+
274
+ env = base_env.copy()
275
+ if step == 5 and cfg.llm_model_step4:
276
+ env["LLM_MODEL"] = cfg.llm_model_step4
277
+
278
+ with log_file.open("w", encoding="utf-8") as lf:
279
+ return_code = 0
280
+ failed_cmd: List[str] | None = None
281
+ for cmd in _build_commands(cfg, step, job_processed_root, paper_id, ids_path):
282
+ lf.write(f"$ {' '.join(cmd)}\n\n")
283
+ proc = subprocess.Popen(
284
+ cmd,
285
+ cwd=str(cfg.repo_root),
286
+ stdout=subprocess.PIPE,
287
+ stderr=subprocess.STDOUT,
288
+ text=True,
289
+ encoding="utf-8",
290
+ errors="ignore",
291
+ env=env,
292
+ )
293
+ assert proc.stdout is not None
294
+ for line in proc.stdout:
295
+ lf.write(line)
296
+ return_code = proc.wait()
297
+ if return_code != 0:
298
+ failed_cmd = cmd
299
+ break
300
+
301
+ if return_code != 0:
302
+ summary_lines.append(f"FAILED at step {step}")
303
+ zip_path = _write_summary_and_zip(job_dir, summary_lines)
304
+ tail = _tail_log(log_file)
305
+ if tail:
306
+ yield (
307
+ f"Step {step} failed.\n\nCommand: {' '.join(failed_cmd or [])}\n\nLast log lines:\n{tail}",
308
+ str(zip_path),
309
+ )
310
+ else:
311
+ yield (f"Step {step} failed. Command: {' '.join(failed_cmd or [])}", str(zip_path))
312
+ return PipelineResult(job_id=job_id, job_dir=job_dir, paper_dir=paper_dir, zip_path=zip_path)
313
+ else:
314
+ yield (f"Step {step} complete", None)
315
+
316
+ if step == 1 and not paper_dir.exists():
317
+ summary_lines.append("FAILED: fetch_metadata did not create paper directory")
318
+ zip_path = _write_summary_and_zip(job_dir, summary_lines)
319
+ yield (f"Step 1 finished but paper dir missing: {paper_dir}", str(zip_path))
320
+ return PipelineResult(job_id=job_id, job_dir=job_dir, paper_dir=paper_dir, zip_path=zip_path)
321
+
322
+ stop_reason = _stop_reason_after_step(step, paper_dir)
323
+ if stop_reason:
324
+ message = f"{STOP_PREFIX} {stop_reason}."
325
+ summary_lines.append(message)
326
+ zip_path = _write_summary_and_zip(job_dir, summary_lines)
327
+ yield (message, str(zip_path))
328
+ return PipelineResult(job_id=job_id, job_dir=job_dir, paper_dir=paper_dir, zip_path=zip_path)
329
+
330
+ summary_lines.append("SUCCESS")
331
+ zip_path = _write_summary_and_zip(job_dir, summary_lines)
332
+ yield ("Pipeline completed successfully.", str(zip_path))
333
+ return PipelineResult(job_id=job_id, job_dir=job_dir, paper_dir=paper_dir, zip_path=zip_path)
hf_space/streamlit_app.py ADDED
@@ -0,0 +1,864 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import time
5
+ import html
6
+ from pathlib import Path
7
+ from typing import Any, Optional
8
+
9
+ import streamlit as st
10
+ try:
11
+ from huggingface_hub import HfApi
12
+ except Exception:
13
+ HfApi = None
14
+
15
+ SRC = Path(__file__).resolve().parent
16
+ REPO_ROOT = SRC.parent
17
+ for extra in (SRC, REPO_ROOT / "src"):
18
+ extra_str = str(extra)
19
+ if extra_str not in sys.path:
20
+ sys.path.insert(0, extra_str)
21
+
22
+ import runner as runner_module
23
+ from runner import PipelineConfig
24
+ from common.paper_package import load_paper_package
25
+ from step_08_annotation.pipeline import TwoPassAnnotationPipeline
26
+ from streamlit_config import EXAMPLES, TAB_NAMES
27
+
28
+ DEFAULT_SOURCE_ROOT = str(REPO_ROOT / "src" / "processed_papers")
29
+ DEFAULT_OUTPUT_ROOT = str(REPO_ROOT / "hf_space" / "runs")
30
+
31
+ CUSTOM_CSS = """
32
+ <style>
33
+ .block-container {max-width: 1450px; padding-top: 2rem; padding-bottom: 2rem;}
34
+ [data-testid="stSidebar"] {background: #f5f7fb; border-right: 1px solid #e2e8f0;}
35
+ .hero-title {font-size: 3rem; font-weight: 800; letter-spacing: -0.03em; color: #1f2937; margin-bottom: 0.35rem;}
36
+ .hero-sub {font-size: 1rem; color: #6b7280; max-width: 920px; margin-bottom: 1.25rem;}
37
+ .metric-card {background: #ffffff; border: 1px solid #e5e7eb; border-radius: 16px; padding: 1rem 1.1rem; min-height: 96px;}
38
+ .metric-label {font-size: 0.78rem; font-weight: 700; color: #6b7280; text-transform: uppercase; letter-spacing: 0.04em;}
39
+ .metric-value {font-size: 1.7rem; font-weight: 800; color: #111827; margin-top: 0.35rem;}
40
+ .soft-card {background: #ffffff; border: 1px solid #e5e7eb; border-radius: 16px; padding: 1rem 1.1rem;}
41
+ .claim-card {background: #ffffff; border: 1px solid #e5e7eb; border-radius: 18px; overflow: hidden; margin-bottom: 1rem;}
42
+ .claim-head {padding: 1rem 1.1rem; border-bottom: 1px solid #eef2f7; background: #fcfdff;}
43
+ .claim-kicker {font-size: 0.78rem; font-weight: 800; color: #2563eb; text-transform: uppercase; letter-spacing: 0.04em; margin-bottom: 0.45rem;}
44
+ .claim-text {font-size: 1.05rem; line-height: 1.55; font-weight: 700; color: #111827;}
45
+ .claim-grid {display: grid; grid-template-columns: 1.7fr 1fr;}
46
+ .claim-main, .claim-side {padding: 1rem 1.1rem;}
47
+ .claim-side {border-left: 1px solid #eef2f7; background: #fbfdff;}
48
+ .section-label {font-size: 0.78rem; font-weight: 800; color: #6b7280; text-transform: uppercase; letter-spacing: 0.04em; margin-bottom: 0.7rem;}
49
+ .pill-row {display: flex; flex-wrap: wrap; gap: 0.45rem; margin-top: 0.8rem;}
50
+ .pill {display: inline-block; padding: 0.28rem 0.7rem; border-radius: 999px; border: 1px solid #dbe4f0; background: #f8fbff; color: #1d4ed8; font-size: 0.78rem; font-weight: 700;}
51
+ .ingredient-card {border: 1px solid #e6edf7; border-left: 4px solid #2563eb; border-radius: 12px; background: #ffffff; padding: 0.9rem; margin-bottom: 0.8rem;}
52
+ .ingredient-top {display: flex; justify-content: space-between; gap: 0.7rem; align-items: flex-start; margin-bottom: 0.45rem;}
53
+ .ingredient-name {font-size: 0.98rem; font-weight: 800; color: #111827; line-height: 1.4;}
54
+ .role-pill {display: inline-block; padding: 0.2rem 0.55rem; border-radius: 999px; border: 1px solid #ddd6fe; background: #f5f3ff; color: #6d28d9; font-size: 0.72rem; font-weight: 800; white-space: nowrap;}
55
+ .field {font-size: 0.88rem; line-height: 1.5; color: #374151; margin-top: 0.4rem;}
56
+ .field b {color: #111827;}
57
+ .grounding-block {margin-top: 0.75rem; display: grid; gap: 0.55rem;}
58
+ .grounding-card {border-radius: 10px; padding: 0.65rem 0.75rem; border: 1px solid #bfdbfe; background: #eff6ff;}
59
+ .grounding-card.additional {border-color: #fed7aa; background: #fff7ed;}
60
+ .grounding-label {font-size: 0.7rem; font-weight: 900; text-transform: uppercase; letter-spacing: 0.05em; margin-bottom: 0.25rem;}
61
+ .grounding-label.primary {color: #1d4ed8;}
62
+ .grounding-label.additional {color: #c2410c;}
63
+ .grounding-title {font-size: 0.9rem; font-weight: 800; color: #111827; line-height: 1.35;}
64
+ .grounding-meta {font-size: 0.78rem; color: #64748b; margin-top: 0.2rem;}
65
+ .cluster-card {border: 1px solid #e5e7eb; border-radius: 16px; background: #ffffff; padding: 1rem 1.1rem; margin-bottom: 0.9rem;}
66
+ .cluster-card.additional-study {border-color: #fed7aa; background: #fff7ed;}
67
+ .cluster-title {font-size: 1rem; font-weight: 800; color: #111827; line-height: 1.45; margin-bottom: 0.4rem;}
68
+ .cluster-meta {font-size: 0.86rem; color: #6b7280; margin-bottom: 0.65rem;}
69
+ .empty-card {border: 1px dashed #cbd5e1; border-radius: 14px; padding: 1rem; background: #ffffff; color: #64748b;}
70
+ .example-btn button {border-radius: 999px !important; border: 1px solid #fecaca !important; color: #991b1b !important; background: #fff !important;}
71
+ @media (max-width: 1050px) {.claim-grid {grid-template-columns: 1fr;} .claim-side {border-left: none; border-top: 1px solid #eef2f7;}}
72
+ </style>
73
+ """
74
+
75
+
76
+ def get_secret(name: str, default: str = "") -> str:
77
+ value = os.getenv(name)
78
+ if value:
79
+ return value
80
+ try:
81
+ return st.secrets[name]
82
+ except Exception:
83
+ return default
84
+
85
+
86
+ def run_repo_config() -> tuple[str | None, str, str | None]:
87
+ repo_id = get_secret("RUNS_REPO_ID", "")
88
+ repo_type = get_secret("RUNS_REPO_TYPE", "dataset")
89
+ token = get_secret("HF_WRITE_TOKEN", "") or get_secret("HF_TOKEN", "")
90
+ return repo_id or None, repo_type, token or None
91
+
92
+
93
+ def remote_run_prefix(job_id: str) -> str:
94
+ return f"runs/{job_id}"
95
+
96
+
97
+ def upload_run_artifact(job_dir: Path) -> str:
98
+ repo_id, repo_type, token = run_repo_config()
99
+ if not repo_id or not token:
100
+ return ""
101
+ if HfApi is None:
102
+ return "upload_failed: huggingface_hub is not installed"
103
+
104
+ job_id = job_dir.name
105
+ remote_prefix = remote_run_prefix(job_id)
106
+ uploaded: list[str] = []
107
+ try:
108
+ api = HfApi(token=token)
109
+ for name in ["input_ids.json", "run_config.json", "summary.txt"]:
110
+ path = job_dir / name
111
+ if path.exists():
112
+ api.upload_file(
113
+ path_or_fileobj=str(path),
114
+ path_in_repo=f"{remote_prefix}/{name}",
115
+ repo_id=repo_id,
116
+ repo_type=repo_type,
117
+ commit_message=f"Upload {name} for {job_id}",
118
+ )
119
+ uploaded.append(name)
120
+
121
+ for folder_name in ["logs", "processed_papers", "two_pass_outputs"]:
122
+ folder = job_dir / folder_name
123
+ if not folder.exists():
124
+ continue
125
+ files = [path for path in folder.rglob("*") if path.is_file()]
126
+ if not files:
127
+ continue
128
+ api.upload_folder(
129
+ folder_path=str(folder),
130
+ path_in_repo=f"{remote_prefix}/{folder_name}",
131
+ repo_id=repo_id,
132
+ repo_type=repo_type,
133
+ commit_message=f"Upload {folder_name} for {job_id}",
134
+ ignore_patterns=["__pycache__/*", "*.pyc", "*.zip"],
135
+ )
136
+ uploaded.append(f"{folder_name}[{len(files)} files]")
137
+
138
+ return f"{repo_type}:{repo_id}/{remote_prefix}/ (uploaded: {', '.join(uploaded) or 'nothing'})"
139
+ except Exception as exc:
140
+ return f"upload_failed: {exc}"
141
+
142
+
143
+ def _load_json(path: Path) -> Optional[dict]:
144
+ if not path.exists():
145
+ return None
146
+ try:
147
+ return json.loads(path.read_text(encoding="utf-8"))
148
+ except Exception:
149
+ return None
150
+
151
+
152
+ def _status_from_line(line: str, current: str) -> str:
153
+ text = (line or "").strip()
154
+ text = _display_log_line(text)
155
+ if text.startswith("Pipeline stopped:"):
156
+ return "Stopped"
157
+ if text.startswith("Step "):
158
+ return text
159
+ if "failed" in text.lower():
160
+ return f"Failed: {text}"
161
+ if "completed successfully" in text.lower():
162
+ return "Completed"
163
+ return current
164
+
165
+
166
+ def _display_log_line(line: str) -> str:
167
+ text = (line or "").strip()
168
+ if text.startswith("Step ") and " failed." in text:
169
+ return text.splitlines()[0]
170
+ if text == "[annotation] starting cluster-first two-pass annotation":
171
+ return "Step 8/8: Annotate target contributions and enabling contributions"
172
+ if text.startswith("[annotation] complete:"):
173
+ return "Step 8 complete"
174
+ if text == "Pipeline completed successfully.":
175
+ return text
176
+ return text
177
+
178
+
179
+ def _format_step_event(line: str) -> str:
180
+ text = _display_log_line(line)
181
+ if not text:
182
+ return ""
183
+ if text.startswith("Step ") and "/" in text and ":" in text:
184
+ return f"🛠️ {text}"
185
+ if text.startswith("Step ") and text.endswith(" complete"):
186
+ return f"✅ {text}"
187
+ if text.lower().startswith("stopped after step"):
188
+ return f"⏹️ {text}"
189
+ if text.startswith("Pipeline stopped:"):
190
+ return f"⏹️ {text}"
191
+ if "failed" in text.lower():
192
+ return f"❌ {text}"
193
+ if "completed successfully" in text.lower():
194
+ return f"✅ {text}"
195
+ return f"• {text}"
196
+
197
+
198
+ def _ensure_state():
199
+ defaults = {
200
+ "paper_input": "",
201
+ "run_status": "Idle",
202
+ "run_logs": [],
203
+ "run_events": [],
204
+ "artifact_path": None,
205
+ "run_dir_path": None,
206
+ "paper_dir_path": None,
207
+ "annotation_payload_path": None,
208
+ "run_summary": None,
209
+ "annotation_skipped_reason": None,
210
+ "pipeline_failed_reason": None,
211
+ "remote_artifact_ref": "",
212
+ }
213
+ for key, value in defaults.items():
214
+ st.session_state.setdefault(key, value)
215
+
216
+
217
+ def _metric_card(label: str, value: Any):
218
+ st.markdown(
219
+ f"<div class='metric-card'><div class='metric-label'>{label}</div><div class='metric-value'>{value}</div></div>",
220
+ unsafe_allow_html=True,
221
+ )
222
+
223
+
224
+ def _esc(value: Any) -> str:
225
+ return html.escape("" if value is None else str(value))
226
+
227
+
228
+ def _safe_int(value: Any, default: int = 0) -> int:
229
+ try:
230
+ return int(value)
231
+ except (TypeError, ValueError):
232
+ return default
233
+
234
+
235
+ def _grounding_html(grounding: Optional[dict], label: str, kind: str) -> str:
236
+ if not grounding:
237
+ return ""
238
+ title = (
239
+ grounding.get("ref_title")
240
+ or grounding.get("title")
241
+ or grounding.get("paper_id")
242
+ or grounding.get("ref_id")
243
+ or "__NONE__"
244
+ )
245
+ meta = []
246
+ if grounding.get("paper_id"):
247
+ meta.append(f"paper_id: {grounding.get('paper_id')}")
248
+ elif grounding.get("ref_id"):
249
+ meta.append(f"ref_id: {grounding.get('ref_id')}")
250
+ if grounding.get("ref_year"):
251
+ meta.append(str(grounding.get("ref_year")))
252
+ authors = grounding.get("ref_authors")
253
+ if isinstance(authors, list) and authors:
254
+ meta.append(", ".join(str(author) for author in authors[:3]))
255
+ meta_html = f"<div class='grounding-meta'>{_esc(' · '.join(meta))}</div>" if meta else ""
256
+ extra_class = " additional" if kind == "additional" else ""
257
+ return (
258
+ f"<div class='grounding-card{extra_class}'>"
259
+ f"<div class='grounding-label {kind}'>{_esc(label)}</div>"
260
+ f"<div class='grounding-title'>{_esc(title)}</div>"
261
+ f"{meta_html}"
262
+ "</div>"
263
+ )
264
+
265
+
266
+ def _study_key(item: dict) -> str:
267
+ for key in ["paper_id", "ref_id", "ref_title", "title"]:
268
+ value = item.get(key)
269
+ if value:
270
+ return str(value).lower()
271
+ return ""
272
+
273
+
274
+ def _collect_grounded_studies(discoveries: list[dict], ingredients: list[dict]) -> list[dict]:
275
+ studies: list[dict] = []
276
+ seen: set[str] = set()
277
+ for item in discoveries:
278
+ if not isinstance(item, dict):
279
+ continue
280
+ copied = dict(item)
281
+ copied["_grounding_kind"] = "primary"
282
+ copied["_grounding_label"] = "Primary study"
283
+ key = _study_key(copied)
284
+ if key:
285
+ seen.add(key)
286
+ studies.append(copied)
287
+
288
+ for idx, ingredient in enumerate(ingredients, start=1):
289
+ if not isinstance(ingredient, dict):
290
+ continue
291
+ canonical = ingredient.get("canonical_grounding") or {}
292
+ canonical_key = _study_key(canonical) if isinstance(canonical, dict) else ""
293
+ annotation = ingredient.get("canonical_annotation") or {}
294
+ for ref in ingredient.get("additional_groundings") or []:
295
+ if not isinstance(ref, dict):
296
+ continue
297
+ key = _study_key(ref)
298
+ if key and (key == canonical_key or key in seen):
299
+ continue
300
+ copied = dict(ref)
301
+ copied["_grounding_kind"] = "additional"
302
+ copied["_grounding_label"] = f"Additional study for enabling contribution {idx}"
303
+ copied.setdefault("role", annotation.get("role") or ", ".join(annotation.get("roles") or []))
304
+ copied.setdefault("contribution", annotation.get("contribution"))
305
+ copied.setdefault("rationale", annotation.get("rationale"))
306
+ if key:
307
+ seen.add(key)
308
+ studies.append(copied)
309
+ return studies
310
+
311
+
312
+ def _render_reference_list(discoveries: list[dict], ingredients: Optional[list[dict]] = None):
313
+ studies = _collect_grounded_studies(discoveries, ingredients or [])
314
+ if not studies:
315
+ st.markdown("<div class='empty-card'>No grounded studies listed for this target contribution.</div>", unsafe_allow_html=True)
316
+ return
317
+ for item in studies:
318
+ title = item.get("ref_title") or item.get("title") or item.get("ref_id") or item.get("paper_id") or "Untitled reference"
319
+ is_additional = item.get("_grounding_kind") == "additional"
320
+ meta = []
321
+ if item.get("_grounding_label"):
322
+ meta.append(str(item.get("_grounding_label")))
323
+ if item.get("role"):
324
+ meta.append(str(item.get("role")))
325
+ if item.get("ref_year"):
326
+ meta.append(str(item.get("ref_year")))
327
+ class_name = "cluster-card additional-study" if is_additional else "cluster-card"
328
+ body = [f"<div class='{class_name}'><div class='cluster-title'>{_esc(title)}</div>"]
329
+ if meta:
330
+ body.append(f"<div class='cluster-meta'>{_esc(' · '.join(meta))}</div>")
331
+ if item.get("contribution"):
332
+ body.append(f"<div class='field'><b>Contribution.</b> {_esc(item.get('contribution'))}</div>")
333
+ if item.get("rationale"):
334
+ body.append(f"<div class='field'><b>Rationale.</b> {_esc(item.get('rationale'))}</div>")
335
+ body.append("</div>")
336
+ st.markdown("".join(body), unsafe_allow_html=True)
337
+
338
+
339
+ def _render_claims_tab(payload: Optional[dict]):
340
+ if not payload:
341
+ st.markdown("<div class='empty-card'>No annotation payload is available yet.</div>", unsafe_allow_html=True)
342
+ return
343
+ claims = payload.get("claims") or []
344
+ if not claims:
345
+ st.markdown("<div class='empty-card'>The run completed, but no target contributions were produced.</div>", unsafe_allow_html=True)
346
+ return
347
+
348
+ for idx, claim in enumerate(claims, start=1):
349
+ claim_id = claim.get("claim_id") or f"C{idx}"
350
+ claim_text = claim.get("rewritten_claim") or claim.get("text") or "(missing target contribution text)"
351
+ ingredients = claim.get("ingredients") or []
352
+ discoveries = claim.get("enabling_discoveries") or []
353
+ grounded_studies = _collect_grounded_studies(discoveries, ingredients)
354
+ meta_pills = []
355
+ if claim.get("decision"):
356
+ meta_pills.append(str(claim.get("decision")))
357
+ if claim.get("cluster_id"):
358
+ meta_pills.append(f"cluster {claim.get('cluster_id')}")
359
+ meta_pills.append(f"{len(ingredients)} enabling contribution{'s' if len(ingredients) != 1 else ''}")
360
+ meta_pills.append(f"{len(grounded_studies)} grounded stud{'ies' if len(grounded_studies) != 1 else 'y'}")
361
+
362
+ pills_html = "".join(f"<span class='pill'>{_esc(p)}</span>" for p in meta_pills)
363
+ st.markdown(
364
+ f"""
365
+ <div class='claim-card'>
366
+ <div class='claim-head'>
367
+ <div class='claim-kicker'>Target contribution {idx} · {_esc(claim_id)}</div>
368
+ <div class='claim-text'>{_esc(claim_text)}</div>
369
+ <div class='pill-row'>{pills_html}</div>
370
+ </div>
371
+ </div>
372
+ """,
373
+ unsafe_allow_html=True,
374
+ )
375
+ left, right = st.columns([1.7, 1.0], gap="large")
376
+ with left:
377
+ st.markdown("<div class='section-label'>Decomposition</div>", unsafe_allow_html=True)
378
+ if not ingredients:
379
+ st.markdown("<div class='empty-card'>No enabling contributions for this target contribution.</div>", unsafe_allow_html=True)
380
+ for ingredient_idx, ingredient in enumerate(ingredients, start=1):
381
+ annotation = ingredient.get("canonical_annotation") or {}
382
+ role = annotation.get("role") or ", ".join(annotation.get("roles") or []) or "UNSPECIFIED"
383
+ canonical_grounding = ingredient.get("canonical_grounding") or {}
384
+ extras = ingredient.get("additional_groundings") or []
385
+ grounding_parts = []
386
+ if canonical_grounding:
387
+ grounding_parts.append(
388
+ _grounding_html(canonical_grounding, "Primary grounding", "primary")
389
+ )
390
+ for ref in extras:
391
+ if not isinstance(ref, dict):
392
+ continue
393
+ if canonical_grounding and (
394
+ ref.get("paper_id") == canonical_grounding.get("paper_id")
395
+ or ref.get("ref_id") == canonical_grounding.get("ref_id")
396
+ ):
397
+ continue
398
+ grounding_parts.append(
399
+ _grounding_html(ref, "Additional grounding", "additional")
400
+ )
401
+ if not grounding_parts:
402
+ canonical_ref_id = ingredient.get("canonical_ref_id") or "__NONE__"
403
+ grounding_parts.append(
404
+ "<div class='grounding-card'>"
405
+ "<div class='grounding-label primary'>Grounding</div>"
406
+ f"<div class='grounding-title'>{_esc(canonical_ref_id)}</div>"
407
+ "</div>"
408
+ )
409
+ grounding_block = (
410
+ "<div class='grounding-block'>"
411
+ f"<div class='section-label'>Groundings for enabling contribution {ingredient_idx}</div>"
412
+ + "".join(grounding_parts)
413
+ + "</div>"
414
+ )
415
+ st.markdown(
416
+ f"""
417
+ <div class='ingredient-card'>
418
+ <div class='ingredient-top'>
419
+ <div class='ingredient-name'>{ingredient_idx}. {_esc(ingredient.get('ingredient') or '(missing enabling contribution)')}</div>
420
+ <div class='role-pill'>{_esc(role)}</div>
421
+ </div>
422
+ <div class='field'><b>Contribution.</b> {_esc(annotation.get('contribution') or '')}</div>
423
+ <div class='field'><b>Rationale.</b> {_esc(annotation.get('rationale') or '')}</div>
424
+ <div class='field'><b>Evidence.</b> {_esc(annotation.get('evidence_span') or '')}</div>
425
+ {grounding_block}
426
+ </div>
427
+ """,
428
+ unsafe_allow_html=True,
429
+ )
430
+ with right:
431
+ st.markdown("<div class='section-label'>Grounded and additional studies</div>", unsafe_allow_html=True)
432
+ _render_reference_list(discoveries, ingredients)
433
+
434
+
435
+ def _render_clusters_tab(discovery: Optional[dict], contributions: list[dict]):
436
+ if not discovery:
437
+ st.markdown("<div class='empty-card'>No refined cluster file is available yet.</div>", unsafe_allow_html=True)
438
+ return
439
+ clusters = discovery.get("clusters") or []
440
+ dropped = discovery.get("dropped_clusters") or []
441
+ if not clusters:
442
+ st.markdown("<div class='empty-card'>No valid downstream usage clusters survived refinement and filtering.</div>", unsafe_allow_html=True)
443
+ if dropped:
444
+ with st.expander(f"Dropped clusters ({len(dropped)})", expanded=False):
445
+ st.json(dropped)
446
+ return
447
+
448
+ for cluster in clusters:
449
+ cluster_id = cluster.get("cluster_id", "")
450
+ rep = cluster.get("representative_claim") or cluster.get("cluster_title") or "(missing representative claim)"
451
+ count = _safe_int(cluster.get("count"), len(cluster.get("claim_indices") or []))
452
+ source_ids = cluster.get("source_cluster_ids") or []
453
+ merge_rationale = cluster.get("merge_rationale") or ""
454
+ st.markdown(
455
+ f"""
456
+ <div class='cluster-card'>
457
+ <div class='cluster-title'>{_esc(rep)}</div>
458
+ <div class='cluster-meta'>Cluster {_esc(cluster_id)} · {count} contribution instance{'s' if count != 1 else ''}</div>
459
+ </div>
460
+ """,
461
+ unsafe_allow_html=True,
462
+ )
463
+ meta_cols = st.columns([1.3, 1.3, 1.4])
464
+ with meta_cols[0]:
465
+ st.caption("Cluster ID")
466
+ st.code(str(cluster_id), language="text")
467
+ with meta_cols[1]:
468
+ st.caption("Source clusters")
469
+ st.code(", ".join(str(x) for x in source_ids) if source_ids else "singleton", language="text")
470
+ with meta_cols[2]:
471
+ st.caption("Merge rationale")
472
+ st.write(merge_rationale or "—")
473
+
474
+ claim_indices = cluster.get("claim_indices") or []
475
+ if claim_indices:
476
+ with st.expander(f"Linked contribution instances ({len(claim_indices)})", expanded=False):
477
+ for idx in claim_indices:
478
+ try:
479
+ j = int(idx)
480
+ except Exception:
481
+ continue
482
+ if 0 <= j < len(contributions):
483
+ item = contributions[j] or {}
484
+ title = item.get("citing_title") or item.get("citing_paper_id") or "Unknown citing paper"
485
+ claim = item.get("paper_claim") or item.get("claim") or "(missing claim)"
486
+ rationale = item.get("rationale") or ""
487
+ evidence = item.get("evidence_span") or ""
488
+ st.markdown(f"**{title}**")
489
+ st.write(claim)
490
+ if rationale:
491
+ st.caption(f"Rationale: {rationale}")
492
+ if evidence:
493
+ st.caption(f"Evidence: {evidence}")
494
+ st.divider()
495
+
496
+ if dropped:
497
+ with st.expander(f"Dropped clusters ({len(dropped)})", expanded=False):
498
+ st.json(dropped)
499
+
500
+
501
+ def run_two_pass_annotation(
502
+ paper_dir: Path,
503
+ annotation_output_root: Path,
504
+ llm_provider: str,
505
+ llm_model: str,
506
+ formatter_model: str,
507
+ judge_model: str,
508
+ candidate_count: int,
509
+ ):
510
+ paper = load_paper_package(paper_dir)
511
+ pipeline = TwoPassAnnotationPipeline(
512
+ provider=llm_provider,
513
+ model=llm_model,
514
+ formatter_model=formatter_model or None,
515
+ judge_model=judge_model or None,
516
+ output_root=annotation_output_root,
517
+ annotator_id="streamlit_hf_space",
518
+ candidate_count=max(1, int(candidate_count)),
519
+ formatter_max_attempts=3,
520
+ include_reference_examples=True,
521
+ prompt_profile="full",
522
+ )
523
+ result = pipeline.run(paper)
524
+ return result.result, result.run_dir
525
+
526
+
527
+ def run_pipeline_stream(
528
+ paper_input: str,
529
+ source_root: str,
530
+ output_root: str,
531
+ llm_provider: str,
532
+ llm_model: str,
533
+ llm_model_step4: str,
534
+ formatter_model: str,
535
+ judge_model: str,
536
+ candidate_count: int,
537
+ ):
538
+ gemini_key = get_secret("GEMINI_API_KEY")
539
+ if gemini_key:
540
+ os.environ["GEMINI_API_KEY"] = gemini_key
541
+
542
+ cfg = PipelineConfig(
543
+ repo_root=REPO_ROOT,
544
+ source_root=Path(source_root).expanduser().resolve(),
545
+ paper_input=paper_input.strip(),
546
+ llm_provider=llm_provider.strip() or "gemini",
547
+ llm_model=llm_model.strip() or "gemini-3.1-pro-preview",
548
+ llm_model_step4=llm_model_step4.strip() or "gemini-3-flash-preview",
549
+ model_path="Deep-Citation/Workspace/acl_scicite_wksp_trl/best_model.pt",
550
+ model_data_dir="Deep-Citation/Data",
551
+ model_class_def="Deep-Citation/Data/class_def.json",
552
+ model_lm="scibert",
553
+ device="cpu",
554
+ embedding_model="sentence-transformers/all-mpnet-base-v2",
555
+ )
556
+
557
+ status_placeholder = st.empty()
558
+ activity_placeholder = st.empty()
559
+ status = "Starting"
560
+ logs: list[str] = []
561
+ events: list[str] = []
562
+ seen_events: set[str] = set()
563
+ artifact_path = None
564
+ annotation_payload_path = None
565
+ annotation_skipped_reason = None
566
+ run_summary = None
567
+ pipeline_stopped_reason = None
568
+ pipeline_failed_reason = None
569
+
570
+ def render_activity(items: list[str]):
571
+ if not items:
572
+ activity_placeholder.info("Waiting for first step...")
573
+ return
574
+ activity_placeholder.markdown("### Activity\n" + "\n".join(f"- {item}" for item in items[-20:]))
575
+
576
+ def append_display_line(line: str):
577
+ display_line = _display_log_line(line)
578
+ if not display_line:
579
+ return
580
+ logs.append(display_line)
581
+ event = _format_step_event(display_line)
582
+ if event and event not in seen_events:
583
+ seen_events.add(event)
584
+ events.append(event)
585
+ render_activity(events)
586
+
587
+ for line, maybe_artifact in runner_module.run_pipeline(cfg, Path(output_root).expanduser().resolve()):
588
+ if line:
589
+ if line.strip() == "Pipeline completed successfully.":
590
+ if maybe_artifact:
591
+ artifact_path = maybe_artifact
592
+ continue
593
+ display_line = _display_log_line(line)
594
+ if display_line:
595
+ logs.append(display_line)
596
+ status = _status_from_line(display_line, status)
597
+ if display_line.startswith("Pipeline stopped:"):
598
+ pipeline_stopped_reason = display_line
599
+ if "failed" in display_line.lower():
600
+ pipeline_failed_reason = display_line
601
+ event = _format_step_event(display_line)
602
+ if event and event not in seen_events:
603
+ seen_events.add(event)
604
+ events.append(event)
605
+ if maybe_artifact:
606
+ artifact_path = maybe_artifact
607
+ status_placeholder.info(f"Current status: {status}")
608
+ render_activity(events)
609
+
610
+ run_dir_path = None
611
+ paper_dir_path = None
612
+ remote_artifact_ref = ""
613
+ if artifact_path:
614
+ job_dir = Path(str(artifact_path)).with_suffix("")
615
+ run_dir_path = str(job_dir)
616
+ paper_id = runner_module.parse_arxiv_id(paper_input.strip())
617
+ paper_dir = job_dir / "processed_papers" / paper_id
618
+ paper_dir_path = str(paper_dir)
619
+ if pipeline_failed_reason:
620
+ annotation_skipped_reason = f"{pipeline_failed_reason} Annotation was not run."
621
+ elif pipeline_stopped_reason:
622
+ annotation_skipped_reason = f"{pipeline_stopped_reason} Annotation was not run."
623
+ else:
624
+ discovery = _load_json(paper_dir / "usage_discovery_from_contributions.json") or {}
625
+ refined_clusters = discovery.get("clusters") or []
626
+ if not refined_clusters:
627
+ annotation_skipped_reason = "No valid downstream usage clusters remained after refinement and filtering. Annotation was skipped."
628
+ logs.append("[annotation] skipped: no refined downstream usage clusters")
629
+ else:
630
+ append_display_line("[annotation] starting cluster-first two-pass annotation")
631
+ status_placeholder.info("Current status: Running annotation")
632
+ try:
633
+ run_output, annotation_run_dir = run_two_pass_annotation(
634
+ paper_dir=paper_dir,
635
+ annotation_output_root=job_dir / "two_pass_outputs",
636
+ llm_provider=llm_provider,
637
+ llm_model=llm_model,
638
+ formatter_model=formatter_model,
639
+ judge_model=judge_model,
640
+ candidate_count=candidate_count,
641
+ )
642
+ payload_path = run_output.get("ui_payload_path") if isinstance(run_output, dict) else None
643
+ if payload_path and Path(payload_path).exists():
644
+ annotation_payload_path = str(Path(payload_path))
645
+ append_display_line(f"[annotation] complete: {annotation_run_dir}")
646
+ except Exception as exc:
647
+ pipeline_failed_reason = f"Annotation failed: {exc}"
648
+ annotation_skipped_reason = pipeline_failed_reason
649
+ logs.append(f"[annotation] failed: {exc}")
650
+ logs.append("[upload] uploading run artifact to Hugging Face dataset")
651
+ status_placeholder.info("Current status: Finalizing run")
652
+ remote_artifact_ref = upload_run_artifact(job_dir)
653
+ if remote_artifact_ref:
654
+ logs.append(f"[upload] {remote_artifact_ref}")
655
+ else:
656
+ logs.append("[upload] skipped: RUNS_REPO_ID/HF_WRITE_TOKEN not configured")
657
+ if not pipeline_stopped_reason and not pipeline_failed_reason:
658
+ append_display_line("Pipeline completed successfully.")
659
+
660
+ if pipeline_failed_reason:
661
+ status = "Failed"
662
+ elif artifact_path and pipeline_stopped_reason:
663
+ status = "Stopped"
664
+ else:
665
+ status = "Completed" if artifact_path else "Failed"
666
+ if status == "Completed":
667
+ status_placeholder.success(f"Final status: {status}")
668
+ elif status == "Stopped":
669
+ status_placeholder.warning(f"Final status: {status}")
670
+ else:
671
+ status_placeholder.error("Final status: Failed")
672
+
673
+ st.session_state["run_status"] = status
674
+ st.session_state["run_logs"] = logs
675
+ st.session_state["run_events"] = events
676
+ st.session_state["artifact_path"] = artifact_path
677
+ st.session_state["run_dir_path"] = run_dir_path
678
+ st.session_state["paper_dir_path"] = paper_dir_path
679
+ st.session_state["annotation_payload_path"] = annotation_payload_path
680
+ st.session_state["annotation_skipped_reason"] = annotation_skipped_reason
681
+ st.session_state["pipeline_stopped_reason"] = pipeline_stopped_reason
682
+ st.session_state["pipeline_failed_reason"] = pipeline_failed_reason
683
+ st.session_state["run_summary"] = run_summary
684
+ st.session_state["remote_artifact_ref"] = remote_artifact_ref
685
+
686
+
687
+ def _load_result_bundle():
688
+ paper_dir_path = st.session_state.get("paper_dir_path")
689
+ annotation_payload_path = st.session_state.get("annotation_payload_path")
690
+ paper_dir = Path(paper_dir_path) if paper_dir_path else None
691
+ payload = _load_json(Path(annotation_payload_path)) if annotation_payload_path else None
692
+ discovery = _load_json(paper_dir / "usage_discovery_from_contributions.json") if paper_dir and paper_dir.exists() else None
693
+ contributions_data = _load_json(paper_dir / "usage_contributions.json") if paper_dir and paper_dir.exists() else None
694
+ contributions = (contributions_data or {}).get("contributions") or []
695
+ return paper_dir, discovery, contributions, payload
696
+
697
+
698
+ def _render_overview(payload: Optional[dict], discovery: Optional[dict]):
699
+ claims = (payload or {}).get("claims") or []
700
+ ingredients = sum(len(claim.get("ingredients") or []) for claim in claims)
701
+ studies = sum(
702
+ len(_collect_grounded_studies(claim.get("enabling_discoveries") or [], claim.get("ingredients") or []))
703
+ for claim in claims
704
+ )
705
+ clusters = len((discovery or {}).get("clusters") or [])
706
+
707
+ c1, c2, c3, c4 = st.columns(4)
708
+ with c1:
709
+ _metric_card("Refined clusters", clusters)
710
+ with c2:
711
+ _metric_card("Target contributions", len(claims))
712
+ with c3:
713
+ _metric_card("Enabling contributions", ingredients)
714
+ with c4:
715
+ _metric_card("Grounded studies", studies)
716
+
717
+
718
+ def _build_public_export(discovery: Optional[dict], payload: Optional[dict]) -> dict:
719
+ claims = []
720
+ for claim in (payload or {}).get("claims") or []:
721
+ if not isinstance(claim, dict):
722
+ continue
723
+ ingredients = []
724
+ for ingredient in claim.get("ingredients") or []:
725
+ if not isinstance(ingredient, dict):
726
+ continue
727
+ ingredients.append({
728
+ "ingredient_id": ingredient.get("ingredient_id"),
729
+ "enabling_contribution": ingredient.get("ingredient"),
730
+ "canonical_annotation": ingredient.get("canonical_annotation") or {},
731
+ "primary_grounding": ingredient.get("canonical_grounding") or {},
732
+ "additional_groundings": ingredient.get("additional_groundings") or [],
733
+ })
734
+ claims.append({
735
+ "claim_id": claim.get("claim_id"),
736
+ "target_contribution": claim.get("rewritten_claim") or claim.get("text"),
737
+ "cluster_id": claim.get("cluster_id"),
738
+ "decision": claim.get("decision"),
739
+ "enabling_contributions": ingredients,
740
+ "grounded_studies": _collect_grounded_studies(claim.get("enabling_discoveries") or [], claim.get("ingredients") or []),
741
+ })
742
+
743
+ return {
744
+ "citation_clusters": (discovery or {}).get("clusters") or [],
745
+ "target_contribution_decompositions": claims,
746
+ }
747
+
748
+
749
+ def main():
750
+ llm_provider = os.getenv("LLM_PROVIDER", "gemini")
751
+ llm_model = os.getenv("LLM_MODEL", "gemini-3.1-pro-preview")
752
+ llm_model_step4 = os.getenv("LLM_MODEL_STEP4", "gemini-3-flash-preview")
753
+ formatter_model = os.getenv("ANNOTATION_FORMATTER_MODEL", "gemini/gemini-3.1-pro-preview")
754
+ judge_model = os.getenv("ANNOTATION_JUDGE_MODEL", "gemini/gemini-3.1-pro-preview")
755
+ candidate_count = int(os.getenv("ANNOTATION_CANDIDATE_COUNT", "3"))
756
+ source_root = DEFAULT_SOURCE_ROOT
757
+ output_root = DEFAULT_OUTPUT_ROOT
758
+
759
+ st.set_page_config(page_title="Forecasting Scientific Contribution Pathways", page_icon="📚", layout="wide")
760
+ st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
761
+ _ensure_state()
762
+
763
+ with st.sidebar:
764
+ st.markdown("## SciPaths")
765
+ st.caption("Enter an arXiv paper and run the target-contribution pathway annotation pipeline.")
766
+ st.divider()
767
+ st.markdown("### Citation")
768
+ st.caption("If you find this useful, please cite our paper as:")
769
+ st.code(
770
+ "@misc{chamoun2026scipathsforecastingpathwaysscientific,\n"
771
+ " title={SciPaths: Forecasting Pathways to Scientific Discovery}, \n"
772
+ " author={Eric Chamoun and Yizhou Chi and Yulong Chen and Rui Cao and Zifeng Ding and Michalis Korakakis and Andreas Vlachos},\n"
773
+ " year={2026},\n"
774
+ " eprint={2605.14600},\n"
775
+ " archivePrefix={arXiv},\n"
776
+ " primaryClass={cs.CL},\n"
777
+ " url={https://arxiv.org/abs/2605.14600}, \n"
778
+ "}",
779
+ language="bibtex",
780
+ )
781
+ st.caption("Paper URL: https://arxiv.org/abs/2605.14600")
782
+ st.caption("Questions or feedback: ec806@cam.ac.uk")
783
+ st.divider()
784
+ if st.button("Clear chat / restart", use_container_width=True):
785
+ for key in [
786
+ "paper_input", "run_status", "run_logs", "run_events", "artifact_path",
787
+ "run_dir_path", "paper_dir_path", "annotation_payload_path",
788
+ "run_summary", "annotation_skipped_reason", "pipeline_stopped_reason",
789
+ "pipeline_failed_reason", "remote_artifact_ref",
790
+ ]:
791
+ if key in st.session_state:
792
+ del st.session_state[key]
793
+ st.rerun()
794
+ if not get_secret("GEMINI_API_KEY"):
795
+ st.warning("No GEMINI_API_KEY found in environment or secrets.", icon="🔑")
796
+
797
+ st.markdown("<div class='hero-title'>Forecasting Scientific Contribution Pathways</div>", unsafe_allow_html=True)
798
+ st.markdown(
799
+ "<div class='hero-sub'>Run the SciPaths pipeline through refined downstream citation clusters, then derive target contributions from those clusters and decompose each target contribution into enabling contributions and grounded studies.</div>",
800
+ unsafe_allow_html=True,
801
+ )
802
+
803
+ tabs = st.tabs(TAB_NAMES)
804
+
805
+ with tabs[0]:
806
+ with st.expander("Try an example", expanded=True):
807
+ cols = st.columns(len(EXAMPLES))
808
+ for i, (label, value) in enumerate(EXAMPLES.items()):
809
+ with cols[i]:
810
+ if st.button(label, key=f"example::{label}", use_container_width=True):
811
+ st.session_state["paper_input"] = value
812
+ st.rerun()
813
+
814
+ paper_input = st.text_input(
815
+ "Paper input (arXiv URL or ID)",
816
+ key="paper_input",
817
+ placeholder="https://arxiv.org/abs/2311.14919",
818
+ )
819
+
820
+ if st.button("Run pipeline + annotation", type="primary", use_container_width=True):
821
+ if not paper_input.strip():
822
+ st.error("Paper input is required.")
823
+ else:
824
+ run_pipeline_stream(
825
+ paper_input=paper_input,
826
+ source_root=source_root,
827
+ output_root=output_root,
828
+ llm_provider=llm_provider,
829
+ llm_model=llm_model,
830
+ llm_model_step4=llm_model_step4,
831
+ formatter_model=formatter_model,
832
+ judge_model=judge_model,
833
+ candidate_count=candidate_count,
834
+ )
835
+
836
+ st.markdown("### Latest run")
837
+ st.info(f"Status: {st.session_state.get('run_status', 'Idle')}")
838
+ if st.session_state.get("pipeline_failed_reason"):
839
+ st.error(st.session_state["pipeline_failed_reason"])
840
+ if st.session_state.get("annotation_skipped_reason"):
841
+ st.warning(st.session_state["annotation_skipped_reason"])
842
+
843
+ paper_dir, discovery, contributions, payload = _load_result_bundle()
844
+ public_export = _build_public_export(discovery, payload)
845
+ if public_export["citation_clusters"] or public_export["target_contribution_decompositions"]:
846
+ st.download_button(
847
+ "Download citation clusters and contribution groundings",
848
+ data=json.dumps(public_export, indent=2, ensure_ascii=False),
849
+ file_name="scipaths_run_results.json",
850
+ mime="application/json",
851
+ use_container_width=False,
852
+ )
853
+ _render_overview(payload, discovery)
854
+
855
+ with tabs[1]:
856
+ paper_dir, discovery, contributions, payload = _load_result_bundle()
857
+ _render_clusters_tab(discovery, contributions)
858
+
859
+ with tabs[2]:
860
+ paper_dir, discovery, contributions, payload = _load_result_bundle()
861
+ _render_claims_tab(payload)
862
+
863
+ if __name__ == "__main__":
864
+ main()
hf_space/streamlit_config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from runner import STEP_LABELS
2
+
3
+ EXAMPLES = {
4
+ "Confidence-based MBR Decoding": "https://arxiv.org/abs/2311.14919",
5
+ "AVerImaTeC": "https://arxiv.org/abs/2505.17978",
6
+ "CSCD-NS (2022)": "https://arxiv.org/abs/2211.08788",
7
+ }
8
+
9
+ TAB_NAMES = [
10
+ "Pipeline Run",
11
+ "Citation Clusters",
12
+ "Target Contribution Decomposition",
13
+ ]
14
+
15
+ METHOD_NOTES = {
16
+ "Pipeline scope": "Runs steps 0, 1, 2, 3, 4, 5, 6, and 8, then launches cluster-first two-pass annotation.",
17
+ "Input": "Accepts a single arXiv URL or arXiv ID.",
18
+ "Cluster-first annotation": "Uses all refined downstream USES/EXTENDS clusters to derive target contributions, then decomposes each target contribution separately.",
19
+ "Stopping rule": "If no valid downstream usage clusters remain after refinement and filtering, annotation is skipped.",
20
+ }
21
+
22
+ DISPLAY_STEPS = [0, 1, 2, 3, 4, 5, 6, 8]
23
+
24
+
25
+ def pipeline_steps_markdown() -> str:
26
+ lines = []
27
+ for idx in DISPLAY_STEPS:
28
+ lines.append(f"{idx}. {STEP_LABELS[idx]}")
29
+ lines.append("9. Cluster-first target contribution annotation and enabling contribution decomposition")
30
+ return "\n".join(lines)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ -r hf_space/requirements.txt
src/common/__init__.py ADDED
File without changes
src/common/llm_client.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import google.generativeai as genai
5
+ from google.generativeai.types import GenerationConfig
6
+
7
+
8
+ class LLMClient:
9
+ def __init__(self):
10
+ self.provider = os.getenv("LLM_PROVIDER", "gemini").lower()
11
+ self.model_name = os.getenv("LLM_MODEL", "gemini-3.1-pro-preview")
12
+
13
+ if self.provider == "gemini":
14
+ if genai is None:
15
+ raise ImportError("google-generativeai not installed.")
16
+ key = os.getenv("GEMINI_API_KEY")
17
+ if not key:
18
+ raise ValueError("GEMINI_API_KEY not set.")
19
+ genai.configure(api_key=key)
20
+ self.model = genai.GenerativeModel(self.model_name)
21
+ else:
22
+ raise NotImplementedError("Only Gemini provider is wired for now.")
23
+
24
+ def call(self, prompt: str, schema: Optional[dict] = None) -> str:
25
+ """
26
+ Call the underlying LLM.
27
+
28
+ If `schema` is provided (as a plain JSON schema dict), and provider is Gemini,
29
+ use it as response_schema with JSON mime type.
30
+ """
31
+ if self.provider == "gemini":
32
+ if schema and GenerationConfig is not None:
33
+ config = GenerationConfig(
34
+ response_schema=schema,
35
+ response_mime_type="application/json",
36
+ )
37
+ response = self.model.generate_content(
38
+ prompt,
39
+ generation_config=config,
40
+ )
41
+ else:
42
+ response = self.model.generate_content(prompt)
43
+
44
+ text = getattr(response, "text", "")
45
+ if not text:
46
+ raise RuntimeError("LLM response did not contain text.")
47
+ return text
48
+
49
+ raise NotImplementedError("Schema-based calls only wired for Gemini right now.")
src/common/model_client.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import re
6
+ from dataclasses import dataclass
7
+ from typing import Any, Type
8
+
9
+ import litellm
10
+ from litellm import completion
11
+ from pydantic import BaseModel, ValidationError
12
+
13
+
14
+ @dataclass
15
+ class ModelConfig:
16
+ provider: str
17
+ model: str
18
+ temperature: float = 0.2
19
+ max_tokens: int = 12000
20
+
21
+ @property
22
+ def model_name(self) -> str:
23
+ if "/" in self.model:
24
+ return self.model
25
+ if self.provider.lower() == "openai":
26
+ return f"openai/{self.model}"
27
+ if self.provider.lower() == "gemini":
28
+ return f"gemini/{self.model}"
29
+ return self.model
30
+
31
+
32
+ class MultiProviderLLMClient:
33
+ def __init__(self, default_config: ModelConfig, stage_models: dict[str, str] | None = None):
34
+ self.default_config = default_config
35
+ self.stage_models = stage_models or {}
36
+ litellm.drop_params = True
37
+ self._validate_env(default_config.provider)
38
+
39
+ def _validate_env(self, provider: str) -> None:
40
+ provider = provider.lower()
41
+ if provider == "openai" and not os.getenv("OPENAI_API_KEY"):
42
+ raise ValueError("OPENAI_API_KEY is required for provider=openai")
43
+ if provider == "gemini" and not os.getenv("GEMINI_API_KEY"):
44
+ raise ValueError("GEMINI_API_KEY is required for provider=gemini")
45
+
46
+ def config_for_stage(self, stage_name: str) -> ModelConfig:
47
+ model_override = self.stage_models.get(stage_name)
48
+ if not model_override:
49
+ return self.default_config
50
+ provider = self.default_config.provider
51
+ model = model_override
52
+ if "/" in model_override:
53
+ provider, model = model_override.split("/", 1)
54
+ self._validate_env(provider)
55
+ return ModelConfig(
56
+ provider=provider,
57
+ model=model,
58
+ temperature=self.default_config.temperature,
59
+ max_tokens=self.default_config.max_tokens,
60
+ )
61
+
62
+ def generate_structured(
63
+ self,
64
+ *,
65
+ stage_name: str,
66
+ system_prompt: str,
67
+ user_prompt: str,
68
+ response_model: Type[BaseModel],
69
+ ) -> BaseModel:
70
+ config = self.config_for_stage(stage_name)
71
+ completion_kwargs = {
72
+ "model": config.model_name,
73
+ "messages": [
74
+ {"role": "system", "content": system_prompt},
75
+ {"role": "user", "content": user_prompt},
76
+ ],
77
+ "max_tokens": config.max_tokens,
78
+ "response_format": {"type": "json_object"},
79
+ }
80
+ temperature = self._temperature_for_model(config)
81
+ if temperature is not None:
82
+ completion_kwargs["temperature"] = temperature
83
+
84
+ response = completion(
85
+ **completion_kwargs,
86
+ )
87
+ content = response.choices[0].message.content or ""
88
+ payload = self._parse_json(content)
89
+ try:
90
+ return response_model.model_validate(payload)
91
+ except ValidationError as exc:
92
+ if isinstance(payload, list) and len(payload) == 1 and isinstance(payload[0], dict):
93
+ try:
94
+ return response_model.model_validate(payload[0])
95
+ except ValidationError:
96
+ pass
97
+ raise ValueError(
98
+ f"Stage {stage_name} returned invalid JSON for {response_model.__name__}: {exc}\nRaw content:\n{content}"
99
+ ) from exc
100
+
101
+ def generate_text(
102
+ self,
103
+ *,
104
+ stage_name: str,
105
+ system_prompt: str,
106
+ user_prompt: str,
107
+ ) -> str:
108
+ config = self.config_for_stage(stage_name)
109
+ completion_kwargs = {
110
+ "model": config.model_name,
111
+ "messages": [
112
+ {"role": "system", "content": system_prompt},
113
+ {"role": "user", "content": user_prompt},
114
+ ],
115
+ "max_tokens": config.max_tokens,
116
+ }
117
+ temperature = self._temperature_for_model(config)
118
+ if temperature is not None:
119
+ completion_kwargs["temperature"] = temperature
120
+ response = completion(**completion_kwargs)
121
+ return (response.choices[0].message.content or "").strip()
122
+
123
+ @staticmethod
124
+ def _parse_json(text: str) -> Any:
125
+ text = text.strip()
126
+ if text.startswith("```"):
127
+ match = re.search(r"```(?:json)?\s*(.*?)```", text, flags=re.S)
128
+ if match:
129
+ text = match.group(1).strip()
130
+ try:
131
+ return json.loads(text)
132
+ except json.JSONDecodeError:
133
+ match = re.search(r"(\{.*\}|\[.*\])", text, flags=re.S)
134
+ if match:
135
+ return json.loads(match.group(1))
136
+ raise
137
+
138
+ @staticmethod
139
+ def _temperature_for_model(config: ModelConfig) -> float | None:
140
+ model_name = config.model_name.lower()
141
+ if "gpt-5" in model_name:
142
+ return None
143
+ return config.temperature
src/common/paper_package.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List
7
+
8
+ from pydantic import BaseModel
9
+
10
+
11
+ SECTION_FILES = [
12
+ "abstract.txt",
13
+ "introduction.tex",
14
+ "related_work.tex",
15
+ "tldr.txt",
16
+ ]
17
+
18
+
19
+ class PaperPackage(BaseModel):
20
+ paper_dir: Path
21
+ paper_metadata: Dict[str, Any]
22
+ extracted_discovery_claim: str
23
+ downstream_cluster_evidence: List[Dict[str, Any]]
24
+ paper_text: Dict[str, str]
25
+ full_processed_text: str
26
+ bibliography: List[Dict[str, Any]]
27
+ citation_contexts: List[Dict[str, Any]]
28
+
29
+ def to_prompt_payload(self) -> Dict[str, Any]:
30
+ return {
31
+ "paper_metadata": self.paper_metadata,
32
+ "extracted_discovery_claim": self.extracted_discovery_claim,
33
+ "downstream_cluster_evidence": self.downstream_cluster_evidence,
34
+ "paper_text": self.paper_text,
35
+ "full_processed_text": self.full_processed_text,
36
+ "bibliography": self.bibliography,
37
+ "citation_contexts": self.citation_contexts,
38
+ }
39
+
40
+
41
+ def _load_json(path: Path, default: Any) -> Any:
42
+ try:
43
+ return json.loads(path.read_text())
44
+ except Exception:
45
+ return default
46
+
47
+
48
+ def _read_text(path: Path) -> str:
49
+ try:
50
+ return path.read_text()
51
+ except Exception:
52
+ return ""
53
+
54
+
55
+ def _normalize_dict_payload(value: Any) -> Dict[str, Any]:
56
+ if isinstance(value, dict):
57
+ return value
58
+ if isinstance(value, list):
59
+ for item in value:
60
+ if isinstance(item, dict):
61
+ return item
62
+ return {}
63
+
64
+
65
+ def _collect_sections(paper_dir: Path) -> Dict[str, str]:
66
+ sections_dir = paper_dir / "sections"
67
+ out: Dict[str, str] = {}
68
+ for name in SECTION_FILES:
69
+ text = _read_text(sections_dir / name).strip()
70
+ if text:
71
+ out[name] = text[:12000]
72
+ if not out:
73
+ processed = _read_text(paper_dir / "processed_main.tex").strip()
74
+ if processed:
75
+ out["processed_main.tex"] = processed[:24000]
76
+ return out
77
+
78
+
79
+ def _collect_full_processed_text(paper_dir: Path) -> str:
80
+ processed = _read_text(paper_dir / "processed_main.tex").strip()
81
+ if processed:
82
+ return processed
83
+
84
+ sections_dir = paper_dir / "sections"
85
+ parts: List[str] = []
86
+ if sections_dir.exists():
87
+ for path in sorted(sections_dir.iterdir()):
88
+ if not path.is_file():
89
+ continue
90
+ text = _read_text(path).strip()
91
+ if text:
92
+ parts.append(f"[{path.name}]\n{text}")
93
+ return "\n\n".join(parts)
94
+
95
+
96
+ def _extract_year(value: Any) -> Any:
97
+ if value:
98
+ return value
99
+ return None
100
+
101
+
102
+ def _normalise_reference_record(ref: Dict[str, Any]) -> Dict[str, Any]:
103
+ cited = ref.get("citedPaper")
104
+ source = cited if isinstance(cited, dict) else ref
105
+ external_ids = source.get("external_ids") or source.get("externalIds") or {}
106
+ return {
107
+ "ref_id": (
108
+ ref.get("ref_id")
109
+ or ref.get("bib_key")
110
+ or source.get("ref_id")
111
+ or source.get("bib_key")
112
+ or source.get("paperId")
113
+ or source.get("paper_id")
114
+ or external_ids.get("ACL")
115
+ or external_ids.get("ArXiv")
116
+ or external_ids.get("DOI")
117
+ ),
118
+ "title": source.get("title") or source.get("ref_title"),
119
+ "authors": source.get("authors") or source.get("ref_authors"),
120
+ "year": _extract_year(source.get("year") or source.get("ref_year")),
121
+ "external_ids": external_ids,
122
+ }
123
+
124
+
125
+ def _parse_bibtex_entries(text: str, limit: int) -> List[Dict[str, Any]]:
126
+ entries: List[Dict[str, Any]] = []
127
+ for match in re.finditer(r"@\w+\s*\{\s*([^,]+),(.*?)(?=\n@\w+\s*\{|\Z)", text, re.S):
128
+ key = match.group(1).strip()
129
+ body = match.group(2)
130
+ fields: Dict[str, str] = {}
131
+ for field in ("title", "author", "year", "doi", "url", "eprint"):
132
+ field_match = re.search(
133
+ rf"\b{field}\s*=\s*(\{{(?:[^{{}}]|\{{[^{{}}]*\}})*\}}|\"[^\"]*\"|[^,\n]+)",
134
+ body,
135
+ re.I | re.S,
136
+ )
137
+ if field_match:
138
+ value = field_match.group(1).strip().strip(",")
139
+ if (value.startswith("{") and value.endswith("}")) or (
140
+ value.startswith('"') and value.endswith('"')
141
+ ):
142
+ value = value[1:-1]
143
+ fields[field] = re.sub(r"\s+", " ", value).strip()
144
+ if fields:
145
+ external_ids: Dict[str, Any] = {}
146
+ if fields.get("doi"):
147
+ external_ids["DOI"] = fields["doi"]
148
+ if fields.get("eprint"):
149
+ external_ids["ArXiv"] = fields["eprint"]
150
+ entries.append(
151
+ {
152
+ "ref_id": key,
153
+ "title": fields.get("title"),
154
+ "authors": fields.get("author"),
155
+ "year": fields.get("year"),
156
+ "external_ids": external_ids,
157
+ }
158
+ )
159
+ if len(entries) >= limit:
160
+ break
161
+ return entries
162
+
163
+
164
+ def _collect_bibtex_citation_contexts(paper_dir: Path, limit: int = 60) -> List[Dict[str, Any]]:
165
+ bibtex = _read_text(paper_dir / "references.bib")
166
+ processed = _read_text(paper_dir / "processed_main.tex")
167
+ if not bibtex or not processed:
168
+ return []
169
+
170
+ refs = _parse_bibtex_entries(bibtex, limit=500)
171
+ out: List[Dict[str, Any]] = []
172
+ seen: set[tuple[str, int]] = set()
173
+ for ref in refs:
174
+ ref_id = ref.get("ref_id")
175
+ if not ref_id:
176
+ continue
177
+ for match in re.finditer(rf"\\cite\w*\s*(?:\[[^\]]*\]\s*)*\{{[^}}]*\b{re.escape(str(ref_id))}\b[^}}]*\}}", processed):
178
+ key = (str(ref_id), match.start())
179
+ if key in seen:
180
+ continue
181
+ seen.add(key)
182
+ start = max(0, match.start() - 350)
183
+ end = min(len(processed), match.end() + 350)
184
+ snippet = re.sub(r"\s+", " ", processed[start:end]).strip()
185
+ out.append(
186
+ {
187
+ "ref_id": ref_id,
188
+ "citation_marker": ref.get("title") or ref_id,
189
+ "text": snippet,
190
+ "section": None,
191
+ "intents": [],
192
+ }
193
+ )
194
+ if len(out) >= limit:
195
+ return out
196
+ return out
197
+
198
+
199
+ def _collect_bibliography(paper_dir: Path, limit: int = 80) -> List[Dict[str, Any]]:
200
+ refs = _load_json(paper_dir / "references_metadata.json", [])
201
+ if isinstance(refs, list) and refs:
202
+ return [_normalise_reference_record(ref) for ref in refs[:limit] if isinstance(ref, dict)]
203
+
204
+ bibtex = _read_text(paper_dir / "references.bib")
205
+ if bibtex:
206
+ return _parse_bibtex_entries(bibtex, limit)
207
+ return []
208
+
209
+
210
+ def _collect_citation_contexts(paper_dir: Path, limit: int = 60) -> List[Dict[str, Any]]:
211
+ refs = _load_json(paper_dir / "references_metadata.json", [])
212
+ out = []
213
+ if isinstance(refs, list):
214
+ for ref in refs:
215
+ if not isinstance(ref, dict):
216
+ continue
217
+ ref_record = _normalise_reference_record(ref)
218
+ for context in ref.get("contextsWithIntent") or []:
219
+ if not isinstance(context, dict):
220
+ continue
221
+ text = context.get("context") or context.get("text") or ""
222
+ if not text:
223
+ continue
224
+ out.append(
225
+ {
226
+ "ref_id": ref_record.get("ref_id"),
227
+ "citation_marker": ref_record.get("title"),
228
+ "text": text,
229
+ "section": context.get("section"),
230
+ "intents": context.get("intents", []),
231
+ }
232
+ )
233
+ if len(out) >= limit:
234
+ return out
235
+ contexts = _load_json(paper_dir / "usage_contexts.json", [])
236
+ if isinstance(contexts, list):
237
+ for item in contexts:
238
+ entry = {
239
+ "ref_id": item.get("ref_id") or item.get("bib_key"),
240
+ "citation_marker": item.get("citation_marker"),
241
+ "text": item.get("text") or item.get("text_raw") or "",
242
+ "section": item.get("section"),
243
+ }
244
+ if entry["text"]:
245
+ out.append(entry)
246
+ if len(out) >= limit:
247
+ break
248
+ if not out:
249
+ out = _collect_bibtex_citation_contexts(paper_dir, limit=limit)
250
+ return out
251
+
252
+
253
+ def _collect_downstream_cluster_evidence(paper_dir: Path) -> List[Dict[str, Any]]:
254
+ discovery = _normalize_dict_payload(_load_json(paper_dir / "usage_discovery_from_contributions.json", {}))
255
+ clusters = discovery.get("clusters", [])
256
+ out = []
257
+ for cluster in clusters:
258
+ out.append(
259
+ {
260
+ "cluster_id": cluster.get("cluster_id"),
261
+ "representative_claim": cluster.get("representative_claim") or cluster.get("cluster_title"),
262
+ "cluster_title": cluster.get("cluster_title"),
263
+ "count": cluster.get("count"),
264
+ "merge_rationale": cluster.get("merge_rationale"),
265
+ }
266
+ )
267
+ return out
268
+
269
+
270
+ def load_paper_package(paper_dir: str | Path, extracted_claim_override: str | None = None) -> PaperPackage:
271
+ paper_dir = Path(paper_dir)
272
+ discovery = _normalize_dict_payload(_load_json(paper_dir / "usage_discovery_from_contributions.json", {}))
273
+ paper_metadata = _normalize_dict_payload(_load_json(paper_dir / "paper_metadata.json", {}))
274
+ claim = extracted_claim_override or (
275
+ discovery.get("most_impactful_contribution_self_contained")
276
+ or discovery.get("most_impactful_contribution")
277
+ or ""
278
+ )
279
+ return PaperPackage(
280
+ paper_dir=paper_dir,
281
+ paper_metadata=paper_metadata,
282
+ extracted_discovery_claim=claim,
283
+ downstream_cluster_evidence=_collect_downstream_cluster_evidence(paper_dir),
284
+ paper_text=_collect_sections(paper_dir),
285
+ full_processed_text=_collect_full_processed_text(paper_dir),
286
+ bibliography=_collect_bibliography(paper_dir),
287
+ citation_contexts=_collect_citation_contexts(paper_dir),
288
+ )
src/step_01_fetch/config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+
4
+ ACL_IDS_PATH = Path("input_ids.json")
5
+ PAPERS_DIR = Path("papers")
6
+ SEMANTIC_SCHOLAR_API_KEY = os.getenv("SEMANTIC_SCHOLAR_API_KEY", "")
src/step_01_fetch/fetch_metadata.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import argparse
3
+ import json
4
+ import os
5
+ import random
6
+ import re
7
+ import tarfile
8
+ import time
9
+
10
+ import arxiv
11
+ import requests
12
+
13
+ from config import ACL_IDS_PATH
14
+ from process_tex_source import preprocess_tex, extract_introduction_and_related
15
+ from semanticscholar_client import get_paper, get_paper_links, search_by_title
16
+
17
+
18
+ def load_ids(path: Path):
19
+ return json.loads(path.read_text(encoding="utf-8"))
20
+
21
+
22
+ def ensure_dir(path: Path):
23
+ path.mkdir(parents=True, exist_ok=True)
24
+
25
+
26
+ _ARXIV_LAST_TS = 0.0
27
+
28
+
29
+ def _cleanup_partial_source_dir(source_dir: Path) -> None:
30
+ for pattern in ("*.tar.gz", "*.tgz", "*.tar"):
31
+ for path in source_dir.glob(pattern):
32
+ try:
33
+ path.unlink()
34
+ except Exception:
35
+ pass
36
+
37
+
38
+ def _download_arxiv_source_with_retries(paper, source_dir: Path, arxiv_id: str) -> Path | None:
39
+ max_retries = int(os.getenv("ARXIV_SOURCE_MAX_RETRIES", "4"))
40
+ base_sleep = float(os.getenv("ARXIV_SOURCE_BASE_SLEEP", "2.0"))
41
+ max_sleep = float(os.getenv("ARXIV_MAX_BACKOFF", "60"))
42
+ last_exc = None
43
+
44
+ for attempt in range(max_retries):
45
+ _cleanup_partial_source_dir(source_dir)
46
+ try:
47
+ _arxiv_min_interval_sleep()
48
+ tar_path = Path(paper.download_source(dirpath=str(source_dir)))
49
+ if not tar_path.exists():
50
+ raise FileNotFoundError(f"download_source returned {tar_path}, but the file does not exist")
51
+ if tar_path.stat().st_size < 1024:
52
+ raise IOError(f"downloaded source archive is unexpectedly small ({tar_path.stat().st_size} bytes)")
53
+ return tar_path
54
+ except Exception as exc:
55
+ last_exc = exc
56
+ sleep = min(base_sleep * (2**attempt), max_sleep) + random.uniform(0.0, 0.5)
57
+ print(f"[WARN] Failed to download source for {arxiv_id} on attempt {attempt + 1}/{max_retries}: {exc}")
58
+ if attempt + 1 < max_retries:
59
+ print(f"[INFO] Retrying source download in {sleep:.2f}s")
60
+ time.sleep(sleep)
61
+
62
+ print(f"[WARN] Source download failed for {arxiv_id} after {max_retries} attempts: {last_exc}")
63
+ return None
64
+
65
+
66
+ def _arxiv_min_interval_sleep() -> None:
67
+ """Global throttle to avoid arXiv API rate limits."""
68
+ global _ARXIV_LAST_TS
69
+ min_interval = float(os.getenv("ARXIV_MIN_INTERVAL", "1.0"))
70
+ now = time.monotonic()
71
+ elapsed = now - _ARXIV_LAST_TS
72
+ if elapsed < min_interval:
73
+ time.sleep(min_interval - elapsed)
74
+ _ARXIV_LAST_TS = time.monotonic()
75
+
76
+
77
+ def download_arxiv_tex(arxiv_id: str, base_dir: Path) -> Path | None:
78
+ """
79
+ Download LaTeX source from arXiv and return the path to a merged TeX file.
80
+
81
+ - arxiv_id: e.g. "2410.22815"
82
+ - base_dir: paper directory where source should be unpacked
83
+ """
84
+ source_dir = base_dir / f"tex_{arxiv_id}"
85
+ source_dir.mkdir(parents=True, exist_ok=True)
86
+ search = arxiv.Search(id_list=[arxiv_id])
87
+ max_retries = int(os.getenv("ARXIV_MAX_RETRIES", "6"))
88
+ base_sleep = float(os.getenv("ARXIV_BASE_SLEEP", "2.0"))
89
+ max_sleep = float(os.getenv("ARXIV_MAX_BACKOFF", "60"))
90
+ paper = None
91
+
92
+ for attempt in range(max_retries):
93
+ try:
94
+ _arxiv_min_interval_sleep()
95
+ paper = next(search.results())
96
+ break
97
+ except StopIteration:
98
+ print(f"[WARN] No arXiv paper found for ID {arxiv_id}")
99
+ return None
100
+ except arxiv.HTTPError as exc:
101
+ if getattr(exc, "status", None) == 429 or "429" in str(exc):
102
+ sleep = min(base_sleep * (2**attempt), max_sleep) + random.uniform(0.0, 0.5)
103
+ print(f"[WARN] arXiv 429 → retrying in {sleep:.2f}s")
104
+ time.sleep(sleep)
105
+ continue
106
+ print(f"[WARN] arXiv HTTP error for {arxiv_id}: {exc}")
107
+ return None
108
+ except Exception as exc:
109
+ sleep = min(base_sleep * (2**attempt), max_sleep) + random.uniform(0.0, 0.5)
110
+ print(f"[WARN] arXiv error {exc} → retrying in {sleep:.2f}s")
111
+ time.sleep(sleep)
112
+ continue
113
+
114
+ if paper is None:
115
+ print(f"[ERROR] Giving up after {max_retries} attempts for arXiv ID {arxiv_id}")
116
+ return None
117
+
118
+ tar_path = _download_arxiv_source_with_retries(paper, source_dir, arxiv_id)
119
+ if tar_path is None:
120
+ return None
121
+
122
+ try:
123
+ with tarfile.open(tar_path) as tar:
124
+ tar.extractall(path=source_dir)
125
+ os.remove(tar_path)
126
+ except Exception as exc:
127
+ print(f"[WARN] Failed to extract source for {arxiv_id}: {exc}")
128
+ return None
129
+
130
+ processed_tex = preprocess_tex(source_dir)
131
+ if processed_tex:
132
+ extract_introduction_and_related(processed_tex)
133
+
134
+ if not processed_tex or not processed_tex.exists():
135
+ print(f"[WARN] Could not produce merged TeX for {arxiv_id}")
136
+ return None
137
+
138
+ print(f"[INFO] Processed LaTeX for {arxiv_id} at {processed_tex}")
139
+ return processed_tex
140
+
141
+
142
+ def _extract_arxiv_id_from_text(text: str) -> str | None:
143
+ if not text:
144
+ return None
145
+ match = re.search(r"\b(\d{4}\.\d{4,5}(?:v\d+)?)\b", text)
146
+ if match:
147
+ return match.group(1)
148
+ match = re.search(r"arxiv[:\s/]*(\d{4}\.\d{4,5}(?:v\d+)?)", text, re.IGNORECASE)
149
+ if match:
150
+ return match.group(1)
151
+ return None
152
+
153
+
154
+ def _safe_write_json(path: Path, payload) -> None:
155
+ path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
156
+
157
+
158
+ def _safe_write_text(path: Path, text: str) -> None:
159
+ path.write_text(text, encoding="utf-8")
160
+
161
+
162
+ def _query_openreview_for_paper(openreview_id: str) -> dict | None:
163
+ """Query OpenReview using a real OpenReview note/forum id."""
164
+ if not openreview_id:
165
+ return None
166
+
167
+ try_urls = [
168
+ f"https://api.openreview.net/notes?forum={openreview_id}",
169
+ f"https://api2.openreview.net/notes?forum={openreview_id}",
170
+ f"https://api.openreview.net/notes?id={openreview_id}",
171
+ f"https://api2.openreview.net/notes?id={openreview_id}",
172
+ ]
173
+
174
+ for url in try_urls:
175
+ try:
176
+ response = requests.get(url, timeout=20)
177
+ if response.status_code != 200:
178
+ continue
179
+ payload = response.json()
180
+ except Exception:
181
+ continue
182
+
183
+ notes = None
184
+ if isinstance(payload, dict) and isinstance(payload.get("notes"), list):
185
+ notes = payload["notes"]
186
+ elif isinstance(payload, dict) and payload.get("content"):
187
+ notes = [payload]
188
+ elif isinstance(payload, list):
189
+ notes = payload
190
+
191
+ if not notes:
192
+ continue
193
+
194
+ note = notes[0]
195
+ content = note.get("content") if isinstance(note, dict) else None
196
+ title = None
197
+ arxiv_id = None
198
+ pdf_url = None
199
+
200
+ if isinstance(content, dict):
201
+ raw_title = content.get("title") or content.get("paperTitle")
202
+ title = raw_title.get("value") if isinstance(raw_title, dict) else raw_title
203
+
204
+ raw_pdf = content.get("pdf")
205
+ pdf_url = raw_pdf.get("value") if isinstance(raw_pdf, dict) else raw_pdf
206
+
207
+ for value in content.values():
208
+ if isinstance(value, dict):
209
+ value = value.get("value")
210
+ if isinstance(value, list):
211
+ value = " ".join(str(item) for item in value)
212
+ if isinstance(value, str):
213
+ arxiv_id = _extract_arxiv_id_from_text(value)
214
+ if arxiv_id:
215
+ break
216
+
217
+ if not title and isinstance(note, dict):
218
+ title = note.get("title") or note.get("forumTitle")
219
+
220
+ if not arxiv_id and isinstance(note, dict):
221
+ for value in note.values():
222
+ if isinstance(value, str):
223
+ arxiv_id = _extract_arxiv_id_from_text(value)
224
+ if arxiv_id:
225
+ break
226
+
227
+ return {
228
+ "title": title,
229
+ "arxiv_id": arxiv_id,
230
+ "pdf_url": pdf_url,
231
+ "openreview_id": openreview_id,
232
+ "source_url": url,
233
+ }
234
+
235
+ return None
236
+
237
+
238
+ def _treat_as_openreview(paper: dict) -> bool:
239
+ acl_id = str(paper.get("id", "")).lower()
240
+ id_type = str(paper.get("id_type", "")).lower()
241
+ return (
242
+ id_type == "openreview"
243
+ or bool(paper.get("openreview_id"))
244
+ or acl_id.startswith("neurips-")
245
+ or acl_id.startswith("icml-")
246
+ )
247
+
248
+
249
+ def _fetch_s2_by_title(title: str, acl_id: str) -> tuple[int, dict | None]:
250
+ if not title:
251
+ print(f"[WARN] no title available for {acl_id} → skipping.")
252
+ return 0, None
253
+ hit = search_by_title(title)
254
+ if not hit:
255
+ print(f"[WARN] no S2 match for {acl_id} ({title}) → skipping.")
256
+ return 0, None
257
+ s2_id = hit["paperId"]
258
+ print(f"[DEBUG] title search matched semantic scholar paperId={s2_id}")
259
+ return get_paper(s2_id, id_type="SemanticScholar")
260
+
261
+
262
+ def _best_arxiv_id(*values: str) -> str | None:
263
+ for value in values:
264
+ arxiv_id = _extract_arxiv_id_from_text(value or "")
265
+ if arxiv_id:
266
+ return arxiv_id
267
+ return None
268
+
269
+
270
+ def _write_openreview_snapshot(paper_dir: Path, payload: dict) -> None:
271
+ if payload:
272
+ _safe_write_json(paper_dir / "openreview_metadata.json", payload)
273
+
274
+
275
+ def _write_metadata_outputs(paper_dir: Path, acl_id: str, data: dict) -> None:
276
+ meta_path = paper_dir / "paper_metadata.json"
277
+ _safe_write_json(meta_path, [data])
278
+ print(f"[DEBUG] wrote metadata to {meta_path}")
279
+
280
+ external_ids = data.get("externalIds", {}) or {}
281
+ arxiv_id = external_ids.get("ArXiv")
282
+ if arxiv_id:
283
+ download_arxiv_tex(arxiv_id=arxiv_id, base_dir=paper_dir)
284
+
285
+ sections_dir = paper_dir / "sections"
286
+ sections_dir.mkdir(exist_ok=True)
287
+
288
+ abstract = data.get("abstract")
289
+ if abstract:
290
+ _safe_write_text(sections_dir / "abstract.txt", abstract)
291
+
292
+ tldr_obj = data.get("tldr")
293
+ if isinstance(tldr_obj, dict) and tldr_obj.get("text"):
294
+ _safe_write_text(sections_dir / "tldr.txt", tldr_obj["text"])
295
+
296
+ semantic_id = data.get("paperId")
297
+ if not semantic_id:
298
+ print(f"[WARN] no semantic_id for {acl_id} → skip refs/cites.")
299
+ return
300
+
301
+ citation_count = data.get("citationCount", 0)
302
+ reference_count = data.get("referenceCount", 0)
303
+
304
+ ref_status, refs = get_paper_links(semantic_id, "references", reference_count)
305
+ if ref_status == 200:
306
+ _safe_write_json(paper_dir / "references_metadata.json", refs)
307
+
308
+ cit_status, cits = get_paper_links(semantic_id, "citations", citation_count)
309
+ if cit_status == 200:
310
+ _safe_write_json(paper_dir / "citations_metadata.json", cits)
311
+
312
+ if "ArXiv" not in external_ids:
313
+ _safe_write_text(paper_dir / "no_arxiv.txt", "no arxiv for this paper")
314
+
315
+
316
+ def fetch_one_acl_id(paper: dict, base_dir: Path):
317
+ acl_id = paper["id"]
318
+ title = (paper.get("title") or "").strip()
319
+ id_type = paper.get("id_type", "ACL")
320
+ openreview_id = paper.get("openreview_id", "")
321
+ input_pdf_url = paper.get("pdf_url", "")
322
+ s2_key = os.getenv("SEMANTIC_SCHOLAR_API_KEY", "")
323
+ print(
324
+ f"[DEBUG] fetch_one_acl_id: id={acl_id} id_type={id_type} "
325
+ f"title_len={len(title)} s2_key_present={'yes' if bool(s2_key) else 'no'} "
326
+ f"s2_key_len={len(s2_key)}"
327
+ )
328
+
329
+ paper_dir = base_dir / acl_id
330
+ ensure_dir(paper_dir)
331
+ meta_path = paper_dir / "paper_metadata.json"
332
+
333
+ if meta_path.exists():
334
+ return
335
+
336
+ status, data = 0, None
337
+ fetch_label = f"{id_type}:{acl_id}"
338
+ is_openreview = _treat_as_openreview(paper)
339
+ openreview_meta = None
340
+ attempted_title_search = False
341
+
342
+ if is_openreview:
343
+ try:
344
+ openreview_meta = _query_openreview_for_paper(openreview_id or acl_id)
345
+ except Exception as exc:
346
+ print(f"[WARN] OpenReview lookup failed for {acl_id}: {exc}")
347
+ openreview_meta = None
348
+
349
+ if openreview_meta:
350
+ _write_openreview_snapshot(paper_dir, openreview_meta)
351
+ or_title = (openreview_meta.get("title") or title or "").strip()
352
+ arxiv_id = (
353
+ _best_arxiv_id(
354
+ openreview_meta.get("arxiv_id", ""),
355
+ openreview_meta.get("pdf_url", ""),
356
+ input_pdf_url,
357
+ )
358
+ or ""
359
+ )
360
+ if arxiv_id:
361
+ print(f"[DEBUG] OpenReview -> found ArXiv {arxiv_id} for {acl_id}")
362
+ status, data = get_paper(arxiv_id, id_type="ArXiv")
363
+ fetch_label = f"ArXiv:{arxiv_id}"
364
+ title = or_title or title
365
+ elif or_title:
366
+ print(f"[DEBUG] OpenReview -> no arXiv for {acl_id}, title-searching")
367
+ status, data = _fetch_s2_by_title(or_title, acl_id)
368
+ fetch_label = f"title:{or_title[:80]}"
369
+ title = or_title
370
+ attempted_title_search = True
371
+ else:
372
+ print(f"[WARN] OpenReview metadata for {acl_id} had neither title nor arXiv")
373
+ else:
374
+ print(f"[WARN] no OpenReview metadata for {acl_id} (openreview_id={openreview_id or acl_id})")
375
+
376
+ if data is None and title and not attempted_title_search:
377
+ print(f"[DEBUG] OpenReview fallback -> title-searching extracted title for {acl_id}")
378
+ status, data = _fetch_s2_by_title(title, acl_id)
379
+ fetch_label = f"title:{title[:80]}"
380
+ attempted_title_search = True
381
+
382
+ if data is None and not is_openreview:
383
+ status, data = get_paper(acl_id, id_type=id_type)
384
+ fetch_label = f"{id_type}:{acl_id}"
385
+
386
+ if data is None and not attempted_title_search:
387
+ print(
388
+ f"[WARN] direct fetch failed for {fetch_label} "
389
+ f"(status={status}) → trying title search with title_len={len(title)}"
390
+ )
391
+ status, data = _fetch_s2_by_title(title, acl_id)
392
+
393
+ if status != 200 or data is None:
394
+ print(f"[WARN] still no data for {acl_id} → skipping.")
395
+ return
396
+
397
+ _write_metadata_outputs(paper_dir, acl_id, data)
398
+ print("[SUCCESS]")
399
+
400
+
401
+ def fetch_all_metadata(acl_ids_path: Path, out_dir: Path, start_from: str | None = None, resume: bool = False):
402
+ raw = json.loads(acl_ids_path.read_text(encoding="utf-8"))
403
+ papers = raw if isinstance(raw[0], dict) else [{"id": x, "title": ""} for x in raw]
404
+
405
+ start_seen = start_from is None
406
+ for paper in papers:
407
+ pid = str(paper.get("id", ""))
408
+ if not start_seen:
409
+ if pid == start_from:
410
+ start_seen = True
411
+ else:
412
+ continue
413
+ if resume:
414
+ paper_dir = out_dir / pid
415
+ if (paper_dir / "paper_metadata.json").exists():
416
+ continue
417
+ fetch_one_acl_id(paper, out_dir)
418
+ return "Meta Data Completed"
419
+
420
+
421
+ if __name__ == "__main__":
422
+ parser = argparse.ArgumentParser()
423
+ parser.add_argument("--ids", type=str, required=True, help="Path to JSON file with paper IDs.")
424
+ parser.add_argument("--outdir", type=str, default="papers", help="Output directory for metadata.")
425
+ parser.add_argument("--start-from", type=str, default=None, help="Start from this paper ID.")
426
+ parser.add_argument("--resume", action="store_true", help="Skip papers that already have paper_metadata.json.")
427
+ args = parser.parse_args()
428
+
429
+ ACL_IDS_PATH = Path(args.ids).expanduser().resolve()
430
+ OUTDIR = Path(args.outdir).expanduser().resolve()
431
+
432
+ if not ACL_IDS_PATH.exists():
433
+ raise FileNotFoundError(f"Could not find {ACL_IDS_PATH}")
434
+
435
+ print(f"[INFO] Using ID list from {ACL_IDS_PATH}")
436
+ print(f"[INFO] Output will be saved to {OUTDIR}")
437
+
438
+ start = time.time()
439
+ fetch_all_metadata(acl_ids_path=ACL_IDS_PATH, out_dir=OUTDIR, start_from=args.start_from, resume=args.resume)
440
+ print("done in", time.time() - start, "s")
src/step_01_fetch/process_tex_source.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from pathlib import Path
4
+ import shutil
5
+
6
+ def read_tex(path: Path) -> str:
7
+ try:
8
+ return path.read_text(encoding="utf-8", errors="ignore")
9
+ except Exception:
10
+ return ""
11
+
12
+
13
+ def resolve_inputs(tex: str, base_dir: Path, seen=None) -> str:
14
+ """
15
+ Recursively replace \\input{...} and \\include{...} with file contents.
16
+ """
17
+ if seen is None:
18
+ seen = set()
19
+
20
+ pattern = r'\\(?:input|include)\{([^}]+)\}'
21
+
22
+ def repl(match):
23
+ name = match.group(1)
24
+ if not name.endswith(".tex"):
25
+ name += ".tex"
26
+
27
+ full = base_dir / name
28
+
29
+ if full in seen:
30
+ return f"% WARNING: skipped circular input {full}\n"
31
+
32
+ if not full.exists():
33
+ return f"% WARNING: missing file {full}\n"
34
+
35
+ seen.add(full)
36
+ content = read_tex(full)
37
+ return resolve_inputs(content, full.parent, seen)
38
+
39
+ return re.sub(pattern, repl, tex)
40
+
41
+
42
+ def find_main_tex(source_dir: Path) -> Path | None:
43
+ """
44
+ Heuristic to find the main .tex file:
45
+ 1. match .bbl → .tex
46
+ 2. else top-level .tex that contains \\begin{document}
47
+ 3. else first .tex in directory
48
+ """
49
+ bbls = list(source_dir.glob("*.bbl"))
50
+ if bbls:
51
+ main_candidate = source_dir / (bbls[0].stem + ".tex")
52
+ if main_candidate.exists():
53
+ return main_candidate
54
+
55
+ for tex in source_dir.glob("*.tex"):
56
+ if "\\begin{document}" in read_tex(tex):
57
+ return tex
58
+
59
+ tex_files = list(source_dir.glob("*.tex"))
60
+ return tex_files[0] if tex_files else None
61
+
62
+
63
+ def preprocess_tex(source_dir: Path) -> Path | None:
64
+ """
65
+ Given an extracted arXiv source directory, produce:
66
+ - a merged TeX file named 'processed_main.tex'
67
+ - a concatenated BibTeX file named 'references.bib'
68
+ Both are written in the parent directory of source_dir (the paper dir).
69
+
70
+ Then delete the extracted source_dir.
71
+ """
72
+ main_tex = find_main_tex(source_dir)
73
+ if not main_tex:
74
+ print(f"[WARN] No main .tex found in {source_dir}")
75
+ shutil.rmtree(source_dir, ignore_errors=True)
76
+ return None
77
+
78
+ raw = read_tex(main_tex)
79
+ merged = resolve_inputs(raw, main_tex.parent)
80
+
81
+ paper_dir = source_dir.parent
82
+ out_tex_path = paper_dir / "processed_main.tex"
83
+ out_tex_path.write_text(merged, encoding="utf-8")
84
+
85
+ bib_files = list(source_dir.rglob("*.bib"))
86
+ if bib_files:
87
+ bib_texts = []
88
+ for bib in bib_files:
89
+ try:
90
+ bib_texts.append(bib.read_text(encoding="utf-8", errors="ignore"))
91
+ except Exception:
92
+ print(f"[WARN] Could not read bib file {bib}")
93
+ if bib_texts:
94
+ bib_out = paper_dir / "references.bib"
95
+ bib_out.write_text("\n\n".join(bib_texts), encoding="utf-8")
96
+ print(f"[INFO] Wrote combined BibTeX to {bib_out}")
97
+
98
+ shutil.rmtree(source_dir, ignore_errors=True)
99
+
100
+ return out_tex_path
101
+
102
+ def _load_tex(path: Path) -> str:
103
+ return path.read_text(encoding="utf-8", errors="ignore")
104
+
105
+
106
+ SECTION_PATTERN = re.compile(
107
+ r'\\section\*?\{([^}]*)\}',
108
+ flags=re.IGNORECASE
109
+ )
110
+
111
+
112
+ def _split_into_sections(tex: str):
113
+ """
114
+ Returns a list of (section_title, content) in order.
115
+ Title is the raw LaTeX title text (without braces).
116
+ Content is the text from this \\section line up to (but not including)
117
+ the next \\section or end of document.
118
+ """
119
+ sections = []
120
+ matches = list(SECTION_PATTERN.finditer(tex))
121
+
122
+ if not matches:
123
+ return sections
124
+
125
+ for i, m in enumerate(matches):
126
+ title = m.group(1).strip()
127
+ start = m.start()
128
+ end = matches[i + 1].start() if i + 1 < len(matches) else len(tex)
129
+ content = tex[start:end]
130
+ sections.append((title, content))
131
+
132
+ return sections
133
+
134
+
135
+ def _normalize_title(title: str) -> str:
136
+ """Lowercase and strip punctuation-ish stuff for robust matching."""
137
+ t = title.lower()
138
+ t = re.sub(r'[^a-z0-9\s]', ' ', t)
139
+ t = re.sub(r'\s+', ' ', t).strip()
140
+ return t
141
+
142
+
143
+ def _find_best_section(sections, candidates):
144
+ """
145
+ sections: list of (raw_title, content)
146
+ candidates: list of strings to match against normalized title
147
+ Returns the content of the best-matching section or None.
148
+ """
149
+ norm_candidates = [c.lower() for c in candidates]
150
+
151
+ for raw_title, content in sections:
152
+ nt = _normalize_title(raw_title)
153
+ for cand in norm_candidates:
154
+ if nt == cand or cand in nt:
155
+ return content
156
+ return None
157
+
158
+
159
+ def extract_introduction_and_related(
160
+ processed_tex_path: Path,
161
+ out_dir: Path | None = None,
162
+ ) -> dict:
163
+ """
164
+ Given path to processed_main.tex, extract Introduction and Related Work sections
165
+ into separate .tex files.
166
+
167
+ Returns a dict with keys:
168
+ {
169
+ "introduction": Path | None,
170
+ "related_work": Path | None
171
+ }
172
+ """
173
+ if out_dir is None:
174
+ out_dir = processed_tex_path.parent / "sections"
175
+
176
+ out_dir.mkdir(parents=True, exist_ok=True)
177
+
178
+ tex = _load_tex(processed_tex_path)
179
+ sections = _split_into_sections(tex)
180
+
181
+ intro_candidates = ["introduction"]
182
+ related_candidates = ["related work"]
183
+
184
+ intro_content = _find_best_section(sections, intro_candidates)
185
+ related_content = _find_best_section(sections, related_candidates)
186
+
187
+ results = {"introduction": None, "related_work": None}
188
+
189
+ if intro_content:
190
+ intro_path = out_dir / "introduction.tex"
191
+ intro_path.write_text(intro_content, encoding="utf-8")
192
+ results["introduction"] = intro_path
193
+ else:
194
+ print(f"[WARN] No Introduction section found in {processed_tex_path}")
195
+
196
+ if related_content:
197
+ rw_path = out_dir / "related_work.tex"
198
+ rw_path.write_text(related_content, encoding="utf-8")
199
+ results["related_work"] = rw_path
200
+ else:
201
+ print(f"[WARN] No Related Work section found in {processed_tex_path}")
202
+
203
+ return results
src/step_01_fetch/semanticscholar_client.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import random
3
+ import requests
4
+ from typing import Optional, Tuple, Any
5
+ from config import SEMANTIC_SCHOLAR_API_KEY
6
+ import os
7
+
8
+ BASE_URL = "https://api.semanticscholar.org/graph/v1/paper"
9
+
10
+ _LAST_REQUEST_TS = 0.0
11
+
12
+
13
+ def _min_interval_sleep() -> None:
14
+ """Global throttle to avoid hammering Semantic Scholar."""
15
+ global _LAST_REQUEST_TS
16
+ min_interval = float(os.getenv("S2_MIN_INTERVAL", "1.0"))
17
+ now = time.monotonic()
18
+ elapsed = now - _LAST_REQUEST_TS
19
+ if elapsed < min_interval:
20
+ time.sleep(min_interval - elapsed)
21
+ _LAST_REQUEST_TS = time.monotonic()
22
+
23
+
24
+ def robust_request(url, params=None, headers=None, max_retries=8, base_sleep=2.0):
25
+ """
26
+ Make a GET request with exponential backoff.
27
+ Retries on:
28
+ - connection errors
29
+ - 429 (Too Many Requests)
30
+ - 500–599 server errors
31
+ - invalid JSON
32
+ Returns (status_code, json_or_None).
33
+ """
34
+
35
+ for attempt in range(max_retries):
36
+ try:
37
+ _min_interval_sleep()
38
+ resp = requests.get(url, params=params, headers=headers, timeout=30)
39
+ status = resp.status_code
40
+
41
+ if status == 200:
42
+ try:
43
+ return 200, resp.json()
44
+ except Exception:
45
+ print(f"[WARN] JSON decode failed on attempt {attempt+1}/{max_retries}")
46
+
47
+ if status == 429:
48
+ retry_after = resp.headers.get("Retry-After")
49
+ if retry_after:
50
+ try:
51
+ sleep = float(retry_after)
52
+ except Exception:
53
+ sleep = base_sleep * (2 ** attempt)
54
+ else:
55
+ sleep = base_sleep * (2 ** attempt)
56
+ max_sleep = float(os.getenv("S2_MAX_BACKOFF", "60"))
57
+ sleep = min(sleep, max_sleep)
58
+ sleep += random.uniform(0.0, 0.5)
59
+ print(f"[WARN] 429 Too Many Requests → retrying in {sleep:.2f}s")
60
+ time.sleep(sleep)
61
+ continue
62
+
63
+ if 500 <= status < 600:
64
+ sleep = base_sleep * (2 ** attempt)
65
+ max_sleep = float(os.getenv("S2_MAX_BACKOFF", "60"))
66
+ sleep = min(sleep, max_sleep)
67
+ sleep += random.uniform(0.0, 0.5)
68
+ print(f"[WARN] Server error {status} → retrying in {sleep:.2f}s")
69
+ time.sleep(sleep)
70
+ continue
71
+
72
+ return status, None
73
+
74
+ except requests.exceptions.RequestException as e:
75
+ sleep = base_sleep * (2 ** attempt)
76
+ max_sleep = float(os.getenv("S2_MAX_BACKOFF", "60"))
77
+ sleep = min(sleep, max_sleep)
78
+ sleep += random.uniform(0.0, 0.5)
79
+ print(f"[WARN] Network error {e} → retrying in {sleep:.2f}s")
80
+ time.sleep(sleep)
81
+ continue
82
+
83
+ print(f"[ERROR] Giving up after {max_retries} attempts for URL: {url}")
84
+ return None, None
85
+ def get_paper(paper_id: str, id_type: str = "ACL") -> Tuple[int, Optional[dict]]:
86
+ """
87
+ id_type can be "ACL" or "SemanticScholar" or "ArXiv" etc.
88
+ """
89
+ if id_type == "SemanticScholar":
90
+ full_id = paper_id
91
+ else:
92
+ full_id = f"{id_type}:{paper_id}"
93
+
94
+ url = f"{BASE_URL}/{full_id}"
95
+ params = {
96
+ "fields": (
97
+ "title,year,publicationDate,authors,url,venue,externalIds,"
98
+ "tldr,abstract,citationCount,referenceCount,openAccessPdf"
99
+ )
100
+ }
101
+
102
+ headers = {"x-api-key": SEMANTIC_SCHOLAR_API_KEY} if SEMANTIC_SCHOLAR_API_KEY else {}
103
+
104
+ status, data = robust_request(url, params=params, headers=headers, max_retries=5, base_sleep=1.0)
105
+ if status == 200 and data is not None:
106
+ return status, data
107
+ else:
108
+ print(f"[WARN] {status} on {full_id}")
109
+ return status or 0, None
110
+
111
+
112
+ def get_paper_links(semantic_id: str, target_type: str, total: int, limit: int = 1000):
113
+ headers = {"x-api-key": SEMANTIC_SCHOLAR_API_KEY} if SEMANTIC_SCHOLAR_API_KEY else {}
114
+ loops = total // limit + 1 if total else 0
115
+ collected = []
116
+
117
+ for i in range(loops):
118
+ offset = i * limit
119
+ url = f"{BASE_URL}/{semantic_id}/{target_type}"
120
+ params = {
121
+ "offset": offset,
122
+ "limit": limit,
123
+ "fields": "paperId,title,isInfluential,externalIds,contextsWithIntent,openAccessPdf",
124
+ }
125
+
126
+ status, data = robust_request(url, params=params, headers=headers, max_retries=5, base_sleep=1.0)
127
+
128
+ if status != 200 or data is None:
129
+ print(f"[WARN] {target_type} fetch failed for {semantic_id} (status {status})")
130
+ return status or 0, []
131
+
132
+ items = data.get("data")
133
+ if not isinstance(items, list):
134
+ print(f"[WARN] malformed {target_type} response for {semantic_id}")
135
+ return status, []
136
+
137
+ collected.extend(items)
138
+
139
+ return 200, collected
140
+
141
+
142
+ def search_by_title(title: str, limit: int = 1):
143
+ """Search Semantic Scholar by paper title."""
144
+ url = "https://api.semanticscholar.org/graph/v1/paper/search"
145
+ params = {
146
+ "query": title,
147
+ "limit": limit,
148
+ "fields": "paperId,title,year,venue,externalIds",
149
+ }
150
+ headers = {"x-api-key": SEMANTIC_SCHOLAR_API_KEY} if SEMANTIC_SCHOLAR_API_KEY else {}
151
+
152
+ status, data = robust_request(url, params=params, headers=headers, max_retries=5, base_sleep=1.0)
153
+ if status == 200 and data is not None:
154
+ items = data.get("data", [])
155
+ return items[0] if items else None
156
+ else:
157
+ print(f"[WARN] title search failed for '{title[:60]}...' (status {status})")
158
+ return None
src/step_02_mark_citations/replace_citation_markers.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Tuple
6
+
7
+
8
+ PAPER_META_FILE = "paper_metadata.json"
9
+ USAGE_CLAIMS_FILE = "usage_claims.json"
10
+ USAGE_CONTEXTS_FILE = "usage_contexts.json"
11
+ CITATIONS_FILE = "citations_metadata.json"
12
+ PROCESSED_MAIN_FILE = "processed_main.tex"
13
+ REFERENCES_META_FILE = "references_metadata.json"
14
+
15
+
16
+ def load_json(path: Path) -> Any | None:
17
+ if not path.exists():
18
+ return None
19
+ try:
20
+ return json.loads(path.read_text(encoding="utf-8"))
21
+ except Exception:
22
+ return None
23
+
24
+
25
+ def save_json(path: Path, data: Any) -> None:
26
+ path.write_text(json.dumps(data, indent=2), encoding="utf-8")
27
+
28
+
29
+ def iter_paper_dirs(root: Path) -> List[Path]:
30
+ out: List[Path] = []
31
+ for child in root.iterdir():
32
+ if child.is_dir() and (child / PAPER_META_FILE).exists():
33
+ out.append(child)
34
+ return out
35
+
36
+
37
+ def load_paper_metadata(paper_dir: Path) -> Dict[str, Any]:
38
+ meta = load_json(paper_dir / PAPER_META_FILE)
39
+ if isinstance(meta, list) and meta:
40
+ return meta[0]
41
+ if isinstance(meta, dict):
42
+ return meta
43
+ return {}
44
+
45
+
46
+ def _is_structurally_complete(paper_dir: Path) -> bool:
47
+ return (
48
+ (paper_dir / PAPER_META_FILE).exists()
49
+ and (paper_dir / PROCESSED_MAIN_FILE).exists()
50
+ and (paper_dir / REFERENCES_META_FILE).exists()
51
+ )
52
+
53
+
54
+ def _author_last_names(authors: List[Any]) -> List[str]:
55
+ last_names: List[str] = []
56
+ for author in authors:
57
+ if isinstance(author, dict):
58
+ name = author.get("name")
59
+ else:
60
+ name = author
61
+ if not isinstance(name, str):
62
+ continue
63
+ parts = [p for p in re.split(r"\s+", name.strip()) if p]
64
+ if not parts:
65
+ continue
66
+ last_names.append(parts[-1])
67
+ return list(dict.fromkeys(last_names))
68
+
69
+
70
+ def _title_aliases(title: str) -> List[str]:
71
+ aliases = [title]
72
+ if ":" in title:
73
+ aliases.append(title.split(":", 1)[0])
74
+ acronym = "".join([c for c in title if c.isupper()])
75
+ if 3 <= len(acronym) <= 10:
76
+ aliases.append(acronym)
77
+ return list(dict.fromkeys([a for a in aliases if a]))
78
+
79
+
80
+ def _artifact_aliases(paper_dir: Path) -> List[str]:
81
+ aliases: List[str] = []
82
+ usage_claims = load_json(paper_dir / USAGE_CLAIMS_FILE)
83
+ if isinstance(usage_claims, dict):
84
+ caps = usage_claims.get("capabilities") or []
85
+ if isinstance(caps, list):
86
+ for cap in caps:
87
+ if not isinstance(cap, dict):
88
+ continue
89
+ name = cap.get("artifact_name")
90
+ if isinstance(name, str) and name.strip():
91
+ aliases.append(name.strip())
92
+ return list(dict.fromkeys(aliases))
93
+
94
+
95
+ def _loose_alias_pattern(alias: str) -> str:
96
+ parts = re.split(r"[^A-Za-z0-9]+", alias)
97
+ parts = [p for p in parts if p]
98
+ if not parts:
99
+ return ""
100
+ return r"\b" + r"[-\s]*".join(map(re.escape, parts)) + r"\b"
101
+
102
+
103
+ def build_patterns(
104
+ meta: Dict[str, Any],
105
+ paper_dir: Path,
106
+ ) -> Tuple[List[re.Pattern], List[re.Pattern], str | None]:
107
+ year = meta.get("year")
108
+ year_str = str(year) if isinstance(year, int) else None
109
+ authors = meta.get("authors") if isinstance(meta.get("authors"), list) else []
110
+ last_names = _author_last_names(authors)
111
+ title = meta.get("title") if isinstance(meta.get("title"), str) else ""
112
+ aliases = _title_aliases(title) + _artifact_aliases(paper_dir)
113
+ aliases = [a for a in aliases if a]
114
+
115
+ author_patterns: List[re.Pattern] = []
116
+ alias_patterns: List[re.Pattern] = []
117
+
118
+ if year_str and last_names:
119
+ year_pat = rf"{re.escape(year_str)}[a-z]?"
120
+ first_last = re.escape(last_names[0])
121
+ author_patterns.append(
122
+ re.compile(
123
+ rf"\b{first_last}\s+et\s+al\.?\s*(?:,\s*|\s*){year_pat}",
124
+ re.IGNORECASE,
125
+ )
126
+ )
127
+
128
+ for alias in aliases:
129
+ pat = _loose_alias_pattern(alias)
130
+ if pat:
131
+ alias_patterns.append(re.compile(pat, re.IGNORECASE))
132
+
133
+ first_last = last_names[0] if last_names else None
134
+ return author_patterns, alias_patterns, first_last
135
+
136
+
137
+ def _replace_author_span(text: str, first_last: str) -> Tuple[str, bool]:
138
+ occurrences = list(re.finditer(rf"\b{re.escape(first_last)}\b", text, re.IGNORECASE))
139
+ if len(occurrences) != 1:
140
+ return text, False
141
+ author_pat = re.compile(
142
+ rf"\(?\b{re.escape(first_last)}\b"
143
+ rf"\s+(?:et\s+al\.?|and|&)\s*"
144
+ rf"(?:,?\s*\(?\d{{4}}[a-z]?\)?)?"
145
+ rf"\)?",
146
+ re.IGNORECASE,
147
+ )
148
+ new_text, count = author_pat.subn("<CITED HERE>", text, count=1)
149
+ return new_text, count > 0
150
+
151
+
152
+ _BRACKET_NUM_RE = re.compile(r"\[[0-9,;\s]+\]")
153
+ _BRACKET_GROUP_RE = re.compile(r"\[([0-9,;\s]+)\]")
154
+
155
+
156
+ def _extract_bracket_numbers(text: str) -> List[str]:
157
+ numbers: List[str] = []
158
+ for match in _BRACKET_GROUP_RE.finditer(text):
159
+ parts = re.split(r"[,\s;]+", match.group(1).strip())
160
+ for part in parts:
161
+ if part.isdigit():
162
+ numbers.append(part)
163
+ return numbers
164
+
165
+
166
+ def _dominant_bracket(contexts: List[Dict[str, Any]]) -> str | None:
167
+ counts: Dict[str, int] = {}
168
+ for ctx in contexts:
169
+ if not isinstance(ctx, dict):
170
+ continue
171
+ text = ctx.get("context") or ctx.get("text")
172
+ if not isinstance(text, str):
173
+ continue
174
+ for num in _extract_bracket_numbers(text):
175
+ counts[num] = counts.get(num, 0) + 1
176
+ if not counts:
177
+ return None
178
+ best = max(counts.values())
179
+ winners = [num for num, count in counts.items() if count == best]
180
+ if len(winners) == 1:
181
+ return winners[0]
182
+ return None
183
+
184
+
185
+ def _single_bracket_candidate(contexts: List[Dict[str, Any]]) -> str | None:
186
+ counts: Dict[str, int] = {}
187
+ for ctx in contexts:
188
+ if not isinstance(ctx, dict):
189
+ continue
190
+ text = ctx.get("context") or ctx.get("text")
191
+ if not isinstance(text, str):
192
+ continue
193
+ matches = list(_BRACKET_GROUP_RE.finditer(text))
194
+ if len(matches) == 1:
195
+ nums = _extract_bracket_numbers(text)
196
+ if len(nums) != 1:
197
+ continue
198
+ num = nums[0]
199
+ counts[num] = counts.get(num, 0) + 1
200
+ if not counts:
201
+ return None
202
+ best = max(counts.values())
203
+ winners = [num for num, count in counts.items() if count == best]
204
+ if len(winners) == 1:
205
+ return winners[0]
206
+ return None
207
+
208
+
209
+ def _replace_single_bracket(text: str, dominant: str | None) -> Tuple[str, bool]:
210
+ matches = list(_BRACKET_GROUP_RE.finditer(text))
211
+ if len(matches) != 1:
212
+ return text, False
213
+ nums = _extract_bracket_numbers(text)
214
+ if len(nums) != 1:
215
+ return text, False
216
+ num = nums[0]
217
+ if dominant is not None and num != dominant:
218
+ return text, False
219
+ start, end = matches[0].span()
220
+ return text[:start] + "<CITED HERE>" + text[end:], True
221
+
222
+
223
+ def replace_with_marker(
224
+ text: str,
225
+ author_patterns: List[re.Pattern],
226
+ alias_patterns: List[re.Pattern],
227
+ dominant_bracket: str | None = None,
228
+ first_author_last: str | None = None,
229
+ ) -> Tuple[str, bool]:
230
+ def _collapse_markers(value: str) -> str:
231
+ value = re.sub(r"(?:<CITED HERE>[\s()\[\],;:]*){2,}", "<CITED HERE> ", value)
232
+ value = re.sub(r"<CITED HERE>(?:\s+<CITED HERE>)+", "<CITED HERE>", value)
233
+ return value.strip()
234
+
235
+ updated = text
236
+ changed = False
237
+
238
+ author_changed = False
239
+ if first_author_last:
240
+ new, author_changed = _replace_author_span(updated, first_author_last)
241
+ if author_changed:
242
+ changed = True
243
+ updated = _collapse_markers(new)
244
+
245
+ if dominant_bracket:
246
+ def _replace_if_contains(match: re.Match) -> str:
247
+ nums = re.split(r"[,\s;]+", match.group(1).strip())
248
+ if any(n == dominant_bracket for n in nums if n.isdigit()):
249
+ return "<CITED HERE>"
250
+ return match.group(0)
251
+
252
+ new = _BRACKET_GROUP_RE.sub(_replace_if_contains, updated)
253
+ if new != updated:
254
+ changed = True
255
+ updated = _collapse_markers(new)
256
+
257
+ for pat in author_patterns:
258
+ new = pat.sub("<CITED HERE>", updated)
259
+ if new != updated:
260
+ changed = True
261
+ updated = _collapse_markers(new)
262
+
263
+ if not author_changed:
264
+ for pat in alias_patterns:
265
+ new = pat.sub("<CITED HERE>", updated)
266
+ if new != updated:
267
+ changed = True
268
+ updated = _collapse_markers(new)
269
+
270
+ new, bracket_changed = _replace_single_bracket(updated, dominant_bracket)
271
+ if bracket_changed:
272
+ changed = True
273
+ updated = _collapse_markers(new)
274
+
275
+ updated = _collapse_markers(updated)
276
+ return updated, changed
277
+
278
+
279
+ def _process_contexts(
280
+ contexts: List[Dict[str, Any]],
281
+ author_patterns: List[re.Pattern],
282
+ alias_patterns: List[re.Pattern],
283
+ dominant_bracket: str | None,
284
+ first_author_last: str | None,
285
+ ) -> Tuple[int, int]:
286
+ updated_count = 0
287
+ total = 0
288
+ for ctx in contexts:
289
+ if not isinstance(ctx, dict):
290
+ continue
291
+ text = ctx.get("context") or ctx.get("text")
292
+ if not isinstance(text, str):
293
+ continue
294
+ total += 1
295
+ new_text, changed = replace_with_marker(
296
+ text,
297
+ author_patterns=author_patterns,
298
+ alias_patterns=alias_patterns,
299
+ dominant_bracket=dominant_bracket,
300
+ first_author_last=first_author_last,
301
+ )
302
+ if changed:
303
+ updated_count += 1
304
+ ctx["context_with_marker"] = new_text
305
+ return updated_count, total
306
+
307
+
308
+ def update_citations_file(
309
+ paper_dir: Path,
310
+ author_patterns: List[re.Pattern],
311
+ alias_patterns: List[re.Pattern],
312
+ first_author_last: str | None,
313
+ ) -> Tuple[int, int]:
314
+ path = paper_dir / CITATIONS_FILE
315
+ data = load_json(path)
316
+ if not isinstance(data, list):
317
+ return 0, 0
318
+ updated = 0
319
+ total = 0
320
+ for entry in data:
321
+ if not isinstance(entry, dict):
322
+ continue
323
+ ctxs = entry.get("contextsWithIntent") or []
324
+ if isinstance(ctxs, list):
325
+ dominant = _dominant_bracket(ctxs)
326
+ if dominant is None:
327
+ dominant = _single_bracket_candidate(ctxs)
328
+ upd, tot = _process_contexts(
329
+ ctxs,
330
+ author_patterns,
331
+ alias_patterns,
332
+ dominant,
333
+ first_author_last,
334
+ )
335
+ updated += upd
336
+ total += tot
337
+ save_json(path, data)
338
+ return updated, total
339
+
340
+
341
+ def update_usage_contexts_file(
342
+ paper_dir: Path,
343
+ author_patterns: List[re.Pattern],
344
+ alias_patterns: List[re.Pattern],
345
+ first_author_last: str | None,
346
+ ) -> Tuple[int, int]:
347
+ path = paper_dir / USAGE_CONTEXTS_FILE
348
+ data = load_json(path)
349
+ if not isinstance(data, dict):
350
+ return 0, 0
351
+ updated = 0
352
+ total = 0
353
+ for entry in data.get("citing_papers", []) or []:
354
+ if not isinstance(entry, dict):
355
+ continue
356
+ ctxs = entry.get("contexts") or []
357
+ if isinstance(ctxs, list):
358
+ dominant = _dominant_bracket(ctxs)
359
+ if dominant is None:
360
+ dominant = _single_bracket_candidate(ctxs)
361
+ upd, tot = _process_contexts(
362
+ ctxs,
363
+ author_patterns,
364
+ alias_patterns,
365
+ dominant,
366
+ first_author_last,
367
+ )
368
+ updated += upd
369
+ total += tot
370
+ save_json(path, data)
371
+ return updated, total
372
+
373
+
374
+ def main() -> None:
375
+ parser = argparse.ArgumentParser(
376
+ description="Replace citation mentions with <CITED HERE> in context fields."
377
+ )
378
+ parser.add_argument(
379
+ "--root",
380
+ type=str,
381
+ default="runs/processed_papers",
382
+ help="Root directory containing processed paper directories.",
383
+ )
384
+ parser.add_argument(
385
+ "--usage-contexts",
386
+ action="store_true",
387
+ help="Also update usage_contexts.json.",
388
+ )
389
+ args = parser.parse_args()
390
+
391
+ root = Path(args.root).expanduser().resolve()
392
+ if not root.exists():
393
+ raise SystemExit(f"Root directory does not exist: {root}")
394
+
395
+ paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
396
+ print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
397
+ total_updated = 0
398
+ total_contexts = 0
399
+ skipped_incomplete = 0
400
+
401
+ for paper_dir in paper_dirs:
402
+ if not _is_structurally_complete(paper_dir):
403
+ skipped_incomplete += 1
404
+ continue
405
+ meta = load_paper_metadata(paper_dir)
406
+ if not meta:
407
+ continue
408
+ author_patterns, alias_patterns, first_author_last = build_patterns(meta, paper_dir)
409
+ if not (author_patterns or alias_patterns):
410
+ continue
411
+ updated, total = update_citations_file(
412
+ paper_dir,
413
+ author_patterns,
414
+ alias_patterns,
415
+ first_author_last,
416
+ )
417
+ total_updated += updated
418
+ total_contexts += total
419
+ if args.usage_contexts:
420
+ upd_usage, tot_usage = update_usage_contexts_file(
421
+ paper_dir,
422
+ author_patterns,
423
+ alias_patterns,
424
+ first_author_last,
425
+ )
426
+ updated += upd_usage
427
+ total += tot_usage
428
+ total_updated += upd_usage
429
+ total_contexts += tot_usage
430
+ if total:
431
+ print(f"[OK] {paper_dir.name}: updated {updated} contexts over {total}")
432
+
433
+ print(
434
+ f"[SUMMARY] total_updated={total_updated} over {total_contexts}; "
435
+ f"skipped_incomplete={skipped_incomplete}"
436
+ )
437
+
438
+
439
+ if __name__ == "__main__":
440
+ main()
src/step_03_usage_contexts/build_usage_contexts.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List, Optional
5
+
6
+
7
+ PAPER_META_FILE = "paper_metadata.json"
8
+ CITATIONS_FILE = "citations_metadata.json"
9
+ DEFAULT_OUT_NAME = "usage_contexts.json"
10
+
11
+
12
+ def load_json(path: Path) -> Any | None:
13
+ if not path.exists():
14
+ return None
15
+ try:
16
+ return json.loads(path.read_text(encoding="utf-8"))
17
+ except Exception as e:
18
+ print(f"[WARN] could not parse JSON at {path}: {e}")
19
+ return None
20
+
21
+
22
+ def iter_paper_dirs(root: Path) -> List[Path]:
23
+ out: List[Path] = []
24
+ for child in root.iterdir():
25
+ if child.is_dir() and (child / PAPER_META_FILE).exists():
26
+ out.append(child)
27
+ return out
28
+
29
+
30
+ def _extract_contexts(item: Dict[str, Any]) -> List[Dict[str, Any]]:
31
+ contexts: List[Dict[str, Any]] = []
32
+
33
+ raw = item.get("contextsWithIntent") or []
34
+ if isinstance(raw, list) and raw:
35
+ for entry in raw:
36
+ if not isinstance(entry, dict):
37
+ continue
38
+ text_raw = (entry.get("context") or "").strip()
39
+ text = (entry.get("context_with_marker") or text_raw).strip()
40
+ intents = entry.get("intents") or []
41
+ contexts.append(
42
+ {
43
+ "text": text,
44
+ "text_raw": text_raw,
45
+ "intents": intents,
46
+ }
47
+ )
48
+
49
+ # Fallback for older schema that only stores raw context strings.
50
+ if not contexts:
51
+ raw_alt = item.get("contexts") or []
52
+ if isinstance(raw_alt, list):
53
+ for text in raw_alt:
54
+ if not isinstance(text, str):
55
+ continue
56
+ text = text.strip()
57
+ if text:
58
+ contexts.append(
59
+ {
60
+ "text": text,
61
+ "intents": [],
62
+ }
63
+ )
64
+
65
+ return contexts
66
+
67
+
68
+ def build_usage_contexts_for_paper(paper_dir: Path) -> Optional[Dict[str, Any]]:
69
+ citations_path = paper_dir / CITATIONS_FILE
70
+ data = load_json(citations_path)
71
+ if data is None:
72
+ return None
73
+
74
+ if not isinstance(data, list):
75
+ print(f"[WARN] {paper_dir.name}: {CITATIONS_FILE} is not a list")
76
+ return None
77
+
78
+ citing_entries: List[Dict[str, Any]] = []
79
+ total_contexts = 0
80
+ citing_with_context = 0
81
+ influential_citations = 0
82
+ influential_with_context = 0
83
+ influential_contexts: List[Dict[str, Any]] = []
84
+
85
+ for item in data:
86
+ if not isinstance(item, dict):
87
+ continue
88
+ citing = item.get("citingPaper") or {}
89
+
90
+ contexts = _extract_contexts(item)
91
+ is_influential = bool(item.get("isInfluential", False))
92
+ if is_influential:
93
+ influential_citations += 1
94
+ if contexts:
95
+ citing_with_context += 1
96
+ total_contexts += len(contexts)
97
+ if is_influential:
98
+ influential_with_context += 1
99
+
100
+ citing_entries.append(
101
+ {
102
+ "citing_paper_id": citing.get("paperId"),
103
+ "title": citing.get("title"),
104
+ "external_ids": citing.get("externalIds") or {},
105
+ "is_influential": is_influential,
106
+ "contexts": contexts,
107
+ }
108
+ )
109
+ if is_influential and contexts:
110
+ influential_contexts.append(
111
+ {
112
+ "citing_paper_id": citing.get("paperId"),
113
+ "title": citing.get("title"),
114
+ "external_ids": citing.get("externalIds") or {},
115
+ "contexts": contexts,
116
+ }
117
+ )
118
+
119
+ payload = {
120
+ "paper_id": paper_dir.name,
121
+ "total_citations": len(data),
122
+ "num_contexts": total_contexts,
123
+ "num_citing_with_context": citing_with_context,
124
+ "num_citing_without_context": len(data) - citing_with_context,
125
+ "num_influential_citations": influential_citations,
126
+ "num_influential_with_context": influential_with_context,
127
+ "influential_contexts": influential_contexts,
128
+ "citing_papers": citing_entries,
129
+ }
130
+ return payload
131
+
132
+
133
+ def run(root: Path, out_name: str, overwrite: bool) -> None:
134
+ root = root.resolve()
135
+ if not root.exists():
136
+ raise SystemExit(f"Root directory does not exist: {root}")
137
+
138
+ paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
139
+ print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
140
+
141
+ for paper_dir in paper_dirs:
142
+ out_path = paper_dir / out_name
143
+ if out_path.exists() and not overwrite:
144
+ print(f"[SKIP] {paper_dir.name}: {out_name} already exists")
145
+ continue
146
+ payload = build_usage_contexts_for_paper(paper_dir)
147
+ if payload is None:
148
+ continue
149
+
150
+ out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
151
+ print(
152
+ f"[OK] {paper_dir.name}: wrote {out_name} "
153
+ f"({payload['num_contexts']} contexts from {payload['total_citations']} citations)"
154
+ )
155
+
156
+
157
+ def main() -> None:
158
+ parser = argparse.ArgumentParser(
159
+ description="Build usage_contexts.json from citations_metadata.json files."
160
+ )
161
+ parser.add_argument(
162
+ "--root",
163
+ type=str,
164
+ default="processed_papers/acl_2024",
165
+ help="Root directory containing processed_papers/acl_2024/<paper_id> dirs.",
166
+ )
167
+ parser.add_argument(
168
+ "--out-name",
169
+ type=str,
170
+ default=DEFAULT_OUT_NAME,
171
+ help="Output filename to write inside each paper dir.",
172
+ )
173
+ parser.add_argument(
174
+ "--overwrite",
175
+ action="store_true",
176
+ help="Overwrite existing usage_contexts.json files.",
177
+ )
178
+ args = parser.parse_args()
179
+
180
+ run(Path(args.root), out_name=args.out_name, overwrite=args.overwrite)
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
src/step_04_label_citations/label_citation_functions.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List
6
+
7
+ DEEP_CITATION_ROOT = Path(__file__).resolve().parents[2] / "Deep-Citation"
8
+ if not DEEP_CITATION_ROOT.exists():
9
+ raise SystemExit(f"Deep-Citation repo not found at {DEEP_CITATION_ROOT}")
10
+
11
+ sys.path.insert(0, str(DEEP_CITATION_ROOT))
12
+
13
+ from data import CollateFn, create_data_channels
14
+ from Model import MultiHeadLanguageModel
15
+ import torch
16
+ from torch.utils.data import DataLoader
17
+
18
+
19
+ PAPER_META_FILE = "paper_metadata.json"
20
+ USAGE_CONTEXTS_FILE = "usage_contexts.json"
21
+ OUT_FILE = "usage_context_labels.json"
22
+
23
+ LABEL_SET = [
24
+ "Background",
25
+ "Uses",
26
+ "Extends",
27
+ "CompareOrContrast",
28
+ "Motivation",
29
+ "Future",
30
+ ]
31
+
32
+
33
+ def load_json(path: Path) -> Any | None:
34
+ if not path.exists():
35
+ return None
36
+ try:
37
+ return json.loads(path.read_text(encoding="utf-8"))
38
+ except Exception:
39
+ return None
40
+
41
+
42
+ def iter_paper_dirs(root: Path) -> List[Path]:
43
+ out: List[Path] = []
44
+ for child in root.iterdir():
45
+ if child.is_dir() and (child / PAPER_META_FILE).exists():
46
+ out.append(child)
47
+ return out
48
+
49
+
50
+ def flatten_contexts(usage: Dict[str, Any]) -> List[Dict[str, Any]]:
51
+ contexts: List[Dict[str, Any]] = []
52
+ idx = 1
53
+ for entry in usage.get("citing_papers", []) or []:
54
+ if not isinstance(entry, dict):
55
+ continue
56
+ citing_title = entry.get("title") or "Unknown citing paper"
57
+ citing_paper_id = entry.get("citing_paper_id") or ""
58
+ for c in entry.get("contexts", []) or []:
59
+ if not isinstance(c, dict):
60
+ continue
61
+ text = (c.get("text") or "").strip()
62
+ if not text:
63
+ continue
64
+ contexts.append(
65
+ {
66
+ "id": idx,
67
+ "text": text,
68
+ "citing_title": citing_title,
69
+ "citing_paper_id": citing_paper_id,
70
+ }
71
+ )
72
+ idx += 1
73
+ return contexts
74
+
75
+
76
+ def _resolve_model_name(lm: str) -> str:
77
+ if lm == "scibert":
78
+ return "allenai/scibert_scivocab_uncased"
79
+ if lm == "bert":
80
+ return "bert-base-uncased"
81
+ if lm == "deberta":
82
+ return "microsoft/deberta-v3-base"
83
+ if lm == "deberta-large":
84
+ return "microsoft/deberta-v3-large"
85
+ return lm
86
+
87
+
88
+ def _infer_head_sizes(state_dict: Dict[str, Any]) -> List[int]:
89
+ head_weights = [
90
+ (k, v) for k, v in state_dict.items() if k.startswith("lns.") and k.endswith(".weight")
91
+ ]
92
+ head_weights.sort(key=lambda x: int(x[0].split(".")[1]))
93
+ return [int(weight.shape[0]) for _, weight in head_weights]
94
+
95
+
96
+ class _ContextDataset:
97
+ def __init__(self, texts: List[str]):
98
+ self.texts = texts
99
+
100
+ def __len__(self) -> int:
101
+ return len(self.texts)
102
+
103
+ def __getitem__(self, idx: int):
104
+ return (self.texts[idx], torch.tensor(0), torch.tensor(0))
105
+
106
+
107
+ def label_with_model(
108
+ contexts: List[Dict[str, Any]],
109
+ model_path: Path,
110
+ data_dir: Path,
111
+ class_definition: Path,
112
+ lm: str,
113
+ device: str,
114
+ batch_size: int,
115
+ ) -> Dict[int, Dict[str, Any]]:
116
+ data_file = data_dir / "acl.tsv"
117
+ train_data, _, _, label_names = create_data_channels(
118
+ str(data_file),
119
+ str(class_definition),
120
+ lmbd=1.0,
121
+ )
122
+ modelname = _resolve_model_name(lm)
123
+ state_dict = torch.load(model_path, map_location=device)
124
+ head_sizes = _infer_head_sizes(state_dict)
125
+ model = MultiHeadLanguageModel(
126
+ modelname=modelname,
127
+ device=device,
128
+ readout="ch",
129
+ num_classes=head_sizes,
130
+ ).to(device)
131
+ model.load_state_dict(state_dict)
132
+ model.eval()
133
+
134
+ collate_fn = CollateFn(
135
+ modelname=modelname,
136
+ class_definitions=train_data.class_definitions,
137
+ instance_weights=False,
138
+ )
139
+ dataset = _ContextDataset([ctx["text"] for ctx in contexts])
140
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
141
+
142
+ outputs: Dict[int, Dict[str, Any]] = {}
143
+ idx_offset = 0
144
+ with torch.no_grad():
145
+ for batched_text, labels, ds_indices, class_tokens, class_ds_indices in loader:
146
+ ds_indices = ds_indices.to(device)
147
+ class_ds_indices = class_ds_indices.to(device)
148
+ logits = model(batched_text, ds_indices, class_tokens, class_ds_indices)[0]
149
+ probs = torch.softmax(logits, dim=1)
150
+ preds = logits.argmax(dim=1).cpu().tolist()
151
+ pred_confidences = probs.max(dim=1).values.cpu().tolist()
152
+ top2 = torch.topk(probs, k=2, dim=1).values.cpu()
153
+ margins = (top2[:, 0] - top2[:, 1]).tolist()
154
+ for i, pred in enumerate(preds):
155
+ raw_label = label_names[pred]
156
+ outputs[idx_offset + i + 1] = {
157
+ "id": idx_offset + i + 1,
158
+ "label": raw_label,
159
+ "confidence": float(pred_confidences[i]),
160
+ "confidence_margin": float(margins[i]),
161
+ "cue_span": "",
162
+ "rationale": "scibert_model",
163
+ }
164
+ idx_offset += len(preds)
165
+ return outputs
166
+
167
+
168
+ def aggregate_citing_labels(labels: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
169
+ by_citing: Dict[str, List[Dict[str, Any]]] = {}
170
+ for item in labels:
171
+ citing_id = item.get("citing_paper_id") or ""
172
+ by_citing.setdefault(citing_id, []).append(item)
173
+
174
+ aggregated: List[Dict[str, Any]] = []
175
+ for citing_id, items in by_citing.items():
176
+ title = items[0].get("citing_title", "")
177
+ labels_set = {it.get("label") for it in items}
178
+
179
+ if "Extends" in labels_set:
180
+ label = "Extends"
181
+ evidence_ids = [it["id"] for it in items if it.get("label") == "Extends"]
182
+ elif "Uses" in labels_set:
183
+ label = "Uses"
184
+ evidence_ids = [it["id"] for it in items if it.get("label") == "Uses"]
185
+ elif "CompareOrContrast" in labels_set:
186
+ label = "CompareOrContrast"
187
+ evidence_ids = [
188
+ it["id"] for it in items if it.get("label") == "CompareOrContrast"
189
+ ]
190
+ else:
191
+ label = "Background"
192
+ evidence_ids = []
193
+
194
+ aggregated.append(
195
+ {
196
+ "citing_paper_id": citing_id,
197
+ "citing_title": title,
198
+ "label": label,
199
+ "evidence_context_ids": evidence_ids,
200
+ }
201
+ )
202
+
203
+ return aggregated
204
+
205
+
206
+ def aggregate_final_label(citing_labels: List[Dict[str, Any]]) -> str:
207
+ labels_set = {item.get("label") for item in citing_labels}
208
+ if "Extends" in labels_set:
209
+ return "Extends"
210
+ if "Uses" in labels_set:
211
+ return "Uses"
212
+ if "CompareOrContrast" in labels_set:
213
+ return "CompareOrContrast"
214
+ return "Background"
215
+
216
+
217
+ def score_for_paper(
218
+ paper_dir: Path,
219
+ batch_size: int,
220
+ overwrite: bool,
221
+ model_path: Path,
222
+ model_data_dir: Path,
223
+ model_class_def: Path,
224
+ model_lm: str,
225
+ device: str,
226
+ ) -> str:
227
+ usage_path = paper_dir / USAGE_CONTEXTS_FILE
228
+ usage = load_json(usage_path)
229
+ if not isinstance(usage, dict):
230
+ return "missing_usage"
231
+
232
+ contexts = flatten_contexts(usage)
233
+ if not contexts:
234
+ return "empty_contexts"
235
+
236
+ out_path = paper_dir / OUT_FILE
237
+ if out_path.exists() and not overwrite:
238
+ return "skipped"
239
+
240
+ labeled = label_with_model(
241
+ contexts=contexts,
242
+ model_path=model_path,
243
+ data_dir=model_data_dir,
244
+ class_definition=model_class_def,
245
+ lm=model_lm,
246
+ device=device,
247
+ batch_size=batch_size,
248
+ )
249
+
250
+ labels_sorted = []
251
+ for context in contexts:
252
+ context_id = context["id"]
253
+ item = labeled.get(context_id)
254
+ if not item:
255
+ item = {
256
+ "id": context_id,
257
+ "label": "Background",
258
+ "confidence": 0.0,
259
+ "cue_span": "",
260
+ "rationale": "missing label",
261
+ }
262
+ item = dict(item)
263
+ item["citing_paper_id"] = context.get("citing_paper_id", "")
264
+ item["citing_title"] = context.get("citing_title", "")
265
+ item["text"] = context.get("text", "")
266
+ labels_sorted.append(item)
267
+
268
+ citing_labels = aggregate_citing_labels(labels_sorted)
269
+ payload = {
270
+ "paper_id": usage.get("paper_id"),
271
+ "num_contexts": len(contexts),
272
+ "label_set": LABEL_SET,
273
+ "labels": labels_sorted,
274
+ "citing_paper_labels": citing_labels,
275
+ "final_label": aggregate_final_label(citing_labels),
276
+ }
277
+ out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
278
+ return "labeled"
279
+
280
+
281
+ def main() -> None:
282
+ parser = argparse.ArgumentParser(
283
+ description="Label citation functions using a Deep-Citation checkpoint."
284
+ )
285
+ parser.add_argument(
286
+ "--root",
287
+ type=str,
288
+ default="runs/processed_papers",
289
+ help="Root directory containing processed paper directories.",
290
+ )
291
+ parser.add_argument(
292
+ "--batch-size",
293
+ type=int,
294
+ default=32,
295
+ help="Batch size for model inference.",
296
+ )
297
+ parser.add_argument(
298
+ "--overwrite",
299
+ action="store_true",
300
+ help="Overwrite existing usage_context_labels.json files.",
301
+ )
302
+ parser.add_argument(
303
+ "--model-path",
304
+ type=str,
305
+ required=True,
306
+ help="Path to Deep-Citation best_model.pt checkpoint.",
307
+ )
308
+ parser.add_argument(
309
+ "--model-data-dir",
310
+ type=str,
311
+ default="Deep-Citation/Data",
312
+ help="Deep-Citation data directory (for label order).",
313
+ )
314
+ parser.add_argument(
315
+ "--model-class-def",
316
+ type=str,
317
+ default="Deep-Citation/Data/class_def.json",
318
+ help="Deep-Citation class_def.json path.",
319
+ )
320
+ parser.add_argument(
321
+ "--model-lm",
322
+ type=str,
323
+ default="scibert",
324
+ help="Model name used for the Deep-Citation checkpoint.",
325
+ )
326
+ parser.add_argument(
327
+ "--device",
328
+ type=str,
329
+ default="cuda",
330
+ help="Device for model inference (cuda/cpu).",
331
+ )
332
+ args = parser.parse_args()
333
+
334
+ model_path = Path(args.model_path).expanduser().resolve()
335
+ if not model_path.exists():
336
+ raise SystemExit(f"Model path does not exist: {model_path}")
337
+
338
+ root = Path(args.root).expanduser().resolve()
339
+ if not root.exists():
340
+ raise SystemExit(f"Root directory does not exist: {root}")
341
+
342
+ paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
343
+ print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
344
+
345
+ counts = {
346
+ "labeled": 0,
347
+ "skipped": 0,
348
+ "missing_usage": 0,
349
+ "empty_contexts": 0,
350
+ }
351
+
352
+ for paper_dir in paper_dirs:
353
+ status = score_for_paper(
354
+ paper_dir,
355
+ args.batch_size,
356
+ args.overwrite,
357
+ model_path=model_path,
358
+ model_data_dir=Path(args.model_data_dir).expanduser().resolve(),
359
+ model_class_def=Path(args.model_class_def).expanduser().resolve(),
360
+ model_lm=args.model_lm,
361
+ device=args.device,
362
+ )
363
+ counts[status] = counts.get(status, 0) + 1
364
+ print(f"[{status.upper()}] {paper_dir.name}")
365
+
366
+ print(
367
+ "[SUMMARY] labeled={labeled}, skipped={skipped}, missing_usage={missing_usage}, "
368
+ "empty_contexts={empty_contexts}".format(**counts)
369
+ )
370
+
371
+
372
+ if __name__ == "__main__":
373
+ main()
src/step_05_verify_uses_extends/prompts.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+
4
+ USES_DEFINITION = (
5
+ "USES: The CITING_PAPER explicitly uses/adopts/evaluates on/includes/relies on "
6
+ "a dataset, benchmark, method, tool, or reported results from TARGET_PAPER "
7
+ "as part of the CITING_PAPER's own methodology or evaluation."
8
+ )
9
+
10
+ EXTENDS_DEFINITION = (
11
+ "EXTENDS: The CITING_PAPER explicitly extends/modifies/adapts/builds upon "
12
+ "TARGET_PAPER's method/dataset/benchmark/tool."
13
+ )
14
+
15
+ NOTES_DEFINITION = (
16
+ "NOT USES/EXTENDS: Merely describing what TARGET_PAPER introduces/offers/proposes "
17
+ "or listing it among related work or benchmarks (without stating adoption). "
18
+ "If no explicit adoption/extension cue is present, label NOT_CONFIRMED."
19
+ )
20
+
21
+
22
+ FEW_SHOT_USES = [
23
+ "We use the same splits as <CITED HERE> .",
24
+ "The Praat tool was used ( <CITED HERE> ) .",
25
+ "CCGBank ( <CITED HERE> ) is used to train the model .",
26
+ "This design idea was adopted from TANKA ( <CITED HERE>b ) .",
27
+ "Our strategy is based on the approach presented by <CITED HERE> .",
28
+ ]
29
+
30
+ FEW_SHOT_EXTENDS = [
31
+ "The features can be easily obtained by modifying the TAT extraction algorithm described in ( <CITED HERE> ) .",
32
+ "Our own work ( <CITED HERE> ) extends the first idea to paraphrase fragment extraction on monolingual parallel and comparable corpora .",
33
+ "This article represents an extension of our previous work on unsupervised event coreference resolution ( Bejan et al. 2009 ; <CITED HERE> ) .",
34
+ "This evaluation set-up is an improvement versus the one we previously reported ( <CITED HERE> ) , in which fixed partitions were used for training , development , and testing .",
35
+ "The computational treatment of lexical rules proposed can be seen as an extension to the principled method discussed by Gotz and <CITED HERE> , 1996 , 1997b ) for encoding the main building block of HPSG grammars -- the implicative constraints -- as a logic program .",
36
+ ]
37
+
38
+ FEW_SHOT_NOT_CONFIRMED = [
39
+ "<CITED HERE> introduced factored SMT .",
40
+ "See ( <CITED HERE> ) for a discussion .",
41
+ "See , among others , ( <CITED HERE> ) .",
42
+ "<CITED HERE> reported a correlation of r = .69 .",
43
+ "See <CITED HERE> for further discussion .",
44
+ ]
45
+
46
+
47
+ def build_uses_extends_verification_prompt(
48
+ target_info: Dict[str, str],
49
+ candidates: List[Dict[str, str]],
50
+ ) -> str:
51
+ header = [
52
+ "You are verifying citation function for a TARGET paper inside a citing sentence.",
53
+ "Be strict: lists of related work or benchmarks are NOT USES/EXTENDS unless there is an explicit action",
54
+ "like \"use\", \"build on\", \"adopt\", \"extend\", \"based on\", \"trained on\", \"evaluate on\", \"implement\".",
55
+ "",
56
+ "Actor test (CRITICAL for USES/EXTENSION):",
57
+ "- Only label USES or EXTENSION if the ACTION is performed by the CITING_PAPER.",
58
+ "- The cue_span for USES/EXTENSION must include an explicit citing-paper actor phrase such as:",
59
+ " \"we\", \"our\", \"in this work\", \"in this paper\", \"we use\", \"we evaluate\",",
60
+ " \"our evaluation includes\", \"we extend\", \"we build on\", \"we adapt\".",
61
+ "- If the context says the TARGET_PAPER (or some other paper/system) uses/extends something",
62
+ " (e.g., \"TARGET_PAPER uses...\", \"TARGET_PAPER extends...\"),",
63
+ " then it is NOT USES/EXTENSION. Label NOT_CONFIRMED.",
64
+ "",
65
+ "Task: Label each sentence as USES, EXTENDS, or NOT_CONFIRMED.",
66
+ "Return JSON only with one entry per input sentence.",
67
+ "",
68
+ "Definitions:",
69
+ f"- {USES_DEFINITION}",
70
+ f"- {EXTENDS_DEFINITION}",
71
+ f"- {NOTES_DEFINITION}",
72
+ "",
73
+ "Output rules:",
74
+ "- label must be one of: USES, EXTENDS, NOT_CONFIRMED",
75
+ "- cue_span: exact substring from the sentence that justifies USES/EXTENDS, else empty",
76
+ "- rationale: one short sentence",
77
+ "- If cue_span is empty => label must be NOT_CONFIRMED",
78
+ "",
79
+ "Few-shot examples:",
80
+ "USES:",
81
+ ]
82
+ for ex in FEW_SHOT_USES:
83
+ header.append(f"- {ex}")
84
+ header.append("EXTENDS:")
85
+ for ex in FEW_SHOT_EXTENDS:
86
+ header.append(f"- {ex}")
87
+ header.append("NOT_CONFIRMED:")
88
+ for ex in FEW_SHOT_NOT_CONFIRMED:
89
+ header.append(f"- {ex}")
90
+
91
+ header.extend(
92
+ [
93
+ "",
94
+ "TARGET_PAPER:",
95
+ f"- title: {target_info.get('title', '')}",
96
+ f"- first_author_last: {target_info.get('first_author_last', '')}",
97
+ f"- year: {target_info.get('year', '')}",
98
+ "",
99
+ "CANDIDATES:",
100
+ ]
101
+ )
102
+
103
+ for item in candidates:
104
+ header.extend(
105
+ [
106
+ f"ID: {item['id']}",
107
+ f"Citing paper: {item.get('citing_title', '')}",
108
+ f"Sentence: {item.get('text', '')}",
109
+ "",
110
+ ]
111
+ )
112
+
113
+ header.append("JSON OUTPUT:")
114
+ header.append("{\"labels\": [{\"id\": 1, \"label\": \"USES\", \"cue_span\": \"...\", \"rationale\": \"...\"}]}")
115
+ return "\n".join(header)
src/step_05_verify_uses_extends/schemas.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ USES_EXTENDS_VERIFICATION_JSON_SCHEMA = {
2
+ "type": "object",
3
+ "properties": {
4
+ "labels": {
5
+ "type": "array",
6
+ "items": {
7
+ "type": "object",
8
+ "properties": {
9
+ "id": {"type": "integer"},
10
+ "label": {
11
+ "type": "string",
12
+ "enum": ["USES", "EXTENDS", "NOT_CONFIRMED"],
13
+ },
14
+ "cue_span": {"type": "string"},
15
+ "rationale": {"type": "string"},
16
+ },
17
+ "required": ["id", "label", "cue_span", "rationale"],
18
+ },
19
+ },
20
+ },
21
+ "required": ["labels"],
22
+ }
src/step_05_verify_uses_extends/verify_uses_extends.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List
6
+
7
+ SRC_ROOT = Path(__file__).resolve().parents[1]
8
+ if str(SRC_ROOT) not in sys.path:
9
+ sys.path.insert(0, str(SRC_ROOT))
10
+
11
+ from common.llm_client import LLMClient
12
+
13
+ from prompts import build_uses_extends_verification_prompt
14
+ from schemas import USES_EXTENDS_VERIFICATION_JSON_SCHEMA
15
+
16
+
17
+ PAPER_META_FILE = "paper_metadata.json"
18
+ USAGE_LABELS_FILE = "usage_context_labels.json"
19
+ OUT_FILE = "usage_uses_extends_verified.json"
20
+
21
+ USE_LABELS = {"Uses", "Extends"}
22
+
23
+
24
+ def load_json(path: Path) -> Any | None:
25
+ if not path.exists():
26
+ return None
27
+ try:
28
+ return json.loads(path.read_text(encoding="utf-8"))
29
+ except Exception:
30
+ return None
31
+
32
+
33
+ def iter_paper_dirs(root: Path) -> List[Path]:
34
+ out: List[Path] = []
35
+ for child in root.iterdir():
36
+ if child.is_dir() and (child / PAPER_META_FILE).exists():
37
+ out.append(child)
38
+ return out
39
+
40
+
41
+ def _normalize_author_last(name: str) -> str:
42
+ parts = [p for p in (name or "").split() if p.strip()]
43
+ return parts[-1] if parts else ""
44
+
45
+
46
+ def extract_target_info(meta: Any) -> Dict[str, str]:
47
+ if isinstance(meta, list) and meta:
48
+ meta = meta[0]
49
+ if not isinstance(meta, dict):
50
+ return {"title": "", "first_author_last": "", "year": ""}
51
+ authors = meta.get("authors") or []
52
+ first_author = authors[0]["name"] if authors else ""
53
+ return {
54
+ "title": meta.get("title", ""),
55
+ "first_author_last": _normalize_author_last(first_author),
56
+ "year": str(meta.get("year", "")),
57
+ }
58
+
59
+
60
+ def verify_candidates(
61
+ client: LLMClient,
62
+ target_info: Dict[str, str],
63
+ candidates: List[Dict[str, Any]],
64
+ ) -> List[Dict[str, Any]]:
65
+ prompt = build_uses_extends_verification_prompt(target_info, candidates)
66
+ try:
67
+ raw = client.call(prompt, schema=USES_EXTENDS_VERIFICATION_JSON_SCHEMA)
68
+ except Exception as exc:
69
+ print(f"[WARN] LLM call failed: {exc}. Marking all candidates NOT_CONFIRMED.")
70
+ return [
71
+ {
72
+ "id": item.get("id"),
73
+ "label": "NOT_CONFIRMED",
74
+ "cue_span": "",
75
+ "rationale": "",
76
+ "text": item.get("text", ""),
77
+ "citing_paper_id": item.get("citing_paper_id", ""),
78
+ "citing_title": item.get("citing_title", ""),
79
+ "original_label": item.get("original_label", ""),
80
+ }
81
+ for item in candidates
82
+ ]
83
+ data = _parse_llm_json(raw)
84
+ if not isinstance(data, dict):
85
+ print("[WARN] Failed to parse LLM JSON response; marking all candidates NOT_CONFIRMED.")
86
+ return [
87
+ {
88
+ "id": item.get("id"),
89
+ "label": "NOT_CONFIRMED",
90
+ "cue_span": "",
91
+ "rationale": "",
92
+ "text": item.get("text", ""),
93
+ "citing_paper_id": item.get("citing_paper_id", ""),
94
+ "citing_title": item.get("citing_title", ""),
95
+ "original_label": item.get("original_label", ""),
96
+ }
97
+ for item in candidates
98
+ ]
99
+ labels = data.get("labels", [])
100
+ by_id = {item.get("id"): item for item in labels if isinstance(item, dict)}
101
+
102
+ verified: List[Dict[str, Any]] = []
103
+ for candidate in candidates:
104
+ item_id = candidate["id"]
105
+ model = by_id.get(item_id, {})
106
+ label = model.get("label", "NOT_CONFIRMED")
107
+ cue_span = model.get("cue_span", "")
108
+ if not cue_span:
109
+ label = "NOT_CONFIRMED"
110
+ verified.append(
111
+ {
112
+ "id": item_id,
113
+ "label": label,
114
+ "cue_span": cue_span,
115
+ "rationale": model.get("rationale", ""),
116
+ "text": candidate.get("text", ""),
117
+ "citing_paper_id": candidate.get("citing_paper_id", ""),
118
+ "citing_title": candidate.get("citing_title", ""),
119
+ "original_label": candidate.get("original_label", ""),
120
+ }
121
+ )
122
+ return verified
123
+
124
+
125
+ def _parse_llm_json(raw: str) -> Any | None:
126
+ try:
127
+ return json.loads(raw)
128
+ except json.JSONDecodeError:
129
+ pass
130
+
131
+ cleaned = raw.strip()
132
+ if cleaned.startswith("```"):
133
+ cleaned = cleaned.strip("`")
134
+ cleaned = cleaned.replace("json", "", 1).strip()
135
+
136
+ start = cleaned.find("{")
137
+ end = cleaned.rfind("}")
138
+ if start == -1 or end == -1 or end <= start:
139
+ return None
140
+
141
+ snippet = cleaned[start : end + 1]
142
+ try:
143
+ return json.loads(snippet)
144
+ except json.JSONDecodeError:
145
+ return None
146
+
147
+
148
+ def process_paper(
149
+ paper_dir: Path,
150
+ client: LLMClient,
151
+ k: int,
152
+ batch_size: int,
153
+ overwrite: bool,
154
+ resume: bool,
155
+ ) -> str:
156
+ labels_path = paper_dir / USAGE_LABELS_FILE
157
+ payload = load_json(labels_path)
158
+ if not isinstance(payload, dict):
159
+ return "missing_labels"
160
+
161
+ out_path = paper_dir / OUT_FILE
162
+ if out_path.exists() and (resume or not overwrite):
163
+ return "skipped"
164
+
165
+ labels = payload.get("labels", [])
166
+ candidates_all = []
167
+ for item in labels:
168
+ if item.get("label") in USE_LABELS:
169
+ candidates_all.append(
170
+ {
171
+ "id": item.get("id"),
172
+ "text": item.get("text", ""),
173
+ "citing_paper_id": item.get("citing_paper_id", ""),
174
+ "citing_title": item.get("citing_title", ""),
175
+ "original_label": item.get("label"),
176
+ "confidence": float(item.get("confidence", 0.0) or 0.0),
177
+ }
178
+ )
179
+
180
+ if not candidates_all:
181
+ result = {
182
+ "paper_id": payload.get("paper_id"),
183
+ "target": {},
184
+ "candidates_total": 0,
185
+ "candidates_considered": 0,
186
+ "verified": [],
187
+ "confirmed": [],
188
+ }
189
+ out_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
190
+ return "no_candidates"
191
+
192
+ # Keep top-k highest-confidence USES/EXTENDS contexts for LLM verification.
193
+ # If k <= 0, verify all candidates.
194
+ candidates_all = sorted(
195
+ candidates_all,
196
+ key=lambda x: x.get("confidence", 0.0),
197
+ reverse=True,
198
+ )
199
+ candidates = candidates_all if k <= 0 else candidates_all[:k]
200
+
201
+ target_info = extract_target_info(load_json(paper_dir / PAPER_META_FILE))
202
+ verified: List[Dict[str, Any]] = []
203
+ if batch_size <= 0:
204
+ batch_size = 25
205
+ for i in range(0, len(candidates), batch_size):
206
+ batch = candidates[i : i + batch_size]
207
+ verified.extend(verify_candidates(client, target_info, batch))
208
+ confirmed = [v for v in verified if v["label"] in {"USES", "EXTENDS"}]
209
+ if any(item["label"] == "EXTENDS" for item in confirmed):
210
+ final_label = "EXTENDS"
211
+ elif confirmed:
212
+ final_label = "USES"
213
+ else:
214
+ final_label = "NOT_CONFIRMED"
215
+
216
+ result = {
217
+ "paper_id": payload.get("paper_id"),
218
+ "target": target_info,
219
+ "candidates_total": len(candidates_all),
220
+ "candidates_considered": len(candidates),
221
+ "verification_batch_size": int(batch_size),
222
+ "verification_num_batches": (len(candidates) + batch_size - 1) // batch_size if candidates else 0,
223
+ "candidates_selected": len(confirmed),
224
+ "verified": verified,
225
+ "confirmed": confirmed,
226
+ "confirmed_extends": sum(1 for x in confirmed if x.get("label") == "EXTENDS"),
227
+ "confirmed_uses": sum(1 for x in confirmed if x.get("label") == "USES"),
228
+ "final_label": final_label,
229
+ }
230
+ out_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
231
+ return "verified"
232
+
233
+
234
+ def main() -> None:
235
+ parser = argparse.ArgumentParser(
236
+ description="Verify USES/EXTENDS candidates via LLM and select top-K."
237
+ )
238
+ parser.add_argument(
239
+ "--root",
240
+ type=str,
241
+ default="runs/processed_papers",
242
+ help="Root directory containing processed paper directories.",
243
+ )
244
+ parser.add_argument(
245
+ "--k",
246
+ type=int,
247
+ default=0,
248
+ help="Verify top-k USES/EXTENDS candidates ranked by classifier confidence (<=0 means all).",
249
+ )
250
+ parser.add_argument(
251
+ "--batch-size",
252
+ type=int,
253
+ default=25,
254
+ help="Number of candidates per LLM verification batch.",
255
+ )
256
+ parser.add_argument(
257
+ "--overwrite",
258
+ action="store_true",
259
+ help="Overwrite existing usage_uses_extends_verified.json files.",
260
+ )
261
+ parser.add_argument(
262
+ "--resume",
263
+ action="store_true",
264
+ help="Skip papers with existing output files (even if --overwrite is set).",
265
+ )
266
+ args = parser.parse_args()
267
+
268
+ root = Path(args.root).expanduser().resolve()
269
+ if not root.exists():
270
+ raise SystemExit(f"Root directory does not exist: {root}")
271
+
272
+ client = LLMClient()
273
+ paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
274
+ print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
275
+
276
+ counts = {"verified": 0, "skipped": 0, "missing_labels": 0, "no_candidates": 0}
277
+ for paper_dir in paper_dirs:
278
+ status = process_paper(
279
+ paper_dir,
280
+ client,
281
+ args.k,
282
+ args.batch_size,
283
+ args.overwrite,
284
+ args.resume,
285
+ )
286
+ counts[status] = counts.get(status, 0) + 1
287
+ print(f"[{status.upper()}] {paper_dir.name}")
288
+
289
+ print(
290
+ "[SUMMARY] verified={verified}, skipped={skipped}, missing_labels={missing_labels}, "
291
+ "no_candidates={no_candidates}".format(**counts)
292
+ )
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()
src/step_06_extract_paragraphs/extract_arxiv_paragraphs.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import random
4
+ import re
5
+ import sys
6
+ import tarfile
7
+ import tempfile
8
+ import time
9
+ import urllib.request
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+ import os
13
+
14
+ SRC_ROOT = Path(__file__).resolve().parents[1]
15
+ if str(SRC_ROOT) not in sys.path:
16
+ sys.path.insert(0, str(SRC_ROOT))
17
+
18
+
19
+ PAPER_META_FILE = "paper_metadata.json"
20
+ USAGE_CONTEXTS_FILE = "usage_contexts.json"
21
+ VERIFIED_FILE = "usage_uses_extends_verified.json"
22
+ OUT_FILE = "usage_citing_paragraphs.json"
23
+
24
+
25
+ def load_json(path: Path) -> Any | None:
26
+ if not path.exists():
27
+ return None
28
+ try:
29
+ return json.loads(path.read_text(encoding="utf-8"))
30
+ except Exception:
31
+ return None
32
+
33
+
34
+ def iter_paper_dirs(root: Path) -> List[Path]:
35
+ out: List[Path] = []
36
+ for child in root.iterdir():
37
+ if child.is_dir() and (child / PAPER_META_FILE).exists():
38
+ out.append(child)
39
+ return out
40
+
41
+
42
+ def safe_extract(tar: tarfile.TarFile, path: Path) -> None:
43
+ for member in tar.getmembers():
44
+ member_path = path / member.name
45
+ if not str(member_path.resolve()).startswith(str(path.resolve())):
46
+ raise RuntimeError(f"Blocked path traversal in tar: {member.name}")
47
+ tar.extractall(path)
48
+
49
+
50
+ _ARXIV_LAST_TS = 0.0
51
+
52
+
53
+ def _arxiv_min_interval_sleep() -> None:
54
+ """Global throttle to avoid arXiv API rate limits."""
55
+ global _ARXIV_LAST_TS
56
+ min_interval = float(os.getenv("ARXIV_MIN_INTERVAL", "1.0"))
57
+ now = time.monotonic()
58
+ elapsed = now - _ARXIV_LAST_TS
59
+ if elapsed < min_interval:
60
+ time.sleep(min_interval - elapsed)
61
+ _ARXIV_LAST_TS = time.monotonic()
62
+
63
+
64
+ def download_arxiv_source(arxiv_id: str, tmpdir: Path) -> Optional[Path]:
65
+ url = f"https://arxiv.org/e-print/{arxiv_id}"
66
+ archive_path = tmpdir / f"{arxiv_id.replace('/', '_')}.tar"
67
+ max_retries = int(os.getenv("ARXIV_MAX_RETRIES", "6"))
68
+ base_sleep = float(os.getenv("ARXIV_BASE_SLEEP", "2.0"))
69
+ max_sleep = float(os.getenv("ARXIV_MAX_BACKOFF", "60"))
70
+
71
+ for attempt in range(max_retries):
72
+ try:
73
+ _arxiv_min_interval_sleep()
74
+ urllib.request.urlretrieve(url, archive_path) # noqa: S310
75
+ try:
76
+ with tarfile.open(archive_path) as tar:
77
+ safe_extract(tar, tmpdir)
78
+ return tmpdir
79
+ except tarfile.ReadError as exc:
80
+ print(f"[WARN] Invalid arXiv archive for {arxiv_id}: {exc}")
81
+ return None
82
+ except Exception as exc:
83
+ # arXiv sometimes returns 429; treat any network error as retryable.
84
+ sleep = min(base_sleep * (2 ** attempt), max_sleep) + random.uniform(0.0, 0.5)
85
+ print(f"[WARN] Failed to download arXiv source for {arxiv_id}: {exc}")
86
+ print(f"[WARN] arXiv download retrying in {sleep:.2f}s")
87
+ time.sleep(sleep)
88
+ continue
89
+
90
+ print(f"[ERROR] Giving up after {max_retries} attempts for arXiv {arxiv_id}")
91
+ return None
92
+
93
+
94
+ def find_main_tex(root: Path) -> Optional[Path]:
95
+ tex_files = list(root.rglob("*.tex"))
96
+ if not tex_files:
97
+ return None
98
+
99
+ candidates: List[Tuple[int, Path]] = []
100
+ for path in tex_files:
101
+ try:
102
+ text = path.read_text(encoding="utf-8", errors="ignore")
103
+ except Exception:
104
+ continue
105
+ score = 0
106
+ if "\\begin{document}" in text:
107
+ score += 3
108
+ if "\\documentclass" in text:
109
+ score += 2
110
+ score += len(text) // 1000
111
+ candidates.append((score, path))
112
+
113
+ candidates.sort(key=lambda x: x[0], reverse=True)
114
+ return candidates[0][1] if candidates else None
115
+
116
+
117
+ def read_bib_files(root: Path) -> Dict[str, str]:
118
+ bibs: Dict[str, str] = {}
119
+ for path in root.rglob("*.bib"):
120
+ try:
121
+ bibs[str(path.relative_to(root))] = path.read_text(encoding="utf-8", errors="ignore")
122
+ except Exception:
123
+ continue
124
+ return bibs
125
+
126
+
127
+ def normalize_text(text: str) -> str:
128
+ text = re.sub(r"[^a-z0-9\s]", " ", text.lower())
129
+ return re.sub(r"\s+", " ", text).strip()
130
+
131
+
132
+ def tokenize(text: str) -> List[str]:
133
+ return [t for t in normalize_text(text).split() if t]
134
+
135
+
136
+ def paragraphize(text: str) -> List[str]:
137
+ text = text.replace("\r\n", "\n")
138
+ text = re.sub(r"\n\s*\n", "\n\n", text)
139
+ paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
140
+ return paragraphs
141
+
142
+
143
+ def strip_latex_comments(text: str) -> str:
144
+ # Remove explicit comment environments first.
145
+ text = re.sub(r"\\begin\{comment\}.*?\\end\{comment\}", "", text, flags=re.DOTALL)
146
+
147
+ cleaned_lines: List[str] = []
148
+ for line in text.splitlines():
149
+ out_chars: List[str] = []
150
+ i = 0
151
+ while i < len(line):
152
+ ch = line[i]
153
+ if ch == "%":
154
+ # Keep escaped percent (\%) and continue parsing.
155
+ if i > 0 and line[i - 1] == "\\":
156
+ out_chars.append(ch)
157
+ i += 1
158
+ continue
159
+ # Unescaped percent starts a LaTeX comment; ignore rest of the line.
160
+ break
161
+ out_chars.append(ch)
162
+ i += 1
163
+ cleaned_lines.append("".join(out_chars))
164
+ return "\n".join(cleaned_lines)
165
+
166
+
167
+ def parse_bib_entries(bib_text: str) -> List[Dict[str, str]]:
168
+ entries: List[Dict[str, str]] = []
169
+ matches = list(re.finditer(r"@[\w]+\s*\{\s*([^,]+),", bib_text))
170
+ for i, match in enumerate(matches):
171
+ key = match.group(1).strip()
172
+ start = match.end()
173
+ end = matches[i + 1].start() if i + 1 < len(matches) else len(bib_text)
174
+ body = bib_text[start:end]
175
+ fields = {}
176
+ for f_match in re.finditer(r"(\w+)\s*=\s*[{|\"](.+?)[}|\"]\s*,", body, re.DOTALL):
177
+ fields[f_match.group(1).lower()] = f_match.group(2).strip()
178
+ entries.append({"key": key, **fields})
179
+ return entries
180
+
181
+
182
+ def find_target_bib_keys(
183
+ bib_texts: Dict[str, str],
184
+ target_info: Dict[str, str],
185
+ ) -> List[str]:
186
+ target_title = normalize_text(target_info.get("title", ""))
187
+ target_author = normalize_text(target_info.get("first_author_last", ""))
188
+ target_year = target_info.get("year", "")
189
+ if not target_title and not target_author:
190
+ return []
191
+
192
+ keys: List[str] = []
193
+ for bib_text in bib_texts.values():
194
+ for entry in parse_bib_entries(bib_text):
195
+ title = normalize_text(entry.get("title", ""))
196
+ author = normalize_text(entry.get("author", ""))
197
+ year = str(entry.get("year", ""))
198
+ has_title = bool(title)
199
+ title_match = target_title and (target_title in title or title in target_title)
200
+ author_match = target_author and target_author in author
201
+ year_match = target_year and target_year in year
202
+
203
+ if title_match and author_match:
204
+ keys.append(entry["key"])
205
+ elif not has_title and author_match and year_match:
206
+ keys.append(entry["key"])
207
+ elif author_match and year_match:
208
+ keys.append(entry["key"])
209
+ return keys
210
+
211
+
212
+ def replace_target_citations(text: str, target_keys: List[str], target_info: Dict[str, str]) -> str:
213
+ key_set = set(target_keys or [])
214
+ author = target_info.get("first_author_last", "").lower()
215
+ year = target_info.get("year", "")
216
+ alt_years = {year}
217
+ if year.isdigit():
218
+ alt_years.add(str(int(year) - 1))
219
+ alt_years.add(str(int(year) + 1))
220
+
221
+ def repl(match: re.Match) -> str:
222
+ keys = [k.strip() for k in match.group(1).split(",")]
223
+ for key in keys:
224
+ if key in key_set:
225
+ return "<CITED HERE>"
226
+ key_lc = key.lower()
227
+ if author and author in key_lc and any(y in key_lc for y in alt_years if y):
228
+ return "<CITED HERE>"
229
+ return match.group(0)
230
+
231
+ return re.sub(r"\\cite[a-zA-Z]*\s*\{([^}]+)\}", repl, text)
232
+
233
+
234
+ def match_paragraphs(
235
+ paragraphs: List[str],
236
+ contexts: List[Dict[str, str]],
237
+ ) -> List[Dict[str, Any]]:
238
+ results: List[Dict[str, Any]] = []
239
+ para_tokens = [set(tokenize(p)) for p in paragraphs]
240
+
241
+ for idx, ctx in enumerate(contexts, start=1):
242
+ ctx_text = ctx.get("text", "")
243
+ ctx_tokens = set(tokenize(ctx_text))
244
+ if not ctx_tokens:
245
+ continue
246
+ best = None
247
+ best_score = 0.0
248
+ for p_idx, tokens in enumerate(para_tokens):
249
+ if not tokens:
250
+ continue
251
+ overlap = len(ctx_tokens & tokens) / max(1, len(ctx_tokens))
252
+ if overlap > best_score:
253
+ best = p_idx
254
+ best_score = overlap
255
+ if best is not None and best_score >= 0.5:
256
+ paragraph = paragraphs[best]
257
+ results.append(
258
+ {
259
+ "context_id": idx,
260
+ "context": ctx_text,
261
+ "context_with_marker": ctx.get("text_with_marker", ctx_text),
262
+ "paragraph": paragraph,
263
+ "overlap": round(best_score, 3),
264
+ }
265
+ )
266
+ return results
267
+
268
+
269
+ def _normalize_text(text: str) -> str:
270
+ return " ".join(text.split()).strip().lower()
271
+
272
+
273
+ def _normalize_for_match(text: str) -> str:
274
+ text = text.replace("<CITED HERE>", "")
275
+ text = re.sub(r"\[[^\]]+\]", "", text)
276
+ return _normalize_text(text)
277
+
278
+
279
+ def _normalize_author_last(name: str) -> str:
280
+ parts = [p for p in (name or "").split() if p.strip()]
281
+ return parts[-1] if parts else ""
282
+
283
+
284
+ def extract_target_info(meta: Any) -> Dict[str, str]:
285
+ if isinstance(meta, list) and meta:
286
+ meta = meta[0]
287
+ if not isinstance(meta, dict):
288
+ return {"title": "", "first_author_last": "", "year": ""}
289
+ authors = meta.get("authors") or []
290
+ first_author = authors[0]["name"] if authors else ""
291
+ return {
292
+ "title": meta.get("title", ""),
293
+ "first_author_last": _normalize_author_last(first_author),
294
+ "year": str(meta.get("year", "")),
295
+ }
296
+
297
+
298
+ def build_citing_contexts_map(
299
+ usage: Dict[str, Any],
300
+ confirmed_texts_by_citing: Dict[str, set] | None,
301
+ ) -> Dict[str, Dict[str, Any]]:
302
+ citing_map: Dict[str, Dict[str, Any]] = {}
303
+ for entry in usage.get("citing_papers", []) or []:
304
+ if not isinstance(entry, dict):
305
+ continue
306
+ citing_id = entry.get("citing_paper_id") or ""
307
+ allowed_texts = confirmed_texts_by_citing.get(citing_id) if confirmed_texts_by_citing else None
308
+ allowed_norms = (
309
+ {_normalize_for_match(text) for text in allowed_texts} if allowed_texts else None
310
+ )
311
+ contexts = []
312
+ seen = set()
313
+ for c in entry.get("contexts", []) or []:
314
+ if not isinstance(c, dict):
315
+ continue
316
+ text_raw = (c.get("text") or "").strip()
317
+ text_with_marker = (c.get("context_with_marker") or text_raw).strip()
318
+ if not text_raw:
319
+ continue
320
+ norm = _normalize_for_match(text_raw)
321
+ if allowed_norms is not None and norm not in allowed_norms:
322
+ continue
323
+ if norm in seen:
324
+ continue
325
+ seen.add(norm)
326
+ contexts.append({"text": text_raw, "text_with_marker": text_with_marker})
327
+ if allowed_texts is not None and not contexts:
328
+ for text in allowed_texts:
329
+ norm = _normalize_for_match(text)
330
+ if norm in seen:
331
+ continue
332
+ seen.add(norm)
333
+ contexts.append({"text": text, "text_with_marker": text})
334
+ citing_map[citing_id] = {
335
+ "title": entry.get("title", ""),
336
+ "paper_id": citing_id,
337
+ "arxiv_id": (entry.get("external_ids") or {}).get("ArXiv", ""),
338
+ "contexts": contexts,
339
+ }
340
+ return citing_map
341
+
342
+
343
+ def process_citing_paper(citing: Dict[str, Any]) -> Dict[str, Any]:
344
+ target_info = citing.get("target_info", {})
345
+ arxiv_id = citing.get("arxiv_id", "")
346
+ if not arxiv_id:
347
+ return {"error": "missing_arxiv_id", **citing}
348
+
349
+ with tempfile.TemporaryDirectory() as tmp:
350
+ tmpdir = Path(tmp)
351
+ if not download_arxiv_source(arxiv_id, tmpdir):
352
+ return {"error": "bad_arxiv_archive", **citing}
353
+ main_tex = find_main_tex(tmpdir)
354
+ if not main_tex:
355
+ return {"error": "missing_main_tex", **citing}
356
+
357
+ tex_text = main_tex.read_text(encoding="utf-8", errors="ignore")
358
+ tex_text = strip_latex_comments(tex_text)
359
+ bibs = read_bib_files(tmpdir)
360
+ target_keys = find_target_bib_keys(bibs, target_info)
361
+ tex_text = replace_target_citations(tex_text, target_keys, target_info)
362
+ paragraphs = paragraphize(tex_text)
363
+ target_citing_paragraphs = [p for p in paragraphs if "<CITED HERE>" in p]
364
+ matched = match_paragraphs(paragraphs, citing.get("contexts", []))
365
+
366
+ return {
367
+ "citing_paper_id": citing.get("paper_id", ""),
368
+ "citing_title": citing.get("title", ""),
369
+ "arxiv_id": arxiv_id,
370
+ "main_tex_file": str(main_tex.relative_to(tmpdir)),
371
+ "bib_files": list(bibs.keys()),
372
+ "bib_texts": bibs,
373
+ "target_bib_keys": target_keys,
374
+ "contexts": citing.get("contexts", []),
375
+ "target_citing_paragraphs": target_citing_paragraphs,
376
+ "matched_paragraphs": matched,
377
+ }
378
+
379
+
380
+ def process_paper(root: Path, overwrite: bool, include_all: bool, resume: bool) -> str:
381
+ usage = load_json(root / USAGE_CONTEXTS_FILE)
382
+ if not isinstance(usage, dict):
383
+ return "missing_usage"
384
+
385
+ out_path = root / OUT_FILE
386
+ if out_path.exists() and (resume or not overwrite):
387
+ return "skipped"
388
+
389
+ verified = None
390
+ confirmed_texts_by_citing: Dict[str, set] = {}
391
+ if not include_all:
392
+ verified = load_json(root / VERIFIED_FILE)
393
+ if not isinstance(verified, dict):
394
+ return "missing_verified"
395
+ for item in verified.get("confirmed", []) or []:
396
+ citing_id = item.get("citing_paper_id") or ""
397
+ text = item.get("text") or ""
398
+ if not citing_id or not text:
399
+ continue
400
+ confirmed_texts_by_citing.setdefault(citing_id, set()).add(text)
401
+
402
+ target_info = extract_target_info(load_json(root / PAPER_META_FILE))
403
+ citing_map = build_citing_contexts_map(
404
+ usage,
405
+ confirmed_texts_by_citing if confirmed_texts_by_citing else None,
406
+ )
407
+ if not citing_map:
408
+ out_path.write_text(
409
+ json.dumps({"paper_id": usage.get("paper_id"), "citing_papers": []}, indent=2),
410
+ encoding="utf-8",
411
+ )
412
+ return "empty_citing"
413
+
414
+ confirmed_ids: Optional[set] = None
415
+ if not include_all and isinstance(verified, dict):
416
+ confirmed = verified.get("confirmed", [])
417
+ confirmed_ids = {
418
+ item.get("citing_paper_id")
419
+ for item in confirmed
420
+ if item.get("citing_paper_id")
421
+ }
422
+
423
+ citing_papers = []
424
+ for citing_id, citing in citing_map.items():
425
+ if confirmed_ids is not None and citing_id not in confirmed_ids:
426
+ continue
427
+ citing["target_info"] = target_info
428
+ citing_papers.append(process_citing_paper(citing))
429
+
430
+ payload = {"paper_id": usage.get("paper_id"), "citing_papers": citing_papers}
431
+ out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
432
+ return "processed"
433
+
434
+
435
+ def main() -> None:
436
+ parser = argparse.ArgumentParser(
437
+ description="Download arXiv sources and extract citation-local paragraphs."
438
+ )
439
+ parser.add_argument(
440
+ "--root",
441
+ type=str,
442
+ default="runs/processed_papers",
443
+ help="Root directory containing processed paper directories.",
444
+ )
445
+ parser.add_argument(
446
+ "--overwrite",
447
+ action="store_true",
448
+ help="Overwrite existing usage_citing_paragraphs.json files.",
449
+ )
450
+ parser.add_argument(
451
+ "--all",
452
+ action="store_true",
453
+ help="Process all citing papers (not just confirmed USES/EXTENDS).",
454
+ )
455
+ parser.add_argument(
456
+ "--resume",
457
+ action="store_true",
458
+ help="Skip papers with existing output files (even if --overwrite is set).",
459
+ )
460
+ args = parser.parse_args()
461
+
462
+ root = Path(args.root).expanduser().resolve()
463
+ if not root.exists():
464
+ raise SystemExit(f"Root directory does not exist: {root}")
465
+
466
+ paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
467
+ print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
468
+
469
+ counts = {
470
+ "processed": 0,
471
+ "skipped": 0,
472
+ "missing_usage": 0,
473
+ "missing_verified": 0,
474
+ "empty_citing": 0,
475
+ }
476
+ for paper_dir in paper_dirs:
477
+ status = process_paper(paper_dir, args.overwrite, args.all, args.resume)
478
+ counts[status] = counts.get(status, 0) + 1
479
+ print(f"[{status.upper()}] {paper_dir.name}")
480
+
481
+ print(
482
+ "[SUMMARY] processed={processed}, skipped={skipped}, missing_usage={missing_usage}, "
483
+ "missing_verified={missing_verified}, empty_citing={empty_citing}".format(**counts)
484
+ )
485
+
486
+
487
+ if __name__ == "__main__":
488
+ main()
src/step_07_extract_and_refine/extract_contributions_from_citations.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List
6
+
7
+ SRC_ROOT = Path(__file__).resolve().parents[1]
8
+ if str(SRC_ROOT) not in sys.path:
9
+ sys.path.insert(0, str(SRC_ROOT))
10
+
11
+ from common.llm_client import LLMClient
12
+
13
+ from prompts import build_contribution_prompt
14
+ from schemas import CONTRIBUTION_JSON_SCHEMA
15
+
16
+
17
+ PAPER_META_FILE = "paper_metadata.json"
18
+ USAGE_CONTEXTS_FILE = "usage_contexts.json"
19
+ ARXIV_PARAGRAPHS_FILE = "usage_citing_paragraphs.json"
20
+ VERIFIED_FILE = "usage_uses_extends_verified.json"
21
+ OUT_FILE = "usage_contributions.json"
22
+
23
+
24
+ def load_json(path: Path) -> Any | None:
25
+ if not path.exists():
26
+ return None
27
+ try:
28
+ return json.loads(path.read_text(encoding="utf-8"))
29
+ except Exception:
30
+ return None
31
+
32
+
33
+ def iter_paper_dirs(root: Path) -> List[Path]:
34
+ out: List[Path] = []
35
+ for child in root.iterdir():
36
+ if child.is_dir() and (child / PAPER_META_FILE).exists():
37
+ out.append(child)
38
+ return out
39
+
40
+
41
+ def _normalize_author_last(name: str) -> str:
42
+ parts = [p for p in (name or "").split() if p.strip()]
43
+ return parts[-1] if parts else ""
44
+
45
+
46
+ def extract_target_info(meta: Any) -> Dict[str, str]:
47
+ if isinstance(meta, list) and meta:
48
+ meta = meta[0]
49
+ if not isinstance(meta, dict):
50
+ return {
51
+ "title": "",
52
+ "first_author_last": "",
53
+ "year": "",
54
+ "tldr": "",
55
+ "abstract": "",
56
+ }
57
+ authors = meta.get("authors") or []
58
+ first_author = authors[0]["name"] if authors else ""
59
+ tldr = ""
60
+ tldr_obj = meta.get("tldr")
61
+ if isinstance(tldr_obj, dict):
62
+ tldr = tldr_obj.get("text", "")
63
+ return {
64
+ "title": meta.get("title", ""),
65
+ "first_author_last": _normalize_author_last(first_author),
66
+ "year": str(meta.get("year", "")),
67
+ "tldr": tldr,
68
+ "abstract": meta.get("abstract", ""),
69
+ }
70
+
71
+
72
+ def build_citing_contexts_map_from_paragraphs(
73
+ arxiv_data: Dict[str, Any],
74
+ ) -> Dict[str, Dict[str, Any]]:
75
+ citing_map: Dict[str, Dict[str, Any]] = {}
76
+ for entry in arxiv_data.get("citing_papers", []) or []:
77
+ if not isinstance(entry, dict):
78
+ continue
79
+ citing_id = entry.get("citing_paper_id") or ""
80
+ contexts = []
81
+ seen = set()
82
+ for paragraph in entry.get("target_citing_paragraphs", []) or []:
83
+ paragraph = (paragraph or "").strip()
84
+ if not paragraph:
85
+ continue
86
+ combined = f"Target-citing paragraph: {paragraph}"
87
+ norm = " ".join(combined.split()).lower()
88
+ if norm in seen:
89
+ continue
90
+ seen.add(norm)
91
+ contexts.append(combined)
92
+ citing_map[citing_id] = {
93
+ "title": entry.get("citing_title", ""),
94
+ "paper_id": citing_id,
95
+ "contexts": contexts,
96
+ "source": "arxiv_paragraphs",
97
+ }
98
+ return citing_map
99
+
100
+
101
+ def build_citing_contexts_map_from_usage(
102
+ usage: Dict[str, Any],
103
+ confirmed_texts_by_citing: Dict[str, set] | None,
104
+ ) -> Dict[str, Dict[str, Any]]:
105
+ citing_map: Dict[str, Dict[str, Any]] = {}
106
+ for entry in usage.get("citing_papers", []) or []:
107
+ if not isinstance(entry, dict):
108
+ continue
109
+ citing_id = entry.get("citing_paper_id") or ""
110
+ allowed_texts = confirmed_texts_by_citing.get(citing_id) if confirmed_texts_by_citing else None
111
+ contexts = []
112
+ seen = set()
113
+ for c in entry.get("contexts", []) or []:
114
+ if not isinstance(c, dict):
115
+ continue
116
+ text = (c.get("context_with_marker") or c.get("text") or "").strip()
117
+ if not text:
118
+ continue
119
+ if allowed_texts is not None and text not in allowed_texts:
120
+ continue
121
+ norm = " ".join(text.split()).lower()
122
+ if norm in seen:
123
+ continue
124
+ seen.add(norm)
125
+ contexts.append(f"Target sentence: {text}")
126
+ citing_map[citing_id] = {
127
+ "title": entry.get("title", ""),
128
+ "paper_id": citing_id,
129
+ "contexts": contexts,
130
+ "source": "usage_contexts_fallback",
131
+ }
132
+ return citing_map
133
+
134
+
135
+ def extract_contribution(
136
+ client: LLMClient,
137
+ target_info: Dict[str, str],
138
+ citing_info: Dict[str, Any],
139
+ ) -> Dict[str, Any]:
140
+ contexts = citing_info.get("contexts", [])
141
+ prompt = build_contribution_prompt(target_info, citing_info, contexts)
142
+ raw = client.call(prompt, schema=CONTRIBUTION_JSON_SCHEMA)
143
+ data = _parse_llm_json(raw)
144
+ if not isinstance(data, dict):
145
+ return {
146
+ "citing_paper_id": citing_info.get("paper_id", ""),
147
+ "citing_title": citing_info.get("title", ""),
148
+ "label": "NOT_CONFIRMED",
149
+ "paper_claim": "",
150
+ "claim": "",
151
+ "cluster_title": "",
152
+ "cluster_key": "",
153
+ "evidence_span": "",
154
+ "rationale": "",
155
+ "contexts": contexts,
156
+ "source": citing_info.get("source", "unknown"),
157
+ }
158
+ label = data.get("label", "NOT_CONFIRMED")
159
+ paper_claim = data.get("paper_claim", "") or data.get("claim", "")
160
+ cluster_title = data.get("cluster_title", "") or data.get("cluster_claim", "")
161
+ cluster_key = data.get("cluster_key", "")
162
+ evidence_span = data.get("evidence_span", "")
163
+ if not evidence_span:
164
+ label = "NOT_CONFIRMED"
165
+ paper_claim = ""
166
+ cluster_title = ""
167
+ cluster_key = ""
168
+ if label in {"USES", "EXTENDS"} and not cluster_title:
169
+ cluster_title = paper_claim
170
+ if label in {"USES", "EXTENDS"} and not cluster_key:
171
+ cluster_key = f"{label}|contribution|unspecified"
172
+ return {
173
+ "citing_paper_id": citing_info.get("paper_id", ""),
174
+ "citing_title": citing_info.get("title", ""),
175
+ "label": label,
176
+ "paper_claim": paper_claim,
177
+ "claim": paper_claim,
178
+ "cluster_title": cluster_title,
179
+ "cluster_key": cluster_key,
180
+ "evidence_span": evidence_span,
181
+ "rationale": data.get("rationale", ""),
182
+ "contexts": contexts,
183
+ "source": citing_info.get("source", "unknown"),
184
+ }
185
+
186
+
187
+ def _parse_llm_json(raw: str) -> Any | None:
188
+ try:
189
+ return json.loads(raw)
190
+ except json.JSONDecodeError:
191
+ pass
192
+
193
+ cleaned = raw.strip()
194
+ if cleaned.startswith("```"):
195
+ cleaned = cleaned.strip("`")
196
+ cleaned = cleaned.replace("json", "", 1).strip()
197
+
198
+ start = cleaned.find("{")
199
+ end = cleaned.rfind("}")
200
+ if start == -1 or end == -1 or end <= start:
201
+ return None
202
+
203
+ snippet = cleaned[start : end + 1]
204
+ try:
205
+ return json.loads(snippet)
206
+ except json.JSONDecodeError:
207
+ return None
208
+
209
+
210
+ def process_paper(
211
+ paper_dir: Path,
212
+ client: LLMClient,
213
+ overwrite: bool,
214
+ resume: bool,
215
+ ) -> str:
216
+ verified = load_json(paper_dir / VERIFIED_FILE)
217
+ if not isinstance(verified, dict):
218
+ return "missing_verified"
219
+ out_path = paper_dir / OUT_FILE
220
+ if out_path.exists() and (resume or not overwrite):
221
+ return "skipped"
222
+
223
+ if verified.get("final_label") == "NOT_CONFIRMED":
224
+ payload = {
225
+ "paper_id": verified.get("paper_id"),
226
+ "final_label": "NOT_CONFIRMED",
227
+ "contributions": [],
228
+ }
229
+ out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
230
+ return "no_confirmed"
231
+
232
+ arxiv_data = load_json(paper_dir / ARXIV_PARAGRAPHS_FILE)
233
+ if not isinstance(arxiv_data, dict):
234
+ return "missing_arxiv_paragraphs"
235
+
236
+ target_info = extract_target_info(load_json(paper_dir / PAPER_META_FILE))
237
+ citing_map = build_citing_contexts_map_from_paragraphs(arxiv_data)
238
+ usage = load_json(paper_dir / USAGE_CONTEXTS_FILE)
239
+ confirmed_texts_by_citing: Dict[str, set] = {}
240
+ for item in verified.get("confirmed", []) or []:
241
+ citing_id = item.get("citing_paper_id") or ""
242
+ text = item.get("text") or ""
243
+ if not citing_id or not text:
244
+ continue
245
+ confirmed_texts_by_citing.setdefault(citing_id, set()).add(text)
246
+ usage_map = (
247
+ build_citing_contexts_map_from_usage(usage, confirmed_texts_by_citing)
248
+ if isinstance(usage, dict)
249
+ else {}
250
+ )
251
+
252
+ confirmed = verified.get("confirmed", [])
253
+ confirmed_ids = {item.get("citing_paper_id") for item in confirmed if item.get("citing_paper_id")}
254
+ contributions: List[Dict[str, Any]] = []
255
+ fallback_citing_ids: List[str] = []
256
+ for citing_id in confirmed_ids:
257
+ citing_info = citing_map.get(citing_id)
258
+ if citing_info and not citing_info.get("contexts"):
259
+ citing_info = None
260
+ if not citing_info:
261
+ fallback = usage_map.get(citing_id)
262
+ if fallback and fallback.get("contexts"):
263
+ citing_info = fallback
264
+ fallback_citing_ids.append(citing_id)
265
+ else:
266
+ continue
267
+ contributions.append(extract_contribution(client, target_info, citing_info))
268
+
269
+ payload = {
270
+ "paper_id": verified.get("paper_id"),
271
+ "final_label": verified.get("final_label"),
272
+ "contributions": contributions,
273
+ "source": "arxiv_paragraphs",
274
+ "fallback_citing_ids": fallback_citing_ids,
275
+ }
276
+ out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
277
+ return "labeled"
278
+
279
+
280
+ def main() -> None:
281
+ parser = argparse.ArgumentParser(
282
+ description="Extract per-citing-paper contribution claims from verified USES/EXTENDS."
283
+ )
284
+ parser.add_argument(
285
+ "--root",
286
+ type=str,
287
+ default="runs/processed_papers",
288
+ help="Root directory containing processed paper directories.",
289
+ )
290
+ parser.add_argument(
291
+ "--overwrite",
292
+ action="store_true",
293
+ help="Overwrite existing usage_contributions.json files.",
294
+ )
295
+ parser.add_argument(
296
+ "--resume",
297
+ action="store_true",
298
+ help="Skip papers with existing output files (even if --overwrite is set).",
299
+ )
300
+ args = parser.parse_args()
301
+
302
+ root = Path(args.root).expanduser().resolve()
303
+ if not root.exists():
304
+ raise SystemExit(f"Root directory does not exist: {root}")
305
+
306
+ client = LLMClient()
307
+ paper_dirs = sorted(iter_paper_dirs(root), key=lambda p: p.name)
308
+ print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
309
+
310
+ counts = {
311
+ "labeled": 0,
312
+ "skipped": 0,
313
+ "missing_verified": 0,
314
+ "missing_arxiv_paragraphs": 0,
315
+ "no_confirmed": 0,
316
+ }
317
+ for paper_dir in paper_dirs:
318
+ status = process_paper(paper_dir, client, args.overwrite, args.resume)
319
+ counts[status] = counts.get(status, 0) + 1
320
+ print(f"[{status.upper()}] {paper_dir.name}")
321
+
322
+ print(
323
+ "[SUMMARY] labeled={labeled}, skipped={skipped}, missing_verified={missing_verified}, "
324
+ "missing_arxiv_paragraphs={missing_arxiv_paragraphs}, no_confirmed={no_confirmed}".format(**counts)
325
+ )
326
+
327
+
328
+ if __name__ == "__main__":
329
+ main()
src/step_07_extract_and_refine/prompts.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+
4
+ def build_contribution_prompt(
5
+ target_info: Dict[str, str],
6
+ citing_info: Dict[str, str],
7
+ contexts: List[str],
8
+ ) -> str:
9
+ header = [
10
+ "You are extracting how a citing paper uses or extends a target paper.",
11
+ "Read the paragraph(s) below and write ONE concise contribution claim.",
12
+ "Focus only on what the citing paper actually does with the target paper.",
13
+ "",
14
+ "Rules:",
15
+ "- If the citing paper explicitly uses/adopts/evaluates on the target paper's method/data/benchmark, label USES.",
16
+ "- If it explicitly extends/modifies/adapts/builds upon the target paper, label EXTENDS.",
17
+ "- If the paragraph is only descriptive/background or only compares/mentions the target paper, return label NOT_CONFIRMED and empty fields.",
18
+ "- Do not output comparison-only claims (e.g., 'compares to <CITED HERE>'); those are NOT_CONFIRMED.",
19
+ "- Output paper_claim: one concise, paper-specific contribution claim.",
20
+ "- Output cluster_title: concise natural-language cluster summary (6-14 words), generic across papers.",
21
+ "- Also output cluster_key in this exact format: RELATION|artifact|purpose",
22
+ "- cluster_key must be generic and reusable across papers.",
23
+ "- artifact and purpose must be short snake_case phrases (e.g., dataset, evaluation_protocol, evaluation).",
24
+ "- cluster_key RELATION must exactly match label.",
25
+ "- Avoid overly specific keys (no paper names, no model/version numbers, no citation keys).",
26
+ "- Prefer stable generic keys such as: USES|dataset|evaluation, EXTENDS|dataset|dataset_creation, USES|evaluation_protocol|evaluation.",
27
+ "- If label is NOT_CONFIRMED, paper_claim, cluster_title, cluster_key, and evidence_span must be empty.",
28
+ "- The evidence_span must be a verbatim substring from the provided contexts.",
29
+ "- The TARGET_PAPER abstract/TLDR is for background only; do not use it as evidence.",
30
+ "",
31
+ "Negative example (NOT_CONFIRMED):",
32
+ "Paragraph: \"We compare our method to <CITED HERE> and other baselines.\"",
33
+ "Output: {\"label\":\"NOT_CONFIRMED\",\"paper_claim\":\"\",\"cluster_title\":\"\",\"cluster_key\":\"\",\"evidence_span\":\"\",\"rationale\":\"Comparison only.\"}",
34
+ "",
35
+ "Return JSON only.",
36
+ "",
37
+ "TARGET_PAPER:",
38
+ f"- title: {target_info.get('title', '')}",
39
+ f"- first_author_last: {target_info.get('first_author_last', '')}",
40
+ f"- year: {target_info.get('year', '')}",
41
+ f"- tldr: {target_info.get('tldr', '')}",
42
+ f"- abstract: {target_info.get('abstract', '')}",
43
+ "",
44
+ "CITING_PAPER:",
45
+ f"- title: {citing_info.get('title', '')}",
46
+ f"- paper_id: {citing_info.get('paper_id', '')}",
47
+ "",
48
+ "CONTEXTS (verbatim, same order as extracted):",
49
+ ]
50
+ for i, text in enumerate(contexts, start=1):
51
+ header.append(f"({i}) {text}")
52
+
53
+ header.append("")
54
+ header.append("JSON OUTPUT:")
55
+ header.append(
56
+ "{"
57
+ "\"label\":\"USES\","
58
+ "\"paper_claim\":\"...\","
59
+ "\"cluster_title\":\"Uses target dataset for evaluation\","
60
+ "\"cluster_key\":\"USES|dataset|evaluation\","
61
+ "\"evidence_span\":\"...\","
62
+ "\"rationale\":\"...\""
63
+ "}"
64
+ )
65
+ return "\n".join(header)
src/step_07_extract_and_refine/refine_and_filter_clusters_llm.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Tuple
6
+
7
+ SRC_ROOT = Path(__file__).resolve().parents[1]
8
+ if str(SRC_ROOT) not in sys.path:
9
+ sys.path.insert(0, str(SRC_ROOT))
10
+
11
+ from common.llm_client import LLMClient
12
+
13
+ PAPER_META_FILE = "paper_metadata.json"
14
+ CONTRIB_FILE = "usage_contributions.json"
15
+ DISCOVERY_FILE = "usage_discovery_from_contributions.json"
16
+ OUT_FILE = "usage_discovery_from_contributions_refined.json"
17
+
18
+ REFINE_SCHEMA = {
19
+ "type": "object",
20
+ "properties": {
21
+ "kept_groups": {
22
+ "type": "array",
23
+ "items": {
24
+ "type": "object",
25
+ "properties": {
26
+ "cluster_ids": {"type": "array", "items": {"type": "string"}},
27
+ "merged_title": {"type": "string"},
28
+ "merged_key": {"type": "string"},
29
+ "rationale": {"type": "string"},
30
+ },
31
+ "required": ["cluster_ids", "merged_title", "merged_key", "rationale"],
32
+ },
33
+ },
34
+ "dropped_clusters": {
35
+ "type": "array",
36
+ "items": {
37
+ "type": "object",
38
+ "properties": {
39
+ "cluster_id": {"type": "string"},
40
+ "reason": {"type": "string"},
41
+ },
42
+ "required": ["cluster_id", "reason"],
43
+ },
44
+ },
45
+ },
46
+ "required": ["kept_groups", "dropped_clusters"],
47
+ }
48
+
49
+
50
+ def load_json(path: Path) -> Any | None:
51
+ if not path.exists():
52
+ return None
53
+ try:
54
+ return json.loads(path.read_text(encoding="utf-8"))
55
+ except Exception:
56
+ return None
57
+
58
+
59
+ def iter_paper_dirs(root: Path) -> List[Path]:
60
+ return sorted([p for p in root.iterdir() if p.is_dir() and (p / PAPER_META_FILE).exists()], key=lambda p: p.name)
61
+
62
+
63
+ def _to_int_indices(raw_indices: List[Any]) -> List[int]:
64
+ out: List[int] = []
65
+ for i in raw_indices or []:
66
+ try:
67
+ out.append(int(i))
68
+ except Exception:
69
+ continue
70
+ return out
71
+
72
+
73
+ def _parse_key(key: str) -> Tuple[str, str, str]:
74
+ parts = [p.strip() for p in str(key or "").split("|")]
75
+ if len(parts) >= 3:
76
+ return parts[0].upper(), parts[1], parts[2]
77
+ return "", "", ""
78
+
79
+
80
+ def _dominant_key(member_clusters: List[Dict[str, Any]]) -> str:
81
+ rel_count: Dict[str, int] = {}
82
+ art_count: Dict[str, int] = {}
83
+ pur_count: Dict[str, int] = {}
84
+ for c in member_clusters:
85
+ rel, art, pur = _parse_key(c.get("cluster_key", ""))
86
+ if rel:
87
+ rel_count[rel] = rel_count.get(rel, 0) + 1
88
+ if art:
89
+ art_count[art] = art_count.get(art, 0) + 1
90
+ if pur:
91
+ pur_count[pur] = pur_count.get(pur, 0) + 1
92
+ rel = max(rel_count, key=rel_count.get) if rel_count else "USES"
93
+ art = max(art_count, key=art_count.get) if art_count else "contribution"
94
+ pur = max(pur_count, key=pur_count.get) if pur_count else "unspecified"
95
+ return f"{rel}|{art}|{pur}"
96
+
97
+
98
+ def _extract_title(meta: Any) -> str:
99
+ if isinstance(meta, list) and meta:
100
+ meta = meta[0]
101
+ if not isinstance(meta, dict):
102
+ return ""
103
+ return str(meta.get("title", ""))
104
+
105
+
106
+ def _title_from_cluster_key(cluster_key: str) -> str:
107
+ parts = [p.strip() for p in str(cluster_key or "").split("|")]
108
+ if len(parts) >= 3:
109
+ relation, artifact, purpose = parts[0], parts[1], parts[2]
110
+ relation_txt = "Uses" if relation.upper() == "USES" else "Extends"
111
+ artifact_txt = artifact.replace("_", " ")
112
+ purpose_txt = purpose.replace("_", " ")
113
+ return f"{relation_txt} {artifact_txt} for {purpose_txt}".strip()
114
+ return cluster_key or ""
115
+
116
+
117
+ def _cluster_by_exact_keys(keys: List[str]) -> List[List[int]]:
118
+ groups: Dict[str, List[int]] = {}
119
+ order: List[str] = []
120
+ for i, key in enumerate(keys):
121
+ k = (key or "").strip()
122
+ if not k:
123
+ k = f"__EMPTY__::{i}"
124
+ if k not in groups:
125
+ groups[k] = []
126
+ order.append(k)
127
+ groups[k].append(i)
128
+ return [groups[k] for k in order]
129
+
130
+
131
+ def _build_initial_clusters_from_contributions(contrib: Dict[str, Any]) -> List[Dict[str, Any]]:
132
+ contributions = [
133
+ c for c in contrib.get("contributions", []) or []
134
+ if c.get("label") in {"USES", "EXTENDS"} and (c.get("paper_claim") or c.get("claim"))
135
+ ]
136
+ if not contributions:
137
+ return []
138
+ cluster_keys_all: List[str] = []
139
+ for c in contributions:
140
+ key = (c.get("cluster_key") or "").strip()
141
+ if not key:
142
+ label = str(c.get("label", "USES")).upper()
143
+ if label not in {"USES", "EXTENDS"}:
144
+ label = "USES"
145
+ key = f"{label}|contribution|unspecified"
146
+ cluster_keys_all.append(key)
147
+ clusters = _cluster_by_exact_keys(cluster_keys_all)
148
+ out: List[Dict[str, Any]] = []
149
+ for idx, cluster in enumerate(clusters, start=1):
150
+ first = contributions[cluster[0]]
151
+ key = cluster_keys_all[cluster[0]]
152
+ title = (first.get("cluster_title") or "").strip() or _title_from_cluster_key(key)
153
+ out.append({
154
+ "cluster_id": f"C{idx}",
155
+ "count": str(len(cluster)),
156
+ "representative_claim": title,
157
+ "cluster_key": key,
158
+ "cluster_title": title,
159
+ "claim_indices": [str(i) for i in cluster],
160
+ })
161
+ return out
162
+
163
+
164
+ def _cluster_support_summary(cluster: Dict[str, Any], contributions: List[Dict[str, Any]]) -> Dict[str, Any]:
165
+ indices = _to_int_indices(cluster.get("claim_indices") or [])
166
+ items: List[Dict[str, Any]] = []
167
+ for i in indices:
168
+ if 0 <= i < len(contributions):
169
+ items.append(contributions[i])
170
+ labels = [str(item.get("label", "")).upper() for item in items if item.get("label")]
171
+ examples: List[str] = []
172
+ for item in items:
173
+ text = str(item.get("paper_claim") or item.get("claim") or "").strip()
174
+ if text:
175
+ examples.append(text)
176
+ if len(examples) >= 3:
177
+ break
178
+ rationales = [str(item.get("rationale", "")).strip() for item in items if item.get("rationale")][:2]
179
+ use_count = sum(1 for x in labels if x == "USES")
180
+ ext_count = sum(1 for x in labels if x == "EXTENDS")
181
+ return {
182
+ "examples": examples,
183
+ "rationales": rationales,
184
+ "uses_count": use_count,
185
+ "extends_count": ext_count,
186
+ "member_count": len(items),
187
+ }
188
+
189
+
190
+ def build_prompt(paper_title: str, centroids: List[Dict[str, Any]]) -> str:
191
+ lines: List[str] = [
192
+ "You are refining downstream citation contribution clusters for one target paper.",
193
+ "Input clusters are already built. Your job is to (a) conservatively merge near-duplicate downstream-usage clusters and (b) drop clusters that do not actually show substantive downstream usage of the target contribution.",
194
+ "",
195
+ f"Target paper: {paper_title}",
196
+ "",
197
+ "Rules:",
198
+ "- Operate only at cluster level. Do not invent new instances.",
199
+ "- Prefer conservative merges. If unsure, keep clusters separate.",
200
+ "- You may drop clusters only when they fail to show real downstream use or extension of the target contribution.",
201
+ "- Drop clusters that are clearly mere mention, loose comparison, background citation, noisy extraction, or off-target usage.",
202
+ "- Never merge USES and EXTENDS clusters together.",
203
+ "- Every input cluster_id must either appear in exactly one kept group or in dropped_clusters.",
204
+ "- kept merged_key must be in format RELATION|artifact|purpose.",
205
+ "- merged_title must be a short natural-language summary (5-12 words).",
206
+ "",
207
+ "Input clusters:",
208
+ ]
209
+ for c in centroids:
210
+ lines.append(
211
+ f"- {c['cluster_id']}: key={c.get('cluster_key','')}; title={c.get('cluster_title','')}; count={c.get('count', 0)}; uses={c.get('uses_count',0)}; extends={c.get('extends_count',0)}; examples={' | '.join(c.get('examples',[])[:2])}; rationales={' | '.join(c.get('rationales',[])[:1])}"
212
+ )
213
+ lines += [
214
+ "",
215
+ "Return JSON only with this shape:",
216
+ "{",
217
+ ' "kept_groups": [',
218
+ " {",
219
+ ' "cluster_ids": ["C1","C3"],',
220
+ ' "merged_title": "Uses target dataset for evaluation",',
221
+ ' "merged_key": "USES|dataset|evaluation",',
222
+ ' "rationale": "Both clusters describe the same downstream dataset use."',
223
+ " }",
224
+ " ],",
225
+ ' "dropped_clusters": [',
226
+ ' {"cluster_id": "C7", "reason": "Only background mention; no substantive downstream use."}',
227
+ " ]",
228
+ "}",
229
+ ]
230
+ return "\n".join(lines)
231
+
232
+
233
+ def _normalize_decision(data: Dict[str, Any], original_clusters: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
234
+ valid_ids = [c.get("cluster_id", "") for c in original_clusters if c.get("cluster_id")]
235
+ valid_set = set(valid_ids)
236
+ assigned = set()
237
+ kept: List[Dict[str, Any]] = []
238
+ dropped: List[Dict[str, Any]] = []
239
+
240
+ for item in data.get("dropped_clusters") or []:
241
+ cid = item.get("cluster_id")
242
+ if cid in valid_set and cid not in assigned:
243
+ assigned.add(cid)
244
+ dropped.append({"cluster_id": cid, "reason": str(item.get("reason", "")).strip() or "Dropped by LLM filter."})
245
+
246
+ for g in data.get("kept_groups") or []:
247
+ ids = [cid for cid in (g.get("cluster_ids") or []) if cid in valid_set and cid not in assigned]
248
+ if not ids:
249
+ continue
250
+ for cid in ids:
251
+ assigned.add(cid)
252
+ kept.append({
253
+ "cluster_ids": ids,
254
+ "merged_title": str(g.get("merged_title", "")).strip(),
255
+ "merged_key": str(g.get("merged_key", "")).strip(),
256
+ "rationale": str(g.get("rationale", "")).strip(),
257
+ })
258
+
259
+ for cid in valid_ids:
260
+ if cid not in assigned:
261
+ kept.append({
262
+ "cluster_ids": [cid],
263
+ "merged_title": "",
264
+ "merged_key": "",
265
+ "rationale": "Auto-singleton fallback.",
266
+ })
267
+
268
+ order = {cid: i for i, cid in enumerate(valid_ids)}
269
+ kept.sort(key=lambda g: min(order[cid] for cid in g["cluster_ids"]))
270
+ dropped.sort(key=lambda x: order.get(x["cluster_id"], 10**9))
271
+ return kept, dropped
272
+
273
+
274
+ def refine_paper(paper_dir: Path, overwrite: bool, inplace: bool) -> str:
275
+ disc_path = paper_dir / DISCOVERY_FILE
276
+ contrib_path = paper_dir / CONTRIB_FILE
277
+ meta_path = paper_dir / PAPER_META_FILE
278
+
279
+ disc = load_json(disc_path)
280
+ contrib = load_json(contrib_path)
281
+ meta = load_json(meta_path)
282
+ if not isinstance(contrib, dict):
283
+ return "missing_inputs"
284
+
285
+ if not isinstance(disc, dict):
286
+ disc = {"paper_id": contrib.get("paper_id"), "decision": "", "justification": "", "clusters": []}
287
+
288
+ clusters = disc.get("clusters") or []
289
+ if not clusters:
290
+ clusters = _build_initial_clusters_from_contributions(contrib)
291
+ if not clusters:
292
+ payload = dict(disc)
293
+ payload["clusters"] = []
294
+ payload["dropped_clusters"] = []
295
+ payload["cluster_refine_method"] = "llm_centroid_merge_filter"
296
+ payload["cluster_refine_source"] = CONTRIB_FILE
297
+ out_path = disc_path if inplace else (paper_dir / OUT_FILE)
298
+ out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
299
+ return "empty_clusters"
300
+
301
+ out_path = disc_path if inplace else (paper_dir / OUT_FILE)
302
+ if out_path.exists() and not overwrite:
303
+ return "skipped"
304
+
305
+ contributions = contrib.get("contributions") or []
306
+ centroids: List[Dict[str, Any]] = []
307
+ auto_dropped: List[Dict[str, Any]] = []
308
+ active_clusters: List[Dict[str, Any]] = []
309
+
310
+ for c in clusters:
311
+ cid = c.get("cluster_id", "")
312
+ summary = _cluster_support_summary(c, contributions)
313
+ rel, _, _ = _parse_key(c.get("cluster_key", ""))
314
+ if summary["uses_count"] + summary["extends_count"] == 0 or rel not in {"USES", "EXTENDS"}:
315
+ auto_dropped.append({"cluster_id": cid, "reason": "No verified USES/EXTENDS support in member contributions."})
316
+ continue
317
+ row = {
318
+ "cluster_id": cid,
319
+ "cluster_key": c.get("cluster_key", ""),
320
+ "cluster_title": c.get("cluster_title") or c.get("representative_claim") or "",
321
+ "count": int(c.get("count", summary["member_count"]) or summary["member_count"]),
322
+ **summary,
323
+ }
324
+ centroids.append(row)
325
+ active_clusters.append(c)
326
+
327
+ if not active_clusters:
328
+ payload = dict(disc)
329
+ payload["clusters"] = []
330
+ payload["dropped_clusters"] = auto_dropped
331
+ payload["cluster_refine_method"] = "llm_centroid_merge_filter"
332
+ payload["cluster_refine_source"] = CONTRIB_FILE if not load_json(disc_path) else DISCOVERY_FILE
333
+ out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
334
+ return "refined"
335
+
336
+ prompt = build_prompt(_extract_title(meta), centroids)
337
+ client = LLMClient()
338
+ raw = client.call(prompt, schema=REFINE_SCHEMA)
339
+ data = json.loads(raw)
340
+ kept_groups, llm_dropped = _normalize_decision(data, active_clusters)
341
+
342
+ id_to_cluster = {c.get("cluster_id"): c for c in active_clusters if c.get("cluster_id")}
343
+ merged_clusters: List[Dict[str, Any]] = []
344
+ for idx, g in enumerate(kept_groups, start=1):
345
+ member_ids = g["cluster_ids"]
346
+ members = [id_to_cluster[mid] for mid in member_ids if mid in id_to_cluster]
347
+ merged_indices: List[int] = []
348
+ for m in members:
349
+ for i in _to_int_indices(m.get("claim_indices") or []):
350
+ if i not in merged_indices:
351
+ merged_indices.append(i)
352
+ merged_indices.sort()
353
+ merged_key = g.get("merged_key") or _dominant_key(members)
354
+ rel, _, _ = _parse_key(merged_key)
355
+ if rel not in {"USES", "EXTENDS"}:
356
+ merged_key = _dominant_key(members)
357
+ merged_title = g.get("merged_title") or (members[0].get("cluster_title") if members else "")
358
+ if not merged_title:
359
+ merged_title = members[0].get("representative_claim", "") if members else ""
360
+ merged_clusters.append({
361
+ "cluster_id": f"C{idx}",
362
+ "count": str(len(merged_indices)),
363
+ "representative_claim": merged_title,
364
+ "cluster_key": merged_key,
365
+ "cluster_title": merged_title,
366
+ "claim_indices": [str(i) for i in merged_indices],
367
+ "source_cluster_ids": member_ids,
368
+ "merge_rationale": g.get("rationale", ""),
369
+ })
370
+
371
+ payload = dict(disc)
372
+ payload["clusters"] = merged_clusters
373
+ payload["dropped_clusters"] = auto_dropped + llm_dropped
374
+ payload["cluster_refine_method"] = "llm_centroid_merge_filter"
375
+ payload["cluster_refine_source"] = CONTRIB_FILE if not load_json(disc_path) else DISCOVERY_FILE
376
+ out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
377
+ return "refined"
378
+
379
+
380
+ def main() -> None:
381
+ parser = argparse.ArgumentParser(description="LLM centroid-level merge/filter pass for downstream contribution clusters.")
382
+ parser.add_argument("--root", type=str, default="runs/processed_papers", help="Root directory containing processed paper directories.")
383
+ parser.add_argument("--overwrite", action="store_true", help="Overwrite output file if it exists.")
384
+ parser.add_argument("--inplace", action="store_true", help="Write back to usage_discovery_from_contributions.json.")
385
+ args = parser.parse_args()
386
+
387
+ root = Path(args.root).expanduser().resolve()
388
+ if not root.exists():
389
+ raise SystemExit(f"Root directory does not exist: {root}")
390
+
391
+ paper_dirs = iter_paper_dirs(root)
392
+ print(f"[INFO] Found {len(paper_dirs)} paper dirs under {root}")
393
+ counts = {"refined": 0, "skipped": 0, "missing_inputs": 0, "empty_clusters": 0}
394
+ for paper_dir in paper_dirs:
395
+ status = refine_paper(paper_dir, overwrite=args.overwrite, inplace=args.inplace)
396
+ counts[status] = counts.get(status, 0) + 1
397
+ print(f"[{status.upper()}] {paper_dir.name}")
398
+ print("[SUMMARY] refined={refined}, skipped={skipped}, missing_inputs={missing_inputs}, empty_clusters={empty_clusters}".format(**counts))
399
+
400
+
401
+ if __name__ == "__main__":
402
+ main()
src/step_07_extract_and_refine/schemas.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTRIBUTION_JSON_SCHEMA = {
2
+ "type": "object",
3
+ "properties": {
4
+ "label": {"type": "string", "enum": ["USES", "EXTENDS", "NOT_CONFIRMED"]},
5
+ "paper_claim": {"type": "string"},
6
+ "cluster_title": {"type": "string"},
7
+ "cluster_key": {"type": "string"},
8
+ "evidence_span": {"type": "string"},
9
+ "rationale": {"type": "string"},
10
+ },
11
+ "required": ["label", "paper_claim", "cluster_title", "cluster_key", "evidence_span", "rationale"],
12
+ }
src/step_08_annotation/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pipeline import TwoPassAnnotationPipeline, TwoPassPipelineResult
2
+
3
+ __all__ = ["TwoPassAnnotationPipeline", "TwoPassPipelineResult"]
src/step_08_annotation/cli.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import typer
7
+
8
+ from .paper_package import load_paper_package
9
+
10
+ from .pipeline import TwoPassAnnotationPipeline
11
+
12
+
13
+ app = typer.Typer(help="Run step 8: derive target contributions, enabling contributions, and groundings.")
14
+
15
+
16
+ def _default_output_root() -> Path:
17
+ return Path("runs/two_pass_outputs")
18
+
19
+
20
+ @app.command()
21
+ def run(
22
+ paper_dir: Path = typer.Option(..., exists=True, file_okay=False, dir_okay=True),
23
+ provider: str = typer.Option("openai", help="Provider family: openai or gemini."),
24
+ model: str = typer.Option("openai/gpt-5", help="Reasoning model used for target-contribution derivation and annotation."),
25
+ formatter_model: str | None = typer.Option(
26
+ None,
27
+ help="Optional model override for pass 2 formatting, e.g. openai/gpt-5-mini or openai/gpt-5.4-pro.",
28
+ ),
29
+ judge_model: str | None = typer.Option(
30
+ None,
31
+ help="Optional model override for pass 1 candidate ranking. Ignored when --candidate-count=1.",
32
+ ),
33
+ candidate_count: int = typer.Option(
34
+ 1,
35
+ help="Number of reasoning candidates to generate. If set to 1, no judge call is made.",
36
+ ),
37
+ formatter_max_attempts: int = typer.Option(
38
+ 3,
39
+ help="Formatter-only retry attempts after pass 1 has succeeded.",
40
+ ),
41
+ include_reference_examples: bool = typer.Option(
42
+ True,
43
+ "--include-reference-examples/--no-include-reference-examples",
44
+ help="Include the built-in reference examples in the pass-1 reasoning prompt.",
45
+ ),
46
+ prompt_profile: str = typer.Option(
47
+ "full",
48
+ help="Reasoning prompt profile: full or generic.",
49
+ ),
50
+ output_root: Path = typer.Option(
51
+ _default_output_root(),
52
+ help="Directory to store run outputs.",
53
+ ),
54
+ run_label: str | None = typer.Option(None, help="Optional label to include in the saved run directory name."),
55
+ annotator_id: str = typer.Option("llm", help="Annotator id to embed in the final UI payload."),
56
+ extracted_claim: str | None = typer.Option(None, help="Optional override for the extracted target contribution."),
57
+ ) -> None:
58
+ paper = load_paper_package(paper_dir, extracted_claim_override=extracted_claim)
59
+ pipeline = TwoPassAnnotationPipeline(
60
+ provider=provider,
61
+ model=model,
62
+ formatter_model=formatter_model,
63
+ judge_model=judge_model,
64
+ output_root=output_root,
65
+ run_label=run_label,
66
+ annotator_id=annotator_id,
67
+ candidate_count=candidate_count,
68
+ formatter_max_attempts=formatter_max_attempts,
69
+ include_reference_examples=include_reference_examples,
70
+ prompt_profile=prompt_profile,
71
+ progress_callback=typer.echo,
72
+ )
73
+ result = pipeline.run(paper)
74
+ typer.echo(str(result.run_dir / "run_output.json"))
75
+
76
+
77
+ @app.command()
78
+ def summarize(run_output: Path = typer.Option(..., exists=True, dir_okay=False, file_okay=True)) -> None:
79
+ data = json.loads(run_output.read_text())
80
+ payload = data.get("ui_payload") or {}
81
+ claims = payload.get("claims") or []
82
+ summary = {
83
+ "paper_id": data.get("paper_id"),
84
+ "target_contribution_count": len(claims),
85
+ "target_contributions": [
86
+ {
87
+ "claim_id": claim.get("claim_id"),
88
+ "rewritten_claim": claim.get("rewritten_claim"),
89
+ "decision": claim.get("decision"),
90
+ "enabling_contribution_count": len(claim.get("ingredients") or []),
91
+ }
92
+ for claim in claims
93
+ ],
94
+ }
95
+ typer.echo(json.dumps(summary, indent=2))
96
+
97
+
98
+ if __name__ == "__main__":
99
+ app()
src/step_08_annotation/final_prompts.py ADDED
The diff for this file is too large to render. See raw diff
 
src/step_08_annotation/paper_package.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List
6
+
7
+ from common.paper_package import (
8
+ PaperPackage,
9
+ _collect_bibliography,
10
+ _collect_citation_contexts,
11
+ _collect_full_processed_text,
12
+ _collect_sections,
13
+ _load_json,
14
+ _normalize_dict_payload,
15
+ )
16
+
17
+
18
+ def _collect_all_cluster_evidence(paper_dir: Path) -> List[Dict[str, Any]]:
19
+ discovery = _normalize_dict_payload(_load_json(paper_dir / "usage_discovery_from_contributions.json", {}))
20
+ clusters = discovery.get("clusters", [])
21
+ out = []
22
+ for cluster in clusters:
23
+ out.append(
24
+ {
25
+ "cluster_id": cluster.get("cluster_id"),
26
+ "representative_claim": cluster.get("representative_claim") or cluster.get("cluster_title"),
27
+ "cluster_title": cluster.get("cluster_title"),
28
+ "count": cluster.get("count"),
29
+ "cluster_key": cluster.get("cluster_key"),
30
+ "claim_indices": cluster.get("claim_indices", []),
31
+ "source_cluster_ids": cluster.get("source_cluster_ids", []),
32
+ "merge_rationale": cluster.get("merge_rationale"),
33
+ }
34
+ )
35
+ return out
36
+
37
+
38
+ def load_paper_package(paper_dir: str | Path, extracted_claim_override: str | None = None) -> PaperPackage:
39
+ paper_dir = Path(paper_dir)
40
+ paper_metadata = _normalize_dict_payload(_load_json(paper_dir / "paper_metadata.json", {}))
41
+ cluster_evidence = _collect_all_cluster_evidence(paper_dir)
42
+ seed = extracted_claim_override or ""
43
+ return PaperPackage(
44
+ paper_dir=paper_dir,
45
+ paper_metadata=paper_metadata,
46
+ extracted_discovery_claim=seed,
47
+ downstream_cluster_evidence=cluster_evidence,
48
+ paper_text=_collect_sections(paper_dir),
49
+ full_processed_text=_collect_full_processed_text(paper_dir),
50
+ bibliography=_collect_bibliography(paper_dir),
51
+ citation_contexts=_collect_citation_contexts(paper_dir),
52
+ )
src/step_08_annotation/pipeline.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import traceback
5
+ from dataclasses import dataclass
6
+ from datetime import datetime, timezone
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict
9
+
10
+ from common.model_client import ModelConfig, MultiProviderLLMClient
11
+ from common.paper_package import PaperPackage
12
+
13
+ from .final_prompts import (
14
+ SYSTEM_TWO_PASS_FORMATTER,
15
+ SYSTEM_TWO_PASS_JUDGE,
16
+ SYSTEM_TWO_PASS_REASONING,
17
+ formatter_prompt,
18
+ judge_prompt,
19
+ reasoning_prompt,
20
+ )
21
+ from .schemas import JudgeResult, UIPayload
22
+
23
+
24
+ @dataclass
25
+ class TwoPassPipelineResult:
26
+ run_dir: Path
27
+ result: Dict[str, Any]
28
+
29
+
30
+ class FormatterStageError(RuntimeError):
31
+ def __init__(self, message: str, run_dir: Path):
32
+ super().__init__(message)
33
+ self.run_dir = run_dir
34
+
35
+
36
+ class TwoPassAnnotationPipeline:
37
+ def __init__(
38
+ self,
39
+ *,
40
+ provider: str,
41
+ model: str,
42
+ formatter_model: str | None,
43
+ judge_model: str | None,
44
+ output_root: Path,
45
+ run_label: str | None = None,
46
+ annotator_id: str = "llm",
47
+ temperature: float = 0.2,
48
+ max_tokens: int = 16000,
49
+ candidate_count: int = 1,
50
+ formatter_max_attempts: int = 3,
51
+ include_reference_examples: bool = True,
52
+ prompt_profile: str = "full",
53
+ progress_callback: Callable[[str], None] | None = None,
54
+ ):
55
+ self.output_root = output_root
56
+ self.annotator_id = annotator_id
57
+ self.progress_callback = progress_callback
58
+ self.run_label = run_label
59
+ self.candidate_count = max(1, candidate_count)
60
+ self.formatter_max_attempts = max(1, formatter_max_attempts)
61
+ self.include_reference_examples = include_reference_examples
62
+ self.prompt_profile = prompt_profile
63
+ self.use_judge = self.candidate_count > 1
64
+ stage_models = {}
65
+ if formatter_model:
66
+ stage_models["two_pass_formatter"] = formatter_model
67
+ if judge_model and self.use_judge:
68
+ stage_models["two_pass_judge"] = judge_model
69
+ self.client = MultiProviderLLMClient(
70
+ default_config=ModelConfig(
71
+ provider=provider,
72
+ model=model,
73
+ temperature=temperature,
74
+ max_tokens=max_tokens,
75
+ ),
76
+ stage_models=stage_models,
77
+ )
78
+
79
+ def run(self, paper: PaperPackage) -> TwoPassPipelineResult:
80
+ run_dir = self._make_run_dir(paper)
81
+ payload = {
82
+ **paper.to_prompt_payload(),
83
+ "paper_dir": paper.paper_dir,
84
+ "full_processed_text": self._load_full_processed_text(paper),
85
+ }
86
+ formatter_config = self.client.config_for_stage("two_pass_formatter")
87
+ self._log(
88
+ f"[run] paper={paper.paper_dir.name} provider={self.client.default_config.provider} model={self.client.default_config.model_name}"
89
+ )
90
+ self._log(f"[run] formatter_model={formatter_config.model_name}")
91
+ self._log(f"[run] include_reference_examples={self.include_reference_examples}")
92
+ self._log(f"[run] prompt_profile={self.prompt_profile}")
93
+ if self.use_judge:
94
+ judge_config = self.client.config_for_stage("two_pass_judge")
95
+ self._log(f"[run] judge_model={judge_config.model_name}")
96
+ else:
97
+ self._log("[run] judge_model=disabled (candidate_count=1)")
98
+ self._log(f"[run] output={run_dir}")
99
+
100
+ reasoning_user_prompt = reasoning_prompt(payload, include_reference_examples=self.include_reference_examples, prompt_profile=self.prompt_profile)
101
+ self._write_text(run_dir / "pass_1_reasoning.prompt.txt", reasoning_user_prompt)
102
+ self._log(
103
+ f"[pass 1] free-form reasoning ({self.candidate_count} candidate{'s' if self.candidate_count != 1 else ''})"
104
+ )
105
+
106
+ candidate_texts: list[str] = []
107
+ candidate_paths: list[str] = []
108
+ for index in range(self.candidate_count):
109
+ reasoning_text = self.client.generate_text(
110
+ stage_name="two_pass_reasoning",
111
+ system_prompt=SYSTEM_TWO_PASS_REASONING,
112
+ user_prompt=reasoning_user_prompt,
113
+ )
114
+ candidate_id = f"candidate_{index + 1}"
115
+ candidate_path = run_dir / f"pass_1_reasoning.output.{candidate_id}.md"
116
+ self._write_text(candidate_path, reasoning_text)
117
+ candidate_texts.append(reasoning_text)
118
+ candidate_paths.append(str(candidate_path))
119
+
120
+ selected_candidate_index = 0
121
+ selected_candidate_id = "candidate_1"
122
+ selected_reasoning_text = candidate_texts[0]
123
+ judge_output_path: Path | None = None
124
+
125
+ if self.use_judge:
126
+ judge_user_prompt = judge_prompt(payload, candidate_texts)
127
+ self._write_text(run_dir / "pass_1_reasoning.judge.prompt.txt", judge_user_prompt)
128
+ self._log("[pass 1] candidate judging")
129
+ judge_result = self.client.generate_structured(
130
+ stage_name="two_pass_judge",
131
+ system_prompt=SYSTEM_TWO_PASS_JUDGE,
132
+ user_prompt=judge_user_prompt,
133
+ response_model=JudgeResult,
134
+ )
135
+ judge_output_path = run_dir / "pass_1_reasoning.judge.output.json"
136
+ self._write_json(judge_output_path, judge_result.model_dump())
137
+ selected_candidate_index = judge_result.selected_candidate_index
138
+ selected_candidate_id = judge_result.selected_candidate_id
139
+ selected_reasoning_text = candidate_texts[selected_candidate_index]
140
+
141
+ selected_reasoning_path = run_dir / "pass_1_reasoning.selected.md"
142
+ self._write_text(selected_reasoning_path, selected_reasoning_text)
143
+
144
+ formatter_user_prompt = formatter_prompt(payload, selected_reasoning_text, self.annotator_id)
145
+ self._write_text(run_dir / "pass_2_formatter.prompt.txt", formatter_user_prompt)
146
+ final_payload: UIPayload | None = None
147
+ formatter_attempts: list[dict[str, Any]] = []
148
+ for attempt in range(1, self.formatter_max_attempts + 1):
149
+ self._log(
150
+ f"[pass 2] strict ui json formatting (attempt {attempt}/{self.formatter_max_attempts})"
151
+ )
152
+ try:
153
+ final_payload = self.client.generate_structured(
154
+ stage_name="two_pass_formatter",
155
+ system_prompt=SYSTEM_TWO_PASS_FORMATTER,
156
+ user_prompt=formatter_user_prompt,
157
+ response_model=UIPayload,
158
+ )
159
+ formatter_attempts.append({"attempt": attempt, "status": "success"})
160
+ break
161
+ except Exception as exc:
162
+ error_text = "".join(traceback.format_exception(exc)).strip()
163
+ error_path = run_dir / f"pass_2_formatter.attempt_{attempt}.error.txt"
164
+ self._write_text(error_path, error_text)
165
+ formatter_attempts.append(
166
+ {
167
+ "attempt": attempt,
168
+ "status": "failed",
169
+ "error": str(exc),
170
+ "error_path": str(error_path),
171
+ }
172
+ )
173
+ if attempt < self.formatter_max_attempts:
174
+ self._log("[pass 2] formatter failed; retrying formatter only")
175
+
176
+ if final_payload is None:
177
+ self._write_json(run_dir / "formatter_attempts.json", {"attempts": formatter_attempts})
178
+ raise FormatterStageError(
179
+ f"Formatter failed after {self.formatter_max_attempts} attempts; pass 1 outputs kept in {run_dir}",
180
+ run_dir,
181
+ )
182
+
183
+ self._write_json(run_dir / "formatter_attempts.json", {"attempts": formatter_attempts})
184
+ self._write_json(run_dir / "pass_2_ui_payload.json", final_payload.model_dump())
185
+
186
+ result = {
187
+ "paper_id": paper.paper_dir.name,
188
+ "paper_dir": str(paper.paper_dir),
189
+ "generated_at": datetime.now(timezone.utc).isoformat(),
190
+ "reasoner_model": self.client.default_config.model_name,
191
+ "formatter_model": formatter_config.model_name,
192
+ "judge_model": judge_config.model_name if self.use_judge else None,
193
+ "candidate_count": self.candidate_count,
194
+ "include_reference_examples": self.include_reference_examples,
195
+ "prompt_profile": self.prompt_profile,
196
+ "reasoning_candidate_paths": [str(path) for path in candidate_paths],
197
+ "selected_reasoning_candidate": selected_candidate_id,
198
+ "selected_candidate_index": selected_candidate_index,
199
+ "selected_reasoning_path": str(selected_reasoning_path),
200
+ "judge_output_path": str(judge_output_path) if judge_output_path is not None else None,
201
+ "formatter_attempts": formatter_attempts,
202
+ "ui_payload_path": str(run_dir / "pass_2_ui_payload.json"),
203
+ "ui_payload": final_payload.model_dump(),
204
+ }
205
+ self._write_json(run_dir / "run_output.json", result)
206
+ self._log("[run] complete")
207
+ return TwoPassPipelineResult(run_dir=run_dir, result=result)
208
+
209
+ def _make_run_dir(self, paper: PaperPackage) -> Path:
210
+ stamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
211
+ run_name = stamp
212
+ if self.run_label:
213
+ run_name = f"{self._slugify(self.run_label)}__{stamp}"
214
+ run_dir = self.output_root / paper.paper_dir.name / run_name
215
+ run_dir.mkdir(parents=True, exist_ok=True)
216
+ return run_dir
217
+
218
+ def _write_json(self, path: Path, payload: Dict[str, Any]) -> None:
219
+ path.parent.mkdir(parents=True, exist_ok=True)
220
+ path.write_text(json.dumps(payload, indent=2, ensure_ascii=True) + "\n")
221
+
222
+ def _write_text(self, path: Path, text: str) -> None:
223
+ path.parent.mkdir(parents=True, exist_ok=True)
224
+ path.write_text(text)
225
+
226
+ def _log(self, message: str) -> None:
227
+ if self.progress_callback:
228
+ self.progress_callback(message)
229
+
230
+ @staticmethod
231
+ def _slugify(value: str) -> str:
232
+ slug = "".join(ch if ch.isalnum() or ch in {"-", "_", "."} else "-" for ch in value.strip())
233
+ slug = "-".join(part for part in slug.split("-") if part)
234
+ return slug[:160] or "run"
235
+
236
+ def _load_full_processed_text(self, paper: PaperPackage) -> str:
237
+ processed_path = paper.paper_dir / "processed_main.tex"
238
+ if processed_path.exists():
239
+ try:
240
+ return processed_path.read_text()
241
+ except Exception:
242
+ pass
243
+
244
+ sections_dir = paper.paper_dir / "sections"
245
+ parts: list[str] = []
246
+ if sections_dir.exists():
247
+ for path in sorted(sections_dir.iterdir()):
248
+ if not path.is_file():
249
+ continue
250
+ try:
251
+ text = path.read_text().strip()
252
+ except Exception:
253
+ continue
254
+ if text:
255
+ parts.append(f"[{path.name}]\n{text}")
256
+ return "\n\n".join(parts)
src/step_08_annotation/schemas.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ ALLOWED_ARTIFACT_TYPES = ["Resource", "Finding", "Method", "Benchmark", "Dataset", "Tool", "Other"]
9
+ ALLOWED_ROLES = [
10
+ "CONCEPTUAL_FRAMEWORK",
11
+ "CORE_METHOD",
12
+ "DATA_SOURCE",
13
+ "MODEL_INITIALIZATION",
14
+ "EVALUATION_PROTOCOL",
15
+ ]
16
+
17
+
18
+ class ReasoningCandidate(BaseModel):
19
+ study: str
20
+ decision: Literal["accepted_canonical", "accepted_additional", "accepted_none", "rejected_candidate"]
21
+ why: str
22
+
23
+
24
+ class ReasoningIngredient(BaseModel):
25
+ ingredient_id: str
26
+ ingredient: str
27
+ necessary: bool
28
+ from_prior_work: bool
29
+ maps_cleanly_to_one_study: bool
30
+ notes: str
31
+ canonical_grounding_decision: Dict[str, Any]
32
+ additional_groundings: List[Dict[str, Any]]
33
+ candidate_studies_considered: List[ReasoningCandidate]
34
+ role: Optional[Literal["CONCEPTUAL_FRAMEWORK", "CORE_METHOD", "DATA_SOURCE", "MODEL_INITIALIZATION", "EVALUATION_PROTOCOL", "IMPLEMENTATION_TOOLING", "TRAINING_DATA"]] = None
35
+ contribution: str
36
+ rationale: str
37
+ evidence_span: str
38
+
39
+
40
+ class ReasoningClaim(BaseModel):
41
+ claim_id: str
42
+ artifact_type: Literal["Resource", "Finding", "Method", "Benchmark", "Dataset", "Tool", "Other"]
43
+ rewritten_claim: str
44
+ cluster_id: str = ""
45
+ decision: Literal["YES_SUFFICIENT", "NO_NOT_DISCOVERY"] = "YES_SUFFICIENT"
46
+ notes: str = ""
47
+ why_this_is_atomic: str
48
+ ingredients: List[ReasoningIngredient]
49
+
50
+
51
+ class ReasoningOutput(BaseModel):
52
+ original_discovery_claim: str
53
+ claim_split_decision: Dict[str, Any]
54
+ rewritten_claims: List[ReasoningClaim]
55
+ paper_level_notes: str = ""
56
+
57
+
58
+ class GroundingRecord(BaseModel):
59
+ ref_id: Optional[str] = None
60
+ bib_key: Optional[str] = None
61
+ paper_id: Optional[str] = None
62
+ external_ids: Optional[Dict[str, Any]] = None
63
+ ref_title: Optional[str] = None
64
+ ref_year: Optional[str] = None
65
+ ref_authors: Optional[str] = None
66
+
67
+
68
+ class CanonicalAnnotation(BaseModel):
69
+ role: Optional[Literal["CONCEPTUAL_FRAMEWORK", "CORE_METHOD", "DATA_SOURCE", "MODEL_INITIALIZATION", "EVALUATION_PROTOCOL", "IMPLEMENTATION_TOOLING", "TRAINING_DATA"]] = None
70
+ roles: List[Literal["CONCEPTUAL_FRAMEWORK", "CORE_METHOD", "DATA_SOURCE", "MODEL_INITIALIZATION", "EVALUATION_PROTOCOL", "IMPLEMENTATION_TOOLING", "TRAINING_DATA"]]
71
+ contribution: str
72
+ rationale: str
73
+ evidence_span: str
74
+
75
+
76
+ class IngredientPayload(BaseModel):
77
+ ingredient_id: str
78
+ ingredient: str
79
+ canonical_ref_id: str
80
+ canonical_grounding: Optional[GroundingRecord] = None
81
+ additional_ref_ids: List[str]
82
+ additional_groundings: List[GroundingRecord]
83
+ canonical_annotation: CanonicalAnnotation
84
+
85
+
86
+ class EnablingDiscoveryPayload(GroundingRecord):
87
+ ingredient_id: str
88
+ ingredient: str
89
+ role: Optional[Literal["CONCEPTUAL_FRAMEWORK", "CORE_METHOD", "DATA_SOURCE", "MODEL_INITIALIZATION", "EVALUATION_PROTOCOL", "IMPLEMENTATION_TOOLING", "TRAINING_DATA"]] = None
90
+ roles: List[Literal["CONCEPTUAL_FRAMEWORK", "CORE_METHOD", "DATA_SOURCE", "MODEL_INITIALIZATION", "EVALUATION_PROTOCOL", "IMPLEMENTATION_TOOLING", "TRAINING_DATA"]]
91
+ contribution: str
92
+ rationale: str
93
+ evidence_span: str
94
+
95
+
96
+ class ClaimPayload(BaseModel):
97
+ claim_id: str
98
+ text: str
99
+ rewritten_claim: str
100
+ cluster_id: str = ""
101
+ decision: Literal["YES_SUFFICIENT", "NO_NOT_DISCOVERY", "UNCERTAIN"] = "YES_SUFFICIENT"
102
+ notes: str = ""
103
+ ingredients: List[IngredientPayload]
104
+ enabling_discoveries: List[EnablingDiscoveryPayload]
105
+
106
+
107
+ class JudgeCandidateScore(BaseModel):
108
+ candidate_id: str
109
+ candidate_index: int
110
+ score: int
111
+ assessment: str
112
+
113
+
114
+ class JudgeResult(BaseModel):
115
+ selected_candidate_index: int
116
+ selected_candidate_id: str
117
+ selected_reason: str
118
+ candidate_scores: List[JudgeCandidateScore]
119
+
120
+
121
+ class UIPayload(BaseModel):
122
+ target_paper_id: str
123
+ target_title: Optional[str] = None
124
+ target_year: Optional[int] = None
125
+ annotator_id: str
126
+ active_claim_id: Optional[str] = None
127
+ claims: List[ClaimPayload]