gaurv007 commited on
Commit
579a6a1
·
verified ·
1 Parent(s): 9c2119c

Delete alpha_factory/infra/rag.py, alpha_factory/infra/rag.py

Browse files
Files changed (1) hide show
  1. alpha_factory/infra/rag.py +0 -140
alpha_factory/infra/rag.py DELETED
@@ -1,140 +0,0 @@
1
- """
2
- RAG System — ChromaDB + arXiv paper indexer.
3
- Retrieves relevant academic papers for the Hypothesis Hunter.
4
- """
5
- import os
6
- from pathlib import Path
7
- from typing import Optional
8
-
9
-
10
- class PaperRAG:
11
- """
12
- RAG retrieval over arXiv q-fin papers.
13
- Uses ChromaDB for vector storage + bge-small for embeddings.
14
- """
15
-
16
- def __init__(self, persist_dir: Path, collection_name: str = "qfin_papers"):
17
- try:
18
- import chromadb
19
- self.client = chromadb.PersistentClient(path=str(persist_dir))
20
- self.collection = self.client.get_or_create_collection(
21
- name=collection_name,
22
- metadata={"hnsw:space": "cosine"},
23
- )
24
- self._available = True
25
- except ImportError:
26
- self._available = False
27
- print("[WARN] chromadb not installed — RAG disabled. pip install chromadb")
28
-
29
- @property
30
- def available(self) -> bool:
31
- return self._available
32
-
33
- def index_papers(self, papers: list[dict]):
34
- """
35
- Index papers into ChromaDB.
36
- Each paper: {"id": "arxiv_id", "title": "...", "abstract": "...", "categories": [...]}
37
- """
38
- if not self._available:
39
- return
40
-
41
- ids = [p["id"] for p in papers]
42
- documents = [f"{p['title']}\n\n{p['abstract']}" for p in papers]
43
- metadatas = [{"title": p["title"], "categories": ",".join(p.get("categories", []))} for p in papers]
44
-
45
- # Batch insert (ChromaDB handles embedding automatically with default model)
46
- batch_size = 100
47
- for i in range(0, len(ids), batch_size):
48
- self.collection.upsert(
49
- ids=ids[i:i+batch_size],
50
- documents=documents[i:i+batch_size],
51
- metadatas=metadatas[i:i+batch_size],
52
- )
53
-
54
- def retrieve(self, query: str, n_results: int = 3) -> list[str]:
55
- """
56
- Retrieve top-N relevant paper abstracts for a given theme/query.
57
- Returns list of formatted paper strings.
58
- """
59
- if not self._available or self.collection.count() == 0:
60
- return []
61
-
62
- results = self.collection.query(
63
- query_texts=[query],
64
- n_results=n_results,
65
- )
66
-
67
- papers = []
68
- if results["documents"] and results["documents"][0]:
69
- for doc, meta in zip(results["documents"][0], results["metadatas"][0]):
70
- papers.append(f"[{meta.get('title', 'Unknown')}]\n{doc[:500]}")
71
-
72
- return papers
73
-
74
- def count(self) -> int:
75
- """Number of indexed papers."""
76
- if not self._available:
77
- return 0
78
- return self.collection.count()
79
-
80
-
81
- async def fetch_arxiv_papers(
82
- categories: list[str] = ["q-fin.PM", "q-fin.ST", "q-fin.CP", "stat.AP"],
83
- max_results: int = 500,
84
- start_year: int = 2021,
85
- ) -> list[dict]:
86
- """
87
- Fetch papers from arXiv API.
88
- Returns list of {id, title, abstract, categories, published}.
89
- """
90
- import asyncio
91
- import aiohttp
92
- from xml.etree import ElementTree as ET
93
-
94
- base_url = "http://export.arxiv.org/api/query"
95
- papers = []
96
-
97
- for cat in categories:
98
- query = f"cat:{cat}"
99
- params = {
100
- "search_query": query,
101
- "start": 0,
102
- "max_results": max_results // len(categories),
103
- "sortBy": "submittedDate",
104
- "sortOrder": "descending",
105
- }
106
-
107
- try:
108
- async with aiohttp.ClientSession() as session:
109
- async with session.get(base_url, params=params) as resp:
110
- if resp.status != 200:
111
- continue
112
- text = await resp.text()
113
-
114
- # Parse XML
115
- ns = {"atom": "http://www.w3.org/2005/Atom", "arxiv": "http://arxiv.org/schemas/atom"}
116
- root = ET.fromstring(text)
117
-
118
- for entry in root.findall("atom:entry", ns):
119
- arxiv_id = entry.find("atom:id", ns).text.split("/")[-1]
120
- title = entry.find("atom:title", ns).text.strip().replace("\n", " ")
121
- abstract = entry.find("atom:summary", ns).text.strip().replace("\n", " ")
122
- published = entry.find("atom:published", ns).text[:10]
123
-
124
- cats = [c.attrib["term"] for c in entry.findall("arxiv:primary_category", ns)]
125
-
126
- papers.append({
127
- "id": arxiv_id,
128
- "title": title,
129
- "abstract": abstract,
130
- "categories": cats,
131
- "published": published,
132
- })
133
-
134
- # Rate limit arxiv API
135
- await asyncio.sleep(3)
136
-
137
- except Exception as e:
138
- print(f"[WARN] arXiv fetch failed for {cat}: {e}")
139
-
140
- return papers