surrogate-1 / bin /lib /rag_retrieve.py
Ashira Pitchayapakayakul
feat: RAG fully wired \u2014 FTS5 + vector hybrid retrieval, no train needed for now
7cbea95
"""
RAG retrieval β€” query FTS5 + vector index for similar past work, inject as context.
Hybrid retrieval:
1. FTS5 keyword match over training-pairs (fast, exact matches)
2. Optional: vector semantic via nomic-embed-text + sqlite-vec (semantic intent)
3. Reciprocal rank fusion of both β†’ top-K to inject
Usage from orchestrate's call_agent BEFORE LLM call:
from rag_retrieve import retrieve_similar
context = retrieve_similar(prompt, top_k=3, max_kb=10)
# inject `context` into prompt as 'Similar past work:'
Cache hits within 60s window (avoid repeat queries during multi-stage pipeline).
"""
from __future__ import annotations
import hashlib
import json
import os
import sqlite3
import time
import urllib.request
from pathlib import Path
from typing import Iterable
HOME = Path(os.environ.get("HOME", "/home/hermes"))
FTS_DB = HOME / ".surrogate/state/self-ingest.db"
VEC_DB = HOME / ".surrogate/state/rag-vectors.db"
CACHE_DIR = HOME / ".surrogate/state/rag-cache"
CACHE_DIR.mkdir(parents=True, exist_ok=True)
OLLAMA_EMBED_URL = "http://127.0.0.1:11434/api/embeddings"
EMBED_MODEL = "nomic-embed-text"
def _cache_get(key: str) -> str | None:
cf = CACHE_DIR / f"{key}.txt"
if cf.exists() and (time.time() - cf.stat().st_mtime) < 60:
return cf.read_text()
return None
def _cache_put(key: str, value: str) -> None:
cf = CACHE_DIR / f"{key}.txt"
cf.write_text(value)
def _hash_key(query: str, top_k: int) -> str:
return hashlib.md5(f"{query[:500]}|{top_k}".encode()).hexdigest()[:12]
def _fts_search(query: str, top_k: int = 5) -> list[tuple[str, str, float, str]]:
"""Returns [(prompt, response, score, source), ...] from FTS5 index."""
if not FTS_DB.exists():
return []
# Sanitize query for FTS5 β€” extract keywords, drop stopwords
import re
words = re.findall(r'\b[a-zA-Z][a-zA-Z0-9_-]{2,}\b', query)
stop = {"the", "and", "for", "with", "from", "this", "that", "what",
"when", "where", "how", "why", "which", "into", "your"}
keywords = [w for w in words if w.lower() not in stop][:10]
if not keywords:
return []
fts_query = " OR ".join(f'"{kw}"' for kw in keywords)
try:
with sqlite3.connect(str(FTS_DB), timeout=3) as c:
rows = c.execute(
"SELECT prompt, response, rank, source FROM pairs "
"WHERE pairs MATCH ? "
"ORDER BY rank LIMIT ?",
(fts_query, top_k * 2)
).fetchall()
return [(r[0], r[1], -float(r[2]), r[3]) for r in rows[:top_k]]
except Exception as e:
print(f"FTS error: {e}", file=__import__("sys").stderr)
return []
def _embed_query(text: str) -> list[float] | None:
"""Get embedding for a query via Ollama nomic-embed-text."""
try:
body = json.dumps({"model": EMBED_MODEL, "prompt": text[:2000]}).encode()
req = urllib.request.Request(OLLAMA_EMBED_URL, data=body,
headers={"Content-Type": "application/json"})
with urllib.request.urlopen(req, timeout=8) as r:
return json.load(r).get("embedding") or None
except Exception:
return None
def _vec_search(query_vec: list[float], top_k: int = 5) -> list[tuple[str, str, float, str]]:
"""Vector cosine search via sqlite β€” fallback to numpy if no sqlite-vec."""
if not VEC_DB.exists() or not query_vec:
return []
try:
import numpy as np
with sqlite3.connect(str(VEC_DB), timeout=3) as c:
rows = c.execute(
"SELECT prompt, response, embedding, source FROM vectors LIMIT 50000"
).fetchall()
if not rows:
return []
q = np.array(query_vec, dtype=np.float32)
q /= (np.linalg.norm(q) + 1e-9)
scored: list[tuple[str, str, float, str]] = []
for prompt, response, emb_blob, src in rows:
emb = np.frombuffer(emb_blob, dtype=np.float32)
if emb.shape[0] != q.shape[0]:
continue
cos = float(np.dot(q, emb / (np.linalg.norm(emb) + 1e-9)))
scored.append((prompt, response, cos, src))
scored.sort(key=lambda x: -x[2])
return scored[:top_k]
except Exception as e:
print(f"Vec search err: {e}", file=__import__("sys").stderr)
return []
def _fuse(fts: list, vec: list, top_k: int = 3) -> list[tuple[str, str, str, float]]:
"""Reciprocal rank fusion β€” combine FTS + vec rankings."""
seen: dict[str, dict] = {}
for rank, (prompt, response, _, src) in enumerate(fts):
key = prompt[:100]
seen.setdefault(key, {"prompt": prompt, "response": response, "source": src,
"rrf": 0.0})
seen[key]["rrf"] += 1.0 / (60 + rank)
for rank, (prompt, response, _, src) in enumerate(vec):
key = prompt[:100]
seen.setdefault(key, {"prompt": prompt, "response": response, "source": src,
"rrf": 0.0})
seen[key]["rrf"] += 1.0 / (60 + rank)
ranked = sorted(seen.values(), key=lambda x: -x["rrf"])
return [(r["prompt"], r["response"], r["source"], r["rrf"]) for r in ranked[:top_k]]
def retrieve_similar(query: str, top_k: int = 3, max_kb: int = 10) -> str:
"""Returns markdown-formatted 'Similar past work' block to inject in prompt.
Empty string if no good matches."""
if not query or len(query) < 30:
return ""
cache_key = _hash_key(query, top_k)
cached = _cache_get(cache_key)
if cached is not None:
return cached
# Run both retrievals in parallel (best-effort)
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as ex:
fts_fut = ex.submit(_fts_search, query, top_k)
# Vec retrieval optional β€” only if Ollama embeddings available
vec_fut = ex.submit(lambda: _vec_search(_embed_query(query) or [], top_k))
try:
fts_results = fts_fut.result(timeout=5)
except Exception:
fts_results = []
try:
vec_results = vec_fut.result(timeout=10)
except Exception:
vec_results = []
fused = _fuse(fts_results, vec_results, top_k)
if not fused:
_cache_put(cache_key, "")
return ""
out_parts = ["### Similar past work (from training-pairs.jsonl):\n"]
budget = max_kb * 1024
for i, (p, r, src, score) in enumerate(fused, 1):
chunk = f"\n#### Match {i} (source: {src}, score: {score:.3f})\n"
chunk += f"**Q:** {p[:600]}\n"
chunk += f"**A:** {r[:1200]}\n"
if len(chunk) > budget:
break
out_parts.append(chunk)
budget -= len(chunk)
out = "".join(out_parts)
_cache_put(cache_key, out)
return out
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("usage: rag_retrieve.py <query>", file=sys.stderr)
sys.exit(2)
q = " ".join(sys.argv[1:])
print(retrieve_similar(q, top_k=3))