gaurv007 commited on
Commit
08897f5
·
verified ·
1 Parent(s): 713381e

Upload alpha_factory/infra/rag.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. alpha_factory/infra/rag.py +140 -0
alpha_factory/infra/rag.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG System — ChromaDB + arXiv paper indexer.
3
+ Retrieves relevant academic papers for the Hypothesis Hunter.
4
+ """
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+
10
+ class PaperRAG:
11
+ """
12
+ RAG retrieval over arXiv q-fin papers.
13
+ Uses ChromaDB for vector storage + bge-small for embeddings.
14
+ """
15
+
16
+ def __init__(self, persist_dir: Path, collection_name: str = "qfin_papers"):
17
+ try:
18
+ import chromadb
19
+ self.client = chromadb.PersistentClient(path=str(persist_dir))
20
+ self.collection = self.client.get_or_create_collection(
21
+ name=collection_name,
22
+ metadata={"hnsw:space": "cosine"},
23
+ )
24
+ self._available = True
25
+ except ImportError:
26
+ self._available = False
27
+ print("[WARN] chromadb not installed — RAG disabled. pip install chromadb")
28
+
29
+ @property
30
+ def available(self) -> bool:
31
+ return self._available
32
+
33
+ def index_papers(self, papers: list[dict]):
34
+ """
35
+ Index papers into ChromaDB.
36
+ Each paper: {"id": "arxiv_id", "title": "...", "abstract": "...", "categories": [...]}
37
+ """
38
+ if not self._available:
39
+ return
40
+
41
+ ids = [p["id"] for p in papers]
42
+ documents = [f"{p['title']}\n\n{p['abstract']}" for p in papers]
43
+ metadatas = [{"title": p["title"], "categories": ",".join(p.get("categories", []))} for p in papers]
44
+
45
+ # Batch insert (ChromaDB handles embedding automatically with default model)
46
+ batch_size = 100
47
+ for i in range(0, len(ids), batch_size):
48
+ self.collection.upsert(
49
+ ids=ids[i:i+batch_size],
50
+ documents=documents[i:i+batch_size],
51
+ metadatas=metadatas[i:i+batch_size],
52
+ )
53
+
54
+ def retrieve(self, query: str, n_results: int = 3) -> list[str]:
55
+ """
56
+ Retrieve top-N relevant paper abstracts for a given theme/query.
57
+ Returns list of formatted paper strings.
58
+ """
59
+ if not self._available or self.collection.count() == 0:
60
+ return []
61
+
62
+ results = self.collection.query(
63
+ query_texts=[query],
64
+ n_results=n_results,
65
+ )
66
+
67
+ papers = []
68
+ if results["documents"] and results["documents"][0]:
69
+ for doc, meta in zip(results["documents"][0], results["metadatas"][0]):
70
+ papers.append(f"[{meta.get('title', 'Unknown')}]\n{doc[:500]}")
71
+
72
+ return papers
73
+
74
+ def count(self) -> int:
75
+ """Number of indexed papers."""
76
+ if not self._available:
77
+ return 0
78
+ return self.collection.count()
79
+
80
+
81
+ async def fetch_arxiv_papers(
82
+ categories: list[str] = ["q-fin.PM", "q-fin.ST", "q-fin.CP", "stat.AP"],
83
+ max_results: int = 500,
84
+ start_year: int = 2021,
85
+ ) -> list[dict]:
86
+ """
87
+ Fetch papers from arXiv API.
88
+ Returns list of {id, title, abstract, categories, published}.
89
+ """
90
+ import asyncio
91
+ import aiohttp
92
+ from xml.etree import ElementTree as ET
93
+
94
+ base_url = "http://export.arxiv.org/api/query"
95
+ papers = []
96
+
97
+ for cat in categories:
98
+ query = f"cat:{cat}"
99
+ params = {
100
+ "search_query": query,
101
+ "start": 0,
102
+ "max_results": max_results // len(categories),
103
+ "sortBy": "submittedDate",
104
+ "sortOrder": "descending",
105
+ }
106
+
107
+ try:
108
+ async with aiohttp.ClientSession() as session:
109
+ async with session.get(base_url, params=params) as resp:
110
+ if resp.status != 200:
111
+ continue
112
+ text = await resp.text()
113
+
114
+ # Parse XML
115
+ ns = {"atom": "http://www.w3.org/2005/Atom", "arxiv": "http://arxiv.org/schemas/atom"}
116
+ root = ET.fromstring(text)
117
+
118
+ for entry in root.findall("atom:entry", ns):
119
+ arxiv_id = entry.find("atom:id", ns).text.split("/")[-1]
120
+ title = entry.find("atom:title", ns).text.strip().replace("\n", " ")
121
+ abstract = entry.find("atom:summary", ns).text.strip().replace("\n", " ")
122
+ published = entry.find("atom:published", ns).text[:10]
123
+
124
+ cats = [c.attrib["term"] for c in entry.findall("arxiv:primary_category", ns)]
125
+
126
+ papers.append({
127
+ "id": arxiv_id,
128
+ "title": title,
129
+ "abstract": abstract,
130
+ "categories": cats,
131
+ "published": published,
132
+ })
133
+
134
+ # Rate limit arxiv API
135
+ await asyncio.sleep(3)
136
+
137
+ except Exception as e:
138
+ print(f"[WARN] arXiv fetch failed for {cat}: {e}")
139
+
140
+ return papers