muthuk1 commited on
Commit
3ee4528
·
verified ·
1 Parent(s): 60b14ca

Add dataset preparation script — downloads Wikipedia/arxiv/BBC, verifies 2M+ tokens, ingests into TigerGraph

Browse files
Files changed (1) hide show
  1. graphrag/prepare_dataset.py +332 -0
graphrag/prepare_dataset.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Preparation — 2M+ Token Corpus for GraphRAG Hackathon
3
+ ==============================================================
4
+ Downloads, tokenizes, and ingests a 2M+ token corpus into TigerGraph.
5
+
6
+ Supported sources (pick one or combine):
7
+ 1. Wikipedia (English) — best entity density, CC-BY-SA
8
+ 2. arXiv papers (neuralwork/arxiver) — full text, CC-BY-NC-SA
9
+ 3. BBC News (RealTimeData/bbc_news_alltime) — events, CC-BY
10
+
11
+ Usage:
12
+ python graphrag/prepare_dataset.py --source wikipedia --target-tokens 2500000
13
+ python graphrag/prepare_dataset.py --source arxiv --target-tokens 2500000
14
+ python graphrag/prepare_dataset.py --source bbc --target-tokens 2500000
15
+ python graphrag/prepare_dataset.py --source wikipedia --target-tokens 2500000 --ingest
16
+ """
17
+ import argparse
18
+ import hashlib
19
+ import json
20
+ import logging
21
+ import os
22
+ import sys
23
+ import time
24
+ from typing import Dict, List, Tuple
25
+
26
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def count_tokens(text: str) -> int:
31
+ """Estimate token count. Uses tiktoken if available, otherwise word-based estimate."""
32
+ try:
33
+ import tiktoken
34
+ enc = tiktoken.encoding_for_model("gpt-4o-mini")
35
+ return len(enc.encode(text))
36
+ except ImportError:
37
+ # Rough estimate: 1 token ≈ 0.75 words (English)
38
+ return int(len(text.split()) * 1.33)
39
+
40
+
41
+ def load_wikipedia(target_tokens: int, domain: str = "science") -> List[Dict]:
42
+ """
43
+ Load Wikipedia articles until target token count is reached.
44
+
45
+ Domain filters available:
46
+ - "science": physics, chemistry, biology, mathematics, astronomy
47
+ - "history": wars, empires, historical figures, events
48
+ - "politics": countries, politicians, governments, elections
49
+ - "technology": computing, AI, internet, software, engineering
50
+ - "all": no filter (fastest, most diverse)
51
+ """
52
+ from datasets import load_dataset
53
+
54
+ logger.info(f"Loading Wikipedia (domain={domain}, target={target_tokens:,} tokens)...")
55
+
56
+ domain_keywords = {
57
+ "science": ["physicist", "scientist", "chemist", "biologist", "mathematician",
58
+ "theory", "equation", "discovery", "experiment", "research",
59
+ "university", "professor", "nobel", "quantum", "evolution"],
60
+ "history": ["war", "battle", "empire", "dynasty", "revolution", "treaty",
61
+ "king", "queen", "president", "ancient", "medieval", "colonial"],
62
+ "politics": ["election", "government", "parliament", "president", "minister",
63
+ "democrat", "republic", "constitution", "legislation", "political"],
64
+ "technology": ["computer", "software", "algorithm", "internet", "artificial",
65
+ "programming", "engineer", "processor", "database", "network"],
66
+ "all": [],
67
+ }
68
+ keywords = domain_keywords.get(domain, [])
69
+
70
+ ds = load_dataset("wikimedia/wikipedia", "20231101.en", split="train", streaming=True)
71
+
72
+ documents = []
73
+ total_tokens = 0
74
+ scanned = 0
75
+
76
+ for article in ds:
77
+ scanned += 1
78
+ title = article.get("title", "")
79
+ text = article.get("text", "")
80
+
81
+ if not text or len(text) < 200:
82
+ continue
83
+
84
+ # Domain filter
85
+ if keywords:
86
+ title_lower = title.lower()
87
+ text_lower = text[:2000].lower() # check first 2000 chars only for speed
88
+ if not any(kw in title_lower or kw in text_lower for kw in keywords):
89
+ if scanned % 1000 == 0:
90
+ logger.info(f" Scanned {scanned:,} articles, collected {len(documents)}, "
91
+ f"tokens: {total_tokens:,}/{target_tokens:,}")
92
+ continue
93
+
94
+ tokens = count_tokens(text)
95
+ documents.append({
96
+ "id": hashlib.md5(title.encode()).hexdigest()[:12],
97
+ "title": title,
98
+ "content": text,
99
+ "source": "wikipedia",
100
+ "tokens": tokens,
101
+ "url": article.get("url", ""),
102
+ })
103
+ total_tokens += tokens
104
+
105
+ if len(documents) % 100 == 0:
106
+ logger.info(f" Collected {len(documents)} articles, "
107
+ f"tokens: {total_tokens:,}/{target_tokens:,} "
108
+ f"({total_tokens/target_tokens*100:.1f}%)")
109
+
110
+ if total_tokens >= target_tokens:
111
+ break
112
+
113
+ logger.info(f"✅ Wikipedia: {len(documents)} articles, {total_tokens:,} tokens")
114
+ return documents
115
+
116
+
117
+ def load_arxiv(target_tokens: int) -> List[Dict]:
118
+ """Load arXiv papers with full markdown text from neuralwork/arxiver."""
119
+ from datasets import load_dataset
120
+
121
+ logger.info(f"Loading arXiv papers (target={target_tokens:,} tokens)...")
122
+ ds = load_dataset("neuralwork/arxiver", split="train")
123
+
124
+ documents = []
125
+ total_tokens = 0
126
+
127
+ for i, paper in enumerate(ds):
128
+ text = paper.get("markdown", "")
129
+ if not text or len(text) < 500:
130
+ continue
131
+
132
+ title = paper.get("title", f"Paper {i}")
133
+ tokens = count_tokens(text)
134
+
135
+ documents.append({
136
+ "id": paper.get("id", hashlib.md5(title.encode()).hexdigest()[:12]),
137
+ "title": title,
138
+ "content": text,
139
+ "source": "arxiv",
140
+ "tokens": tokens,
141
+ "authors": paper.get("authors", ""),
142
+ "published_date": paper.get("published_date", ""),
143
+ })
144
+ total_tokens += tokens
145
+
146
+ if len(documents) % 50 == 0:
147
+ logger.info(f" Collected {len(documents)} papers, "
148
+ f"tokens: {total_tokens:,}/{target_tokens:,}")
149
+
150
+ if total_tokens >= target_tokens:
151
+ break
152
+
153
+ logger.info(f"✅ arXiv: {len(documents)} papers, {total_tokens:,} tokens")
154
+ return documents
155
+
156
+
157
+ def load_bbc_news(target_tokens: int, year: str = "2022") -> List[Dict]:
158
+ """Load BBC News articles from RealTimeData/bbc_news_alltime."""
159
+ from datasets import load_dataset, concatenate_datasets
160
+
161
+ logger.info(f"Loading BBC News (year={year}, target={target_tokens:,} tokens)...")
162
+
163
+ months = [f"{year}-{m:02d}" for m in range(1, 13)]
164
+ all_articles = []
165
+
166
+ for month in months:
167
+ try:
168
+ ds = load_dataset("RealTimeData/bbc_news_alltime", month, split="train")
169
+ all_articles.extend([dict(row) for row in ds])
170
+ logger.info(f" Loaded {month}: {len(ds)} articles (total: {len(all_articles)})")
171
+ except Exception as e:
172
+ logger.warning(f" {month} not available: {e}")
173
+ continue
174
+
175
+ documents = []
176
+ total_tokens = 0
177
+
178
+ for article in all_articles:
179
+ text = article.get("content", "")
180
+ if not text or len(text) < 200:
181
+ continue
182
+
183
+ title = article.get("title", "Untitled")
184
+ tokens = count_tokens(text)
185
+
186
+ documents.append({
187
+ "id": hashlib.md5(f"{title}:{article.get('published_date','')}".encode()).hexdigest()[:12],
188
+ "title": title,
189
+ "content": text,
190
+ "source": "bbc_news",
191
+ "tokens": tokens,
192
+ "section": article.get("section", ""),
193
+ "published_date": article.get("published_date", ""),
194
+ })
195
+ total_tokens += tokens
196
+
197
+ if total_tokens >= target_tokens:
198
+ break
199
+
200
+ logger.info(f"✅ BBC News: {len(documents)} articles, {total_tokens:,} tokens")
201
+ return documents
202
+
203
+
204
+ def save_dataset(documents: List[Dict], output_dir: str = "dataset"):
205
+ """Save prepared dataset to disk."""
206
+ os.makedirs(output_dir, exist_ok=True)
207
+
208
+ # Save as JSONL
209
+ output_path = os.path.join(output_dir, "corpus.jsonl")
210
+ with open(output_path, "w") as f:
211
+ for doc in documents:
212
+ f.write(json.dumps(doc, ensure_ascii=False) + "\n")
213
+
214
+ # Save metadata
215
+ total_tokens = sum(d["tokens"] for d in documents)
216
+ meta = {
217
+ "num_documents": len(documents),
218
+ "total_tokens": total_tokens,
219
+ "sources": list(set(d["source"] for d in documents)),
220
+ "avg_tokens_per_doc": total_tokens // max(len(documents), 1),
221
+ "meets_2m_minimum": total_tokens >= 2_000_000,
222
+ "created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
223
+ }
224
+ meta_path = os.path.join(output_dir, "metadata.json")
225
+ with open(meta_path, "w") as f:
226
+ json.dump(meta, f, indent=2)
227
+
228
+ logger.info(f"\n{'='*60}")
229
+ logger.info(f"DATASET SAVED: {output_path}")
230
+ logger.info(f" Documents: {len(documents):,}")
231
+ logger.info(f" Total tokens: {total_tokens:,}")
232
+ logger.info(f" Meets 2M minimum: {'✅ YES' if total_tokens >= 2_000_000 else '❌ NO'}")
233
+ logger.info(f" Metadata: {meta_path}")
234
+ logger.info(f"{'='*60}\n")
235
+ return meta
236
+
237
+
238
+ def ingest_to_tigergraph(documents: List[Dict], max_docs: int = None):
239
+ """Ingest prepared documents into TigerGraph via the ingestion pipeline."""
240
+ from graphrag.layers.graph_layer import GraphLayer
241
+ from graphrag.layers.llm_layer import LLMLayer
242
+ from graphrag.layers.orchestration_layer import EmbeddingManager
243
+ from graphrag.ingestion import IngestionPipeline
244
+
245
+ logger.info("Connecting to TigerGraph...")
246
+ graph = GraphLayer(config={
247
+ "host": os.getenv("TG_HOST", ""),
248
+ "graphname": os.getenv("TG_GRAPH", "GraphRAG"),
249
+ "username": os.getenv("TG_USERNAME", "tigergraph"),
250
+ "password": os.getenv("TG_PASSWORD", ""),
251
+ })
252
+ if not graph.connect():
253
+ logger.error("TigerGraph connection failed. Set TG_HOST and TG_PASSWORD.")
254
+ return
255
+
256
+ graph.create_schema()
257
+ graph.install_queries()
258
+
259
+ llm = LLMLayer(api_key=os.getenv("OPENAI_API_KEY", ""),
260
+ model=os.getenv("LLM_MODEL", "gpt-4o-mini"))
261
+ llm.initialize()
262
+
263
+ embedder = EmbeddingManager()
264
+ embedder.initialize()
265
+
266
+ pipeline = IngestionPipeline(graph, llm, embedder)
267
+
268
+ docs_to_ingest = documents[:max_docs] if max_docs else documents
269
+ logger.info(f"Ingesting {len(docs_to_ingest)} documents into TigerGraph...")
270
+
271
+ custom_docs = [{"id": d["id"], "title": d["title"], "content": d["content"],
272
+ "source": d["source"]} for d in docs_to_ingest]
273
+ stats = pipeline.ingest_custom_documents(custom_docs, extract_entities=True)
274
+
275
+ logger.info(f"✅ Ingestion complete: {stats}")
276
+ return stats
277
+
278
+
279
+ def main():
280
+ parser = argparse.ArgumentParser(
281
+ description="Prepare 2M+ token dataset for GraphRAG Hackathon")
282
+ parser.add_argument("--source", choices=["wikipedia", "arxiv", "bbc", "combined"],
283
+ default="wikipedia", help="Dataset source")
284
+ parser.add_argument("--target-tokens", type=int, default=2_500_000,
285
+ help="Target token count (default: 2.5M for safety margin)")
286
+ parser.add_argument("--domain", default="science",
287
+ help="Domain filter for Wikipedia (science/history/politics/technology/all)")
288
+ parser.add_argument("--year", default="2022",
289
+ help="Year for BBC News")
290
+ parser.add_argument("--output-dir", default="dataset",
291
+ help="Output directory")
292
+ parser.add_argument("--ingest", action="store_true",
293
+ help="Also ingest into TigerGraph (requires TG_HOST, TG_PASSWORD)")
294
+ parser.add_argument("--max-ingest", type=int, default=None,
295
+ help="Max docs to ingest (default: all)")
296
+ args = parser.parse_args()
297
+
298
+ # Load dataset
299
+ if args.source == "wikipedia":
300
+ documents = load_wikipedia(args.target_tokens, domain=args.domain)
301
+ elif args.source == "arxiv":
302
+ documents = load_arxiv(args.target_tokens)
303
+ elif args.source == "bbc":
304
+ documents = load_bbc_news(args.target_tokens, year=args.year)
305
+ elif args.source == "combined":
306
+ # Mix: 60% Wikipedia + 25% arXiv + 15% BBC
307
+ wiki_target = int(args.target_tokens * 0.6)
308
+ arxiv_target = int(args.target_tokens * 0.25)
309
+ bbc_target = int(args.target_tokens * 0.15)
310
+ documents = (load_wikipedia(wiki_target, domain=args.domain) +
311
+ load_arxiv(arxiv_target) +
312
+ load_bbc_news(bbc_target, year=args.year))
313
+
314
+ if not documents:
315
+ logger.error("No documents loaded. Check your internet connection.")
316
+ sys.exit(1)
317
+
318
+ # Save to disk
319
+ meta = save_dataset(documents, args.output_dir)
320
+
321
+ if not meta["meets_2m_minimum"]:
322
+ logger.warning(f"⚠️ Only {meta['total_tokens']:,} tokens. "
323
+ f"Need {2_000_000 - meta['total_tokens']:,} more. "
324
+ f"Try --target-tokens {args.target_tokens + 1_000_000}")
325
+
326
+ # Ingest into TigerGraph
327
+ if args.ingest:
328
+ ingest_to_tigergraph(documents, max_docs=args.max_ingest)
329
+
330
+
331
+ if __name__ == "__main__":
332
+ main()