nkshirsa commited on
Commit
f6c2b19
·
verified ·
1 Parent(s): a4f0eec

Add SciRIFF training data integration script (72x more data for training)

Browse files
phd_research_os_v2/training/sciriff_integration.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SciRIFF Training Data Integration
3
+ ====================================
4
+ Converts AllenAI's SciRIFF dataset (137K expert-written examples across
5
+ 54 scientific tasks) into the PhD Research OS ChatML format.
6
+
7
+ Filters for tasks relevant to our pipeline:
8
+ - Claim verification (SciFact tasks)
9
+ - Information extraction (SciERC tasks)
10
+ - NER and entity recognition
11
+ - Summarization (faithful compression)
12
+
13
+ Addresses blindspots: D-1, D-6, PA-3
14
+ Source: SYSTEM_INSPIRATIONS.md DA-3
15
+
16
+ Dependencies:
17
+ pip install datasets
18
+ """
19
+
20
+ import json
21
+ import logging
22
+ from typing import Optional
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Tasks from SciRIFF that map to our pipeline
27
+ RELEVANT_TASK_FAMILIES = {
28
+ "ie", # Information extraction → Layer 2
29
+ "classification", # Classification → epistemic tagging
30
+ "summarization", # Summarization → faithful claim compression
31
+ "qa", # Question answering → query decomposition
32
+ "entailment", # Entailment → claim verification
33
+ }
34
+
35
+ # Specific task prefixes that are highly relevant
36
+ HIGH_PRIORITY_TASKS = {
37
+ "scifact", # Claim verification (SUPPORT/CONTRADICT)
38
+ "scierc", # Scientific entity + relation extraction
39
+ "evidence_inference", # RCT outcome extraction
40
+ "biosses", # Biomedical sentence similarity
41
+ "chemprot", # Chemical-protein interaction extraction
42
+ "ncbi_disease", # Disease NER
43
+ "pubmedqa", # Biomedical QA
44
+ "qasper", # Full-text scientific QA
45
+ }
46
+
47
+ # System prompts to wrap SciRIFF examples in our format
48
+ SYSTEM_PROMPTS = {
49
+ "ie": (
50
+ "You are the Claim Extractor of a PhD Research OS. "
51
+ "Extract structured information from scientific text. "
52
+ "Be precise, preserve qualifiers, and output valid JSON."
53
+ ),
54
+ "classification": (
55
+ "You are the Epistemic Classifier of a PhD Research OS. "
56
+ "Classify the given scientific text according to the specified taxonomy. "
57
+ "Consider context, hedging language, and evidence strength."
58
+ ),
59
+ "summarization": (
60
+ "You are the Synthesis Agent of a PhD Research OS. "
61
+ "Summarize scientific text faithfully. Never add information "
62
+ "not present in the source. Preserve all qualifiers and hedging."
63
+ ),
64
+ "qa": (
65
+ "You are the Query Planner of a PhD Research OS. "
66
+ "Answer questions about scientific papers using evidence from the text. "
67
+ "Cite specific passages. Say 'insufficient evidence' when appropriate."
68
+ ),
69
+ "entailment": (
70
+ "You are the Claim Verifier of a PhD Research OS. "
71
+ "Given a claim and evidence, determine if the evidence SUPPORTS, "
72
+ "CONTRADICTS, or provides NOT_ENOUGH_INFO about the claim."
73
+ ),
74
+ }
75
+
76
+
77
+ def load_sciriff(config: str = "4096", split: str = "train",
78
+ max_examples: int = None) -> list[dict]:
79
+ """
80
+ Load SciRIFF from HuggingFace and convert to ChatML format.
81
+
82
+ Args:
83
+ config: Token length config ("4096", "8192", "16384")
84
+ split: Dataset split
85
+ max_examples: Limit for quick testing
86
+
87
+ Returns:
88
+ List of {"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}]}
89
+ """
90
+ from datasets import load_dataset
91
+
92
+ logger.info(f"Loading SciRIFF ({config}/{split})...")
93
+ ds = load_dataset("allenai/SciRIFF", config, split=split, trust_remote_code=True)
94
+
95
+ if max_examples:
96
+ ds = ds.select(range(min(max_examples, len(ds))))
97
+
98
+ converted = []
99
+ skipped = 0
100
+ task_counts = {}
101
+
102
+ for row in ds:
103
+ input_text = row.get("input", "")
104
+ output_text = row.get("output", "")
105
+ metadata = row.get("metadata", {})
106
+ instance_id = row.get("_instance_id", "")
107
+
108
+ if not input_text or not output_text:
109
+ skipped += 1
110
+ continue
111
+
112
+ # Determine task family from metadata or instance_id
113
+ task_family = None
114
+ if isinstance(metadata, dict):
115
+ task_family = metadata.get("task_family", "")
116
+
117
+ # Also check instance_id for task identification
118
+ task_name = instance_id.split(":")[0] if ":" in instance_id else ""
119
+
120
+ # Filter for relevant tasks
121
+ is_relevant = False
122
+ if task_family and task_family.lower() in RELEVANT_TASK_FAMILIES:
123
+ is_relevant = True
124
+ for prefix in HIGH_PRIORITY_TASKS:
125
+ if task_name.lower().startswith(prefix):
126
+ is_relevant = True
127
+ break
128
+
129
+ if not is_relevant:
130
+ # Still include with lower priority — all scientific tasks help
131
+ # but only include 20% of non-priority tasks to maintain focus
132
+ import hashlib
133
+ h = int(hashlib.md5(instance_id.encode()).hexdigest(), 16)
134
+ if h % 5 != 0: # Keep ~20%
135
+ skipped += 1
136
+ continue
137
+
138
+ # Select system prompt based on task family
139
+ system_prompt = SYSTEM_PROMPTS.get(
140
+ task_family.lower() if task_family else "ie",
141
+ SYSTEM_PROMPTS["ie"]
142
+ )
143
+
144
+ # Build ChatML message
145
+ messages = [
146
+ {"role": "system", "content": system_prompt},
147
+ {"role": "user", "content": input_text},
148
+ {"role": "assistant", "content": output_text},
149
+ ]
150
+
151
+ converted.append({"messages": messages})
152
+
153
+ # Track task distribution
154
+ task_key = task_name or task_family or "unknown"
155
+ task_counts[task_key] = task_counts.get(task_key, 0) + 1
156
+
157
+ logger.info(
158
+ f"Converted {len(converted)} SciRIFF examples "
159
+ f"(skipped {skipped}, {len(task_counts)} task types)"
160
+ )
161
+
162
+ # Log task distribution
163
+ sorted_tasks = sorted(task_counts.items(), key=lambda x: -x[1])
164
+ for task, count in sorted_tasks[:15]:
165
+ logger.info(f" {task}: {count} examples")
166
+
167
+ return converted
168
+
169
+
170
+ def merge_datasets(existing_path: str = "nkshirsa/phd-research-os-sft-data",
171
+ sciriff_config: str = "4096",
172
+ sciriff_max: int = 10000,
173
+ existing_max: int = None) -> dict:
174
+ """
175
+ Merge existing PhD Research OS training data with SciRIFF.
176
+
177
+ Returns:
178
+ {
179
+ "merged": list of ChatML examples,
180
+ "stats": {
181
+ "existing_count": int,
182
+ "sciriff_count": int,
183
+ "total": int,
184
+ }
185
+ }
186
+ """
187
+ from datasets import load_dataset
188
+
189
+ # Load existing data
190
+ logger.info(f"Loading existing data from {existing_path}...")
191
+ existing_ds = load_dataset(existing_path, split="train", trust_remote_code=True)
192
+ existing_examples = [{"messages": row["messages"]} for row in existing_ds]
193
+ if existing_max:
194
+ existing_examples = existing_examples[:existing_max]
195
+
196
+ # Load SciRIFF
197
+ sciriff_examples = load_sciriff(config=sciriff_config, max_examples=sciriff_max)
198
+
199
+ # Merge
200
+ merged = existing_examples + sciriff_examples
201
+
202
+ stats = {
203
+ "existing_count": len(existing_examples),
204
+ "sciriff_count": len(sciriff_examples),
205
+ "total": len(merged),
206
+ "expansion_factor": round(len(merged) / max(len(existing_examples), 1), 1),
207
+ }
208
+
209
+ logger.info(
210
+ f"Merged dataset: {stats['existing_count']} existing + "
211
+ f"{stats['sciriff_count']} SciRIFF = {stats['total']} total "
212
+ f"({stats['expansion_factor']}× expansion)"
213
+ )
214
+
215
+ return {"merged": merged, "stats": stats}
216
+
217
+
218
+ def create_merged_hf_dataset(output_path: str = "data/merged_sft",
219
+ sciriff_max: int = 10000,
220
+ test_ratio: float = 0.1):
221
+ """
222
+ Create a merged HuggingFace dataset on disk, ready for training.
223
+
224
+ Args:
225
+ output_path: Where to save the dataset
226
+ sciriff_max: Maximum SciRIFF examples to include
227
+ test_ratio: Fraction for test split
228
+ """
229
+ from datasets import Dataset, DatasetDict
230
+ import random
231
+
232
+ result = merge_datasets(sciriff_max=sciriff_max)
233
+ all_examples = result["merged"]
234
+
235
+ # Shuffle
236
+ random.seed(42)
237
+ random.shuffle(all_examples)
238
+
239
+ # Split
240
+ n_test = int(len(all_examples) * test_ratio)
241
+ test_examples = all_examples[:n_test]
242
+ train_examples = all_examples[n_test:]
243
+
244
+ # Create HF dataset
245
+ train_ds = Dataset.from_list(train_examples)
246
+ test_ds = Dataset.from_list(test_examples)
247
+
248
+ ds_dict = DatasetDict({"train": train_ds, "test": test_ds})
249
+ ds_dict.save_to_disk(output_path)
250
+
251
+ logger.info(
252
+ f"Saved merged dataset to {output_path}: "
253
+ f"{len(train_examples)} train, {len(test_examples)} test"
254
+ )
255
+
256
+ return {
257
+ "path": output_path,
258
+ "train_count": len(train_examples),
259
+ "test_count": len(test_examples),
260
+ "stats": result["stats"],
261
+ }