Spaces:
Runtime error
Runtime error
Commit ·
b62e029
0
Parent(s):
Initial commit
Browse files- .gitattributes +35 -0
- .github/workflows/deploy.yml +45 -0
- .gitignore +1 -0
- Dockerfile +22 -0
- README.md +106 -0
- api/dependencies.py +12 -0
- api/schemas/search.py +44 -0
- api/v1/search.py +91 -0
- api/v1/system.py +12 -0
- core/config.py +54 -0
- core/exceptions.py +72 -0
- core/logger.py +45 -0
- main.py +109 -0
- models/embedder.py +98 -0
- models/reranker.py +77 -0
- requirements.txt +31 -0
- scripts/data_pipeline.py +387 -0
- scripts/setup_db.py +51 -0
- services/search_service.py +146 -0
- storage/qdrant_client.py +74 -0
- storage/sqlite_client.py +72 -0
- templates/index.html +92 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/deploy.yml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Deploy to Hugging Face Spaces
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [ main ]
|
| 6 |
+
paths:
|
| 7 |
+
- "api/**"
|
| 8 |
+
- "core/**"
|
| 9 |
+
- "models/**"
|
| 10 |
+
- "services/**"
|
| 11 |
+
- "storage/**"
|
| 12 |
+
- "scripts/setup_db.py"
|
| 13 |
+
- "templates/**"
|
| 14 |
+
- "static/**"
|
| 15 |
+
- "utils/**"
|
| 16 |
+
- "main.py"
|
| 17 |
+
- "Dockerfile"
|
| 18 |
+
- "requirements.txt"
|
| 19 |
+
- ".github/workflows/deploy.yml"
|
| 20 |
+
- "README.md"
|
| 21 |
+
|
| 22 |
+
jobs:
|
| 23 |
+
deploy:
|
| 24 |
+
runs-on: ubuntu-latest
|
| 25 |
+
steps:
|
| 26 |
+
- name: Checkout code
|
| 27 |
+
uses: actions/checkout@v4
|
| 28 |
+
with:
|
| 29 |
+
fetch-depth: 0
|
| 30 |
+
|
| 31 |
+
- name: Set up Python
|
| 32 |
+
uses: actions/setup-python@v4
|
| 33 |
+
with:
|
| 34 |
+
python-version: '3.10'
|
| 35 |
+
|
| 36 |
+
- name: Install HF CLI
|
| 37 |
+
run: pip install "huggingface_hub[cli]"
|
| 38 |
+
|
| 39 |
+
- name: Push to HF Space
|
| 40 |
+
env:
|
| 41 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 42 |
+
run: |
|
| 43 |
+
git remote add hf https://m97j:$HF_TOKEN@huggingface.co/spaces/m97j/knowledge-engine
|
| 44 |
+
|
| 45 |
+
git push --force hf main
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
_old.git/
|
Dockerfile
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile
|
| 2 |
+
|
| 3 |
+
FROM python:3.11-slim
|
| 4 |
+
|
| 5 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 6 |
+
ENV PYTHONUNBUFFERED=1
|
| 7 |
+
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
+
build-essential \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
COPY requirements.txt .
|
| 15 |
+
RUN pip install --upgrade pip && \
|
| 16 |
+
pip install --no-cache-dir -r requirements.txt
|
| 17 |
+
|
| 18 |
+
COPY . .
|
| 19 |
+
|
| 20 |
+
VOLUME ["/app/data"]
|
| 21 |
+
|
| 22 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Knowledge Engine
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# 🔍 Knowledge Engine
|
| 13 |
+
|
| 14 |
+
[](https://huggingface.co/spaces/m97j/knowledge-engine)
|
| 15 |
+
[](https://www.python.org/downloads/release/python-3100/)
|
| 16 |
+
[](https://opensource.org/licenses/Apache-2.0)
|
| 17 |
+
|
| 18 |
+
> **High-performance Hybrid Search & Reranking Engine based on BGE-M3.** > An advanced knowledge retrieval API system that combines Dense/Sparse embeddings and optimizes precision with Cross-Encoders.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## 🚀 Key Features
|
| 24 |
+
* **Hybrid Search:** Seamlessly combines Dense & Sparse vector retrieval using Qdrant's Native Fusion API (BGE-M3).
|
| 25 |
+
* **Re-ranking:** Ensures top-tier precision by re-ordering search results via Cross-Encoder models.
|
| 26 |
+
* **Clean Architecture:** Highly modularized layers (API, Service, Storage, Models) for superior maintainability and scalability.
|
| 27 |
+
* **CI/CD Pipeline:** Fully automated deployment to Hugging Face Spaces using GitHub Actions and Docker.
|
| 28 |
+
* **Auto-Healing Data:** Robust startup logic via FastAPI `lifespan` that automatically synchronizes and validates the knowledge base.
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## 🏗 Project Structure
|
| 33 |
+
This project follows the **Separation of Concerns (SoC)** principle to ensure the system remains extensible and testable.
|
| 34 |
+
|
| 35 |
+
```text
|
| 36 |
+
├── api/ # API Routing & Dependency Injection (DI)
|
| 37 |
+
├── core/ # Global Configuration (Pydantic Settings) & Exception Handling
|
| 38 |
+
├── models/ # AI Model Inference (Embedder, Reranker)
|
| 39 |
+
├── services/ # Business Logic & Search Pipeline Orchestration
|
| 40 |
+
├── storage/ # Infrastructure Layer (Qdrant, SQLite Clients)
|
| 41 |
+
├── scripts/ # Data Pipeline & Database Setup Scripts
|
| 42 |
+
├── templates/ # Demo UI (Jinja2 Templates)
|
| 43 |
+
└── main.py # App Entry Point & Lifespan Management
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## 🛠 Tech Stack
|
| 49 |
+
* **Framework:** FastAPI
|
| 50 |
+
* **Vector DB:** Qdrant (Local Path Mode)
|
| 51 |
+
* **RDBMS:** SQLite (Metadata & Corpus Storage)
|
| 52 |
+
* **ML Models:**
|
| 53 |
+
* `BAAI/bge-m3` (Multi-functional Embedding)
|
| 54 |
+
* `BAAI/bge-reranker-v2-m3` (Cross-Encoder)
|
| 55 |
+
* **DevOps:** Docker, GitHub Actions, Hugging Face Hub
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## 🔧 Installation & Setup
|
| 60 |
+
|
| 61 |
+
### Prerequisites
|
| 62 |
+
* Python 3.10 or higher
|
| 63 |
+
* Hugging Face Access Token (Read/Write)
|
| 64 |
+
|
| 65 |
+
### Running Locally
|
| 66 |
+
1. Clone the repository:
|
| 67 |
+
```bash
|
| 68 |
+
git clone [https://github.com/m97j/knowledge-engine.git](https://github.com/m97j/knowledge-engine.git)
|
| 69 |
+
cd knowledge-engine
|
| 70 |
+
```
|
| 71 |
+
2. Install dependencies:
|
| 72 |
+
```bash
|
| 73 |
+
pip install -r requirements.txt
|
| 74 |
+
```
|
| 75 |
+
3. Run the application (The system will automatically download the necessary DB files on startup):
|
| 76 |
+
```bash
|
| 77 |
+
python main.py
|
| 78 |
+
# OR using uvicorn
|
| 79 |
+
uvicorn main:app --host 0.0.0.0 --port 7860
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## 📡 API Endpoints
|
| 85 |
+
| Method | Endpoint | Description |
|
| 86 |
+
| :--- | :--- | :--- |
|
| 87 |
+
| `GET` | `/` | Redirects to Search Demo UI |
|
| 88 |
+
| `POST` | `/api/v1/search/` | Executes JSON-based Hybrid Search |
|
| 89 |
+
| `GET` | `/api/v1/system/health/ping` | System health check (Heartbeat) |
|
| 90 |
+
|
| 91 |
+
---
|
| 92 |
+
|
| 93 |
+
## 💡 Architecture Insights
|
| 94 |
+
1. **Dependency Injection:** Uses FastAPI `app.state` to manage singletons of AI models and DB clients, allowing for easy mocking during unit testing.
|
| 95 |
+
2. **Hybrid RAG Pipeline:** Beyond simple vector similarity, this engine leverages Sparse embeddings for keyword-level precision, merged via Reciprocal Rank Fusion (RRF).
|
| 96 |
+
3. **Deployment Ready:** Optimized for PaaS environments (like HF Spaces) through a containerized Docker setup and automated CI/CD.
|
| 97 |
+
|
| 98 |
+
---
|
| 99 |
+
|
| 100 |
+
## 📄 Documentation
|
| 101 |
+
For more detailed technical documentation, design decisions, and troubleshooting, please visit:
|
| 102 |
+
* [Personal Archive Link](https://minjae-portfolio.vercel.app/projects/ke)
|
| 103 |
+
* [Technical Design Blog](https://minjae-portfolio.vercel.app/blogs/ke-pd)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
---
|
api/dependencies.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/dependencies.py
|
| 2 |
+
|
| 3 |
+
from fastapi import Request
|
| 4 |
+
|
| 5 |
+
from services.search_service import HybridSearchService
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_search_service(request: Request) -> HybridSearchService:
|
| 9 |
+
"""
|
| 10 |
+
Dependency injection function for FastAPI Depends().
|
| 11 |
+
"""
|
| 12 |
+
return request.app.state.search_service
|
api/schemas/search.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/schemas/search.py
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# ---------------------------
|
| 9 |
+
# Request
|
| 10 |
+
# ---------------------------
|
| 11 |
+
class SearchRequest(BaseModel):
|
| 12 |
+
query: str = Field(..., description="Search query")
|
| 13 |
+
top_k: int = Field(default=5, ge=1, le=50)
|
| 14 |
+
|
| 15 |
+
# optional
|
| 16 |
+
use_reranker: Optional[bool] = True
|
| 17 |
+
|
| 18 |
+
# ---------------------------
|
| 19 |
+
# Document metadata
|
| 20 |
+
# ---------------------------
|
| 21 |
+
class DocumentMetadata(BaseModel):
|
| 22 |
+
doc_id: int
|
| 23 |
+
title: str
|
| 24 |
+
lang: str
|
| 25 |
+
url: Optional[str] = None
|
| 26 |
+
date_modified: Optional[str] = None
|
| 27 |
+
|
| 28 |
+
# ---------------------------
|
| 29 |
+
# Result item (LLM-friendly)
|
| 30 |
+
# ---------------------------
|
| 31 |
+
class SearchResultItem(BaseModel):
|
| 32 |
+
chunk_id: int
|
| 33 |
+
text: str
|
| 34 |
+
score: float = Field(..., description="Reranking score (0.0 to 1.0)")
|
| 35 |
+
metadata: DocumentMetadata
|
| 36 |
+
scoring_details: Optional[Dict[str, Any]] = None # optional
|
| 37 |
+
|
| 38 |
+
# ---------------------------
|
| 39 |
+
# Response
|
| 40 |
+
# ---------------------------
|
| 41 |
+
class SearchResponse(BaseModel):
|
| 42 |
+
query: str
|
| 43 |
+
results: List[SearchResultItem]
|
| 44 |
+
latency_ms: int
|
api/v1/search.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/v1/search.py
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends, Form, HTTPException, Request
|
| 4 |
+
from fastapi.responses import HTMLResponse
|
| 5 |
+
from fastapi.templating import Jinja2Templates
|
| 6 |
+
|
| 7 |
+
from api.dependencies import get_search_service
|
| 8 |
+
from api.schemas.search import SearchRequest, SearchResponse
|
| 9 |
+
from core.logger import setup_logger
|
| 10 |
+
from services.search_service import HybridSearchService
|
| 11 |
+
|
| 12 |
+
logger = setup_logger("search_api")
|
| 13 |
+
|
| 14 |
+
router = APIRouter(prefix="/search", tags=["Search"])
|
| 15 |
+
templates = Jinja2Templates(directory="templates")
|
| 16 |
+
|
| 17 |
+
# -------------------------------------
|
| 18 |
+
# Json API Endpoint for Hybrid Search
|
| 19 |
+
# -------------------------------------
|
| 20 |
+
@router.post("/", response_model=SearchResponse, summary="Execute Hybrid Search (JSON)")
|
| 21 |
+
async def execute_search(
|
| 22 |
+
request_data: SearchRequest,
|
| 23 |
+
search_service: HybridSearchService = Depends(get_search_service)
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Execute a hybrid search using the provided query and parameters.
|
| 27 |
+
"""
|
| 28 |
+
try:
|
| 29 |
+
search_output = search_service.search(
|
| 30 |
+
query=request_data.query,
|
| 31 |
+
top_k=request_data.top_k
|
| 32 |
+
)
|
| 33 |
+
return SearchResponse(
|
| 34 |
+
query=search_output["query"],
|
| 35 |
+
results=search_output["results"],
|
| 36 |
+
latency_ms=search_output["latency_ms"]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
except ValueError as ve:
|
| 40 |
+
logger.warning(f"Invalid search request: {ve}")
|
| 41 |
+
raise HTTPException(status_code=400, detail=str(ve))
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.error(f"Search Execution Failed: {e}", exc_info=True)
|
| 44 |
+
raise HTTPException(status_code=500, detail="Internal server error during search.")
|
| 45 |
+
|
| 46 |
+
# -------------------------------------
|
| 47 |
+
# HTML Demo Endpoint for Manual Testing
|
| 48 |
+
# -------------------------------------
|
| 49 |
+
@router.get("/demo", response_class=HTMLResponse, summary="Search Demo UI (GET)")
|
| 50 |
+
async def demo_page_get(request: Request):
|
| 51 |
+
"""
|
| 52 |
+
Render a simple HTML page with a search form for manual testing of the hybrid search functionality.
|
| 53 |
+
"""
|
| 54 |
+
return templates.TemplateResponse(
|
| 55 |
+
"index.html",
|
| 56 |
+
{"request": request, "results": None, "query": ""}
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
@router.post("/demo", response_class=HTMLResponse, summary="Search Demo UI (POST)")
|
| 60 |
+
async def demo_page_post(
|
| 61 |
+
request: Request,
|
| 62 |
+
query: str = Form(...),
|
| 63 |
+
search_service: HybridSearchService = Depends(get_search_service)
|
| 64 |
+
):
|
| 65 |
+
"""
|
| 66 |
+
Handle form submission from the demo page, execute the search, and render results in the same template.
|
| 67 |
+
"""
|
| 68 |
+
try:
|
| 69 |
+
search_output = search_service.search(query=query, top_k=5)
|
| 70 |
+
|
| 71 |
+
return templates.TemplateResponse(
|
| 72 |
+
"index.html",
|
| 73 |
+
{
|
| 74 |
+
"request": request,
|
| 75 |
+
"results": search_output["results"],
|
| 76 |
+
"query": query,
|
| 77 |
+
"latency_ms": search_output["latency_ms"]
|
| 78 |
+
}
|
| 79 |
+
)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"Demo Search Failed: {e}", exc_info=True)
|
| 82 |
+
return templates.TemplateResponse(
|
| 83 |
+
"index.html",
|
| 84 |
+
{
|
| 85 |
+
"request": request,
|
| 86 |
+
"results": None,
|
| 87 |
+
"query": query,
|
| 88 |
+
"error_message": "An error occurred while processing your search. Please try again."
|
| 89 |
+
},
|
| 90 |
+
status_code=500
|
| 91 |
+
)
|
api/v1/system.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/v1/system.py
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter
|
| 4 |
+
|
| 5 |
+
router = APIRouter(prefix="/health", tags=["Health Check"])
|
| 6 |
+
|
| 7 |
+
# ---------------------------
|
| 8 |
+
# Debug endpoint (optional)
|
| 9 |
+
# ---------------------------
|
| 10 |
+
@router.get("/ping")
|
| 11 |
+
def ping():
|
| 12 |
+
return {"message": "pong"}
|
core/config.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# core/config.py
|
| 2 |
+
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
|
| 5 |
+
from pydantic import Field
|
| 6 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Settings(BaseSettings):
|
| 10 |
+
"""
|
| 11 |
+
This is a class that manages application global settings.
|
| 12 |
+
It reads values from .env files or system environment variables and strictly validates types using Pydantic.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
# 1. Project Info
|
| 16 |
+
PROJECT_NAME: str = Field(default="Knowledge Engine", description="Project name")
|
| 17 |
+
VERSION: str = Field(default="1.0.0", description="API version")
|
| 18 |
+
ENVIRONMENT: str = Field(default="development", description="Execution environment (development, staging, production)")
|
| 19 |
+
LOG_LEVEL: str = Field(default="INFO", description="Global logging level")
|
| 20 |
+
DATA_DIR: str = Field(default="./data", description="Data storage directory path")
|
| 21 |
+
REPO_ID: str = Field(default="m97j/ke-store", description="Hugging Face repository ID")
|
| 22 |
+
|
| 23 |
+
# 2. Storage Settings (Vector DB & RDBMS)
|
| 24 |
+
QDRANT_PATH: str = Field(default="./data/qdrant", description="Qdrant local storage path")
|
| 25 |
+
QDRANT_COLLECTION: str = Field(default="knowledge_base", description="Qdrant collection name")
|
| 26 |
+
SQLITE_PATH: str = Field(default="./data/corpus/corpus.sqlite", description="SQLite DB file path")
|
| 27 |
+
|
| 28 |
+
# 3. Model Settings (Embedder & Reranker)
|
| 29 |
+
EMBEDDER_NAME: str = Field(default="BAAI/bge-m3", description="FlagEmbedding model name")
|
| 30 |
+
RERANKER_NAME: str = Field(default="BAAI/bge-reranker-v2-m3", description="Cross-Encoder model name")
|
| 31 |
+
USE_FP16: bool = Field(default=True, description="Whether to use FP16 precision in GPU environment")
|
| 32 |
+
|
| 33 |
+
# 4. Search Hyperparameters
|
| 34 |
+
DEFAULT_TOP_K: int = Field(default=5, description="Final number of documents to return")
|
| 35 |
+
QDRANT_FETCH_LIMIT: int = Field(default=50, description="Number of candidates to fetch from Vector DB before reranking")
|
| 36 |
+
|
| 37 |
+
# Pydantic v2 settings
|
| 38 |
+
model_config = SettingsConfigDict(
|
| 39 |
+
env_file=".env",
|
| 40 |
+
env_file_encoding="utf-8",
|
| 41 |
+
case_sensitive=True, # case-sensitive environment variables
|
| 42 |
+
extra="ignore" # ignore unexpected fields in .env or environment variables
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
@lru_cache()
|
| 46 |
+
def get_settings() -> Settings:
|
| 47 |
+
"""
|
| 48 |
+
It caches and returns the Settings object as a Singleton.
|
| 49 |
+
It offers performance advantages as it does not read or parse the file every time.
|
| 50 |
+
"""
|
| 51 |
+
return Settings()
|
| 52 |
+
|
| 53 |
+
# Instantiate as a global variable so that it can be easily imported from other modules
|
| 54 |
+
settings = get_settings()
|
core/exceptions.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# core/exceptions.py
|
| 2 |
+
|
| 3 |
+
from fastapi import Request, status
|
| 4 |
+
from fastapi.responses import JSONResponse
|
| 5 |
+
|
| 6 |
+
from core.logger import setup_logger
|
| 7 |
+
|
| 8 |
+
logger = setup_logger("exception_handler")
|
| 9 |
+
|
| 10 |
+
# ---------------------------------------------------
|
| 11 |
+
# Base Exception (Parent class of all custom errors)
|
| 12 |
+
# ---------------------------------------------------
|
| 13 |
+
class KnowledgeEngineException(Exception):
|
| 14 |
+
"""Base custom exception class for the Knowledge Engine application"""
|
| 15 |
+
def __init__(self, message: str, status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR):
|
| 16 |
+
self.message = message
|
| 17 |
+
self.status_code = status_code
|
| 18 |
+
super().__init__(self.message)
|
| 19 |
+
|
| 20 |
+
# ---------------------------------------------------
|
| 21 |
+
# Domain Specific Exceptions (Hierarchical error)
|
| 22 |
+
# ---------------------------------------------------
|
| 23 |
+
class ModelLoadError(KnowledgeEngineException):
|
| 24 |
+
"""models/ layer where model (Embedder/Reranker) loading fails"""
|
| 25 |
+
def __init__(self, message: str):
|
| 26 |
+
super().__init__(message, status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
|
| 27 |
+
|
| 28 |
+
class DatabaseError(KnowledgeEngineException):
|
| 29 |
+
"""storage/ layer where Qdrant or SQLite integration fails"""
|
| 30 |
+
def __init__(self, message: str):
|
| 31 |
+
super().__init__(message, status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
|
| 32 |
+
|
| 33 |
+
class SearchExecutionError(KnowledgeEngineException):
|
| 34 |
+
"""services/ layer where the search pipeline (Hybrid Search) encounters a logical error"""
|
| 35 |
+
def __init__(self, message: str):
|
| 36 |
+
super().__init__(message, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
| 37 |
+
|
| 38 |
+
class InvalidQueryError(KnowledgeEngineException):
|
| 39 |
+
"""api/ layer where user input is invalid (e.g., empty query, unsupported parameters)"""
|
| 40 |
+
def __init__(self, message: str):
|
| 41 |
+
super().__init__(message, status_code=status.HTTP_400_BAD_REQUEST)
|
| 42 |
+
|
| 43 |
+
# -----------------------------------
|
| 44 |
+
# FastAPI Exception Handler
|
| 45 |
+
# -----------------------------------
|
| 46 |
+
async def custom_exception_handler(request: Request, exc: KnowledgeEngineException):
|
| 47 |
+
"""
|
| 48 |
+
When a custom exception occurs in a FastAPI app,
|
| 49 |
+
catch it and convert it into a consistent JSON error response.
|
| 50 |
+
"""
|
| 51 |
+
logger.error(f"[{exc.status_code}] {request.method} {request.url} - {exc.message}")
|
| 52 |
+
return JSONResponse(
|
| 53 |
+
status_code=exc.status_code,
|
| 54 |
+
content={
|
| 55 |
+
"error": exc.__class__.__name__,
|
| 56 |
+
"message": exc.message,
|
| 57 |
+
"path": str(request.url.path)
|
| 58 |
+
}
|
| 59 |
+
)
|
| 60 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 61 |
+
"""Catch any unhandled exceptions that are not instances of KnowledgeEngineException,
|
| 62 |
+
log them, and return a generic error response."""
|
| 63 |
+
logger.critical(f"Unhandled Exception: {str(exc)}", exc_info=True) # Log stack trace for debugging
|
| 64 |
+
return JSONResponse(
|
| 65 |
+
status_code=500,
|
| 66 |
+
content={"error": "InternalServerError", "message": "An unexpected error occurred."}
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def setup_exception_handlers(app):
|
| 70 |
+
"""Register custom exception handlers to the FastAPI app."""
|
| 71 |
+
app.add_exception_handler(KnowledgeEngineException, custom_exception_handler)
|
| 72 |
+
app.add_exception_handler(Exception, global_exception_handler)
|
core/logger.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# core/logger.py
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import sys
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from core.config import settings
|
| 9 |
+
DEFAULT_LOG_LEVEL = settings.LOG_LEVEL
|
| 10 |
+
except ImportError:
|
| 11 |
+
DEFAULT_LOG_LEVEL = "INFO"
|
| 12 |
+
|
| 13 |
+
# logging format: timestamp | log level | logger name | message
|
| 14 |
+
LOG_FORMAT = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
| 15 |
+
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
| 16 |
+
|
| 17 |
+
def setup_logger(name: str, level: Optional[str] = None) -> logging.Logger:
|
| 18 |
+
"""
|
| 19 |
+
Returns a standardized logger instance for use in each module.
|
| 20 |
+
Usage: logger = setup_logger(__name__)
|
| 21 |
+
"""
|
| 22 |
+
logger = logging.getLogger(name)
|
| 23 |
+
|
| 24 |
+
# If the logger already has a handler set up (to prevent duplicate calls), return as is.
|
| 25 |
+
if logger.handlers:
|
| 26 |
+
return logger
|
| 27 |
+
|
| 28 |
+
# Set the log level
|
| 29 |
+
log_level = level or DEFAULT_LOG_LEVEL
|
| 30 |
+
logger.setLevel(getattr(logging, log_level.upper(), logging.INFO))
|
| 31 |
+
|
| 32 |
+
# Prevent duplicate logging (do not propagate to parent loggers)
|
| 33 |
+
logger.propagate = False
|
| 34 |
+
|
| 35 |
+
# Create console handler and set level
|
| 36 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 37 |
+
console_handler.setLevel(logger.level)
|
| 38 |
+
|
| 39 |
+
# Apply formatter
|
| 40 |
+
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=DATE_FORMAT)
|
| 41 |
+
console_handler.setFormatter(formatter)
|
| 42 |
+
|
| 43 |
+
logger.addHandler(console_handler)
|
| 44 |
+
|
| 45 |
+
return logger
|
main.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# main.py
|
| 2 |
+
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
from fastapi.responses import RedirectResponse
|
| 7 |
+
from starlette.middleware.cors import CORSMiddleware
|
| 8 |
+
|
| 9 |
+
from api.v1 import search, system
|
| 10 |
+
from core.config import settings
|
| 11 |
+
from core.exceptions import setup_exception_handlers
|
| 12 |
+
from core.logger import setup_logger
|
| 13 |
+
from models.embedder import TextEmbedder
|
| 14 |
+
from models.reranker import TextReranker
|
| 15 |
+
from scripts.setup_db import download_knowledge_base
|
| 16 |
+
from services.search_service import HybridSearchService
|
| 17 |
+
from storage.qdrant_client import QdrantStorage
|
| 18 |
+
from storage.sqlite_client import SQLiteStorage
|
| 19 |
+
|
| 20 |
+
logger = setup_logger("knowledge_engine")
|
| 21 |
+
|
| 22 |
+
@asynccontextmanager
|
| 23 |
+
async def lifespan(app: FastAPI):
|
| 24 |
+
"""
|
| 25 |
+
FastAPI Lifespan function to manage startup and shutdown events.
|
| 26 |
+
On startup, it initializes all necessary components (DB connections, models, services) and injects them into the app state.
|
| 27 |
+
On shutdown, it ensures that all resources are properly cleaned up (e.g., closing DB connections).
|
| 28 |
+
- This approach centralizes all initialization logic in one place, making it easier to manage dependencies and handle errors during startup.
|
| 29 |
+
- If any critical error occurs during startup, it logs the error and prevents the server from starting in an unstable state.
|
| 30 |
+
"""
|
| 31 |
+
logger.info("🚀 Starting Knowledge Engine API...")
|
| 32 |
+
|
| 33 |
+
qdrant_client = None
|
| 34 |
+
sqlite_client = None
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
# 0. Prepare dependency data (DB) (Download if unavailable, skip if available)
|
| 38 |
+
logger.info("Checking and preparing Knowledge Base data...")
|
| 39 |
+
download_knowledge_base()
|
| 40 |
+
|
| 41 |
+
# 1. Infrastructure Connection (Database)
|
| 42 |
+
qdrant_client = QdrantStorage(path=settings.QDRANT_PATH, collection_name=settings.QDRANT_COLLECTION)
|
| 43 |
+
sqlite_client = SQLiteStorage(db_path=settings.SQLITE_PATH)
|
| 44 |
+
|
| 45 |
+
# 2. Load AI Model (Singleton)
|
| 46 |
+
embedder = TextEmbedder(model_name=settings.EMBEDDER_NAME, use_fp16=True)
|
| 47 |
+
reranker = TextReranker(model_name=settings.RERANKER_NAME)
|
| 48 |
+
|
| 49 |
+
# 3. Business Service Orchestration (Instantiate the HybridSearchService with all dependencies)
|
| 50 |
+
search_service = HybridSearchService(
|
| 51 |
+
qdrant=qdrant_client,
|
| 52 |
+
sqlite=sqlite_client,
|
| 53 |
+
embedder=embedder,
|
| 54 |
+
reranker=reranker
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# 4. Injecting services into FastAPI app state for global accessibility in routers
|
| 58 |
+
app.state.search_service = search_service
|
| 59 |
+
logger.info("✅ All services and models initialized successfully.")
|
| 60 |
+
|
| 61 |
+
yield # --- From this point, the server starts receiving traffic ---
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.critical(f"❌ Application failed to start: {e}", exc_info=True)
|
| 65 |
+
raise e
|
| 66 |
+
|
| 67 |
+
finally:
|
| 68 |
+
logger.info("🛑 Shutting down. Cleaning up resources...")
|
| 69 |
+
# Safe termination of DB connections, etc.
|
| 70 |
+
if qdrant_client is not None: qdrant_client.close()
|
| 71 |
+
if sqlite_client is not None: sqlite_client.close()
|
| 72 |
+
logger.info("Resources cleaned up.")
|
| 73 |
+
|
| 74 |
+
# ---------------------------
|
| 75 |
+
# FastAPI Instance Creation
|
| 76 |
+
# ---------------------------
|
| 77 |
+
app = FastAPI(
|
| 78 |
+
title="Hybrid RAG Knowledge Engine API",
|
| 79 |
+
description="Qdrant and BGE-M3-based high-performance hybrid search engine API",
|
| 80 |
+
version="0.1.0",
|
| 81 |
+
lifespan=lifespan
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# CORS Setup
|
| 85 |
+
app.add_middleware(
|
| 86 |
+
CORSMiddleware,
|
| 87 |
+
allow_origins=["*"],
|
| 88 |
+
allow_credentials=True,
|
| 89 |
+
allow_methods=["*"],
|
| 90 |
+
allow_headers=["*"],
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Mount static files (CSS, JS, etc.) if needed (e.g., for demo pages)
|
| 94 |
+
# app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 95 |
+
|
| 96 |
+
# ---------------------------
|
| 97 |
+
# Router Registration
|
| 98 |
+
# ---------------------------
|
| 99 |
+
app.include_router(system.router, prefix="/api/v1")
|
| 100 |
+
app.include_router(search.router, prefix="/api/v1")
|
| 101 |
+
|
| 102 |
+
@app.get("/", include_in_schema=False)
|
| 103 |
+
async def root():
|
| 104 |
+
return RedirectResponse(url="/api/v1/search/demo")
|
| 105 |
+
|
| 106 |
+
# -----------------------------------
|
| 107 |
+
# Register global exception handlers
|
| 108 |
+
# -----------------------------------
|
| 109 |
+
setup_exception_handlers(app)
|
models/embedder.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/embedder.py
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from FlagEmbedding import BGEM3FlagModel
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
from core.exceptions import ModelLoadError
|
| 10 |
+
from core.logger import setup_logger
|
| 11 |
+
|
| 12 |
+
logger = setup_logger("embedder")
|
| 13 |
+
|
| 14 |
+
# Data structure for return (Type Hinting)
|
| 15 |
+
class EmbedderResult(BaseModel):
|
| 16 |
+
dense_vector: List[float]
|
| 17 |
+
sparse_indices: List[int]
|
| 18 |
+
sparse_values: List[float]
|
| 19 |
+
|
| 20 |
+
class TextEmbedder:
|
| 21 |
+
"""
|
| 22 |
+
Converts the input text into Dense Vectors and Sparse Vectors (Lexical Weights) using the BGE-M3 model.
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, model_name: str = "BAAI/bge-m3", use_fp16: bool = False):
|
| 25 |
+
self.model_name = model_name
|
| 26 |
+
self.device = self._get_device()
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
logger.info(f"⏳ Loading Embedder Model: {self.model_name} on {self.device}")
|
| 30 |
+
self.model = BGEM3FlagModel(
|
| 31 |
+
self.model_name,
|
| 32 |
+
use_fp16=(use_fp16 and self.device.startswith("cuda"))
|
| 33 |
+
)
|
| 34 |
+
self._warmup()
|
| 35 |
+
logger.info("✅ Embedder Model loaded successfully.")
|
| 36 |
+
except Exception as e:
|
| 37 |
+
logger.critical(f"❌ Failed to load Embedder Model: {e}", exc_info=True)
|
| 38 |
+
raise ModelLoadError(f"Embedder initialization failed: {e}")
|
| 39 |
+
|
| 40 |
+
def _get_device(self) -> str:
|
| 41 |
+
if torch.cuda.is_available():
|
| 42 |
+
return "cuda"
|
| 43 |
+
elif torch.backends.mps.is_available():
|
| 44 |
+
return "mps" # Apple Silicon
|
| 45 |
+
return "cpu"
|
| 46 |
+
|
| 47 |
+
def _warmup(self):
|
| 48 |
+
logger.info("Warming up embedder model with a dummy input.")
|
| 49 |
+
self.encode_query("Hello world")
|
| 50 |
+
|
| 51 |
+
def encode_query(self, text: str) -> EmbedderResult:
|
| 52 |
+
"""
|
| 53 |
+
Converts a single query text into Qdrant hybrid search format.
|
| 54 |
+
"""
|
| 55 |
+
try:
|
| 56 |
+
# 1. model inference to get dense vector and sparse lexical weights
|
| 57 |
+
output = self.model.encode(
|
| 58 |
+
text,
|
| 59 |
+
return_dense=True,
|
| 60 |
+
return_sparse=True,
|
| 61 |
+
return_colbert_vecs=False
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
dense_vec = output['dense_vecs'].tolist()
|
| 65 |
+
lexical_weights: Dict[str, float] = output['lexical_weights']
|
| 66 |
+
|
| 67 |
+
# 2. Sparse Vector Transformation (Qdrant specifications: token_id array, weight array)
|
| 68 |
+
sparse_indices = []
|
| 69 |
+
sparse_values = []
|
| 70 |
+
|
| 71 |
+
# Convert text tokens into unique IDs (integers) using the BGE-M3 tokenizer
|
| 72 |
+
for token_str, weight in lexical_weights.items():
|
| 73 |
+
# Get the ID of the string token through the tokenizer (vocab index)
|
| 74 |
+
token_id = self.model.tokenizer.convert_tokens_to_ids(token_str)
|
| 75 |
+
if token_id is not None:
|
| 76 |
+
sparse_indices.append(token_id)
|
| 77 |
+
sparse_values.append(float(weight))
|
| 78 |
+
|
| 79 |
+
return EmbedderResult(
|
| 80 |
+
dense_vector=dense_vec,
|
| 81 |
+
sparse_indices=sparse_indices,
|
| 82 |
+
sparse_values=sparse_values
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Failed to encode query '{text}': {e}")
|
| 87 |
+
raise RuntimeError(f"Embedding generation failed: {e}")
|
| 88 |
+
|
| 89 |
+
# options for batch encoding if needed in the future
|
| 90 |
+
def encode_documents(self, texts: List[str], batch_size: int = 12) -> Dict[str, Any]:
|
| 91 |
+
return self.model.encode(
|
| 92 |
+
texts,
|
| 93 |
+
batch_size=batch_size,
|
| 94 |
+
max_length=8192, # BGE-M3's max token length
|
| 95 |
+
return_dense=True,
|
| 96 |
+
return_sparse=True,
|
| 97 |
+
return_colbert_vecs=False
|
| 98 |
+
)
|
models/reranker.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/reranker.py
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from FlagEmbedding import FlagReranker
|
| 7 |
+
|
| 8 |
+
from core.exceptions import ModelLoadError
|
| 9 |
+
from core.logger import setup_logger
|
| 10 |
+
|
| 11 |
+
logger = setup_logger("reranker")
|
| 12 |
+
|
| 13 |
+
class TextReranker:
|
| 14 |
+
"""
|
| 15 |
+
Using the BGE-Reranker model, the documents retrieved in the first search are reordered (Cross-Encoding) by comparing them with the query.
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, model_name: str = "BAAI/bge-reranker-v2-m3", use_fp16: bool = False):
|
| 18 |
+
self.model_name = model_name
|
| 19 |
+
self.device = self._get_device()
|
| 20 |
+
self._warmup()
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
logger.info(f"⏳ Loading Reranker Model: {self.model_name} on {self.device}")
|
| 24 |
+
self.reranker = FlagReranker(
|
| 25 |
+
self.model_name,
|
| 26 |
+
use_fp16=(use_fp16 and self.device.startswith("cuda"))
|
| 27 |
+
)
|
| 28 |
+
logger.info("✅ Reranker Model loaded successfully.")
|
| 29 |
+
except Exception as e:
|
| 30 |
+
logger.critical(f"❌ Failed to load Reranker Model: {e}", exc_info=True)
|
| 31 |
+
raise ModelLoadError(f"Reranker initialization failed: {e}")
|
| 32 |
+
|
| 33 |
+
def _get_device(self) -> str:
|
| 34 |
+
if torch.cuda.is_available():
|
| 35 |
+
return "cuda"
|
| 36 |
+
elif torch.backends.mps.is_available():
|
| 37 |
+
return "mps"
|
| 38 |
+
return "cpu"
|
| 39 |
+
|
| 40 |
+
def _warmup(self):
|
| 41 |
+
logger.info("Warming up reranker model with a dummy input.")
|
| 42 |
+
self.rerank(query="Hello world", documents=[{"text": "Hello world"}])
|
| 43 |
+
|
| 44 |
+
def rerank(self, query: str, documents: List[Dict[str, Any]], text_key: str = "text") -> List[Dict[str, Any]]:
|
| 45 |
+
"""
|
| 46 |
+
Takes a list of documents as input, recalculates their similarity to the query, and returns the results sorted by score.
|
| 47 |
+
|
| 48 |
+
:param query: The original search query string
|
| 49 |
+
:param documents: A list of dictionaries in the form [{'chunk_id': 1, 'text': '...'}, ...]
|
| 50 |
+
:param text_key: The key name in the document dictionary containing the body text
|
| 51 |
+
"""
|
| 52 |
+
if not documents:
|
| 53 |
+
return []
|
| 54 |
+
|
| 55 |
+
# Generate pairs for Cross-Encoder input: [[query, doc1], [query, doc2], ...]
|
| 56 |
+
sentence_pairs = [[query, doc[text_key]] for doc in documents]
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
# 1. Batch score calculation
|
| 60 |
+
scores = self.reranker.compute_score(sentence_pairs, normalize=True)
|
| 61 |
+
|
| 62 |
+
# Wrap in a list because compute_score can return a float when there is only one input document
|
| 63 |
+
if isinstance(scores, float):
|
| 64 |
+
scores = [scores]
|
| 65 |
+
|
| 66 |
+
# 2. Inject rerank_score into source document dictionarys
|
| 67 |
+
for i, doc in enumerate(documents):
|
| 68 |
+
doc["rerank_score"] = float(scores[i])
|
| 69 |
+
|
| 70 |
+
# 3. Sort by score (descending)
|
| 71 |
+
reranked_docs = sorted(documents, key=lambda x: x["rerank_score"], reverse=True)
|
| 72 |
+
|
| 73 |
+
return reranked_docs
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.error(f"Reranking failed for query '{query}': {e}")
|
| 77 |
+
raise RuntimeError(f"Reranking process failed: {e}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Web Server
|
| 2 |
+
fastapi==0.135.1
|
| 3 |
+
uvicorn[standard]==0.40.0
|
| 4 |
+
jinja2==3.1.5
|
| 5 |
+
python-multipart==0.0.22
|
| 6 |
+
starlette==0.52.1
|
| 7 |
+
|
| 8 |
+
# Vector Search & Embeddings
|
| 9 |
+
qdrant-client==1.16.2
|
| 10 |
+
FlagEmbedding==1.3.5
|
| 11 |
+
torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu
|
| 12 |
+
numpy==2.4.2
|
| 13 |
+
sentence-transformers==5.2.3
|
| 14 |
+
|
| 15 |
+
# Data Modeling & Settings
|
| 16 |
+
pydantic==2.12.2
|
| 17 |
+
pydantic-settings==2.13.1
|
| 18 |
+
python-dotenv==1.2.1
|
| 19 |
+
|
| 20 |
+
# Hugging Face Stack
|
| 21 |
+
huggingface_hub>=0.25.0
|
| 22 |
+
transformers>=4.44.0
|
| 23 |
+
tokenizers>=0.19.0
|
| 24 |
+
accelerate>=0.34.0
|
| 25 |
+
|
| 26 |
+
# ONNX Runtime
|
| 27 |
+
onnxruntime>=1.19.0
|
| 28 |
+
|
| 29 |
+
# Utils
|
| 30 |
+
tqdm==4.67.2
|
| 31 |
+
requests==2.32.5
|
scripts/data_pipeline.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/data_pipeline.py
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import sqlite3
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
from FlagEmbedding import BGEM3FlagModel
|
| 11 |
+
from qdrant_client import QdrantClient
|
| 12 |
+
from qdrant_client.models import (Distance, OptimizersConfigDiff, PointStruct,
|
| 13 |
+
ScalarQuantization, ScalarQuantizationConfig,
|
| 14 |
+
ScalarType, SparseIndexParams, SparseVector,
|
| 15 |
+
SparseVectorParams, VectorParams)
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from transformers import AutoTokenizer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class KnowledgeEngineBuilder:
|
| 21 |
+
def __init__(self, base_dir="ke_store", dim=1024):
|
| 22 |
+
self.base_dir = base_dir
|
| 23 |
+
self.dim = dim
|
| 24 |
+
|
| 25 |
+
print("Loading BGE-M3 Model and Tokenizer...")
|
| 26 |
+
self.model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
|
| 27 |
+
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
|
| 28 |
+
|
| 29 |
+
self.max_tokens = 384
|
| 30 |
+
self.overlap_count = 2
|
| 31 |
+
|
| 32 |
+
self._init_dirs()
|
| 33 |
+
self._init_sqlite()
|
| 34 |
+
self._init_meta()
|
| 35 |
+
self._init_qdrant()
|
| 36 |
+
|
| 37 |
+
# ---------------------------
|
| 38 |
+
# INIT & SETUP
|
| 39 |
+
# ---------------------------
|
| 40 |
+
def _init_dirs(self):
|
| 41 |
+
for d in ["corpus", "qdrant", "build_cache/embeddings"]:
|
| 42 |
+
os.makedirs(os.path.join(self.base_dir, d), exist_ok=True)
|
| 43 |
+
|
| 44 |
+
def _init_qdrant(self):
|
| 45 |
+
self.qdrant_path = f"{self.base_dir}/qdrant"
|
| 46 |
+
self.qdrant_client = QdrantClient(path=self.qdrant_path)
|
| 47 |
+
self.collection_name = "knowledge_base"
|
| 48 |
+
|
| 49 |
+
if not self.qdrant_client.collection_exists(self.collection_name):
|
| 50 |
+
print(f"Creating Qdrant collection: {self.collection_name}")
|
| 51 |
+
self.qdrant_client.create_collection(
|
| 52 |
+
collection_name=self.collection_name,
|
| 53 |
+
vectors_config={
|
| 54 |
+
"dense": VectorParams(size=self.dim, distance=Distance.COSINE, on_disk=True)
|
| 55 |
+
},
|
| 56 |
+
sparse_vectors_config={
|
| 57 |
+
"sparse": SparseVectorParams(index=SparseIndexParams(on_disk=True))
|
| 58 |
+
},
|
| 59 |
+
quantization_config=ScalarQuantization(
|
| 60 |
+
scalar=ScalarQuantizationConfig(type=ScalarType.INT8, always_ram=True)
|
| 61 |
+
),
|
| 62 |
+
optimizers_config=OptimizersConfigDiff(indexing_threshold=0)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def _optimize_sqlite(self, conn):
|
| 66 |
+
conn.execute("PRAGMA journal_mode=WAL;")
|
| 67 |
+
conn.execute("PRAGMA synchronous=NORMAL;")
|
| 68 |
+
conn.execute("PRAGMA temp_store=MEMORY;")
|
| 69 |
+
conn.execute("PRAGMA cache_size=-2000000")
|
| 70 |
+
|
| 71 |
+
def _init_sqlite(self):
|
| 72 |
+
self.conn = sqlite3.connect(f"{self.base_dir}/corpus/corpus.sqlite")
|
| 73 |
+
self._optimize_sqlite(self.conn)
|
| 74 |
+
cur = self.conn.cursor()
|
| 75 |
+
|
| 76 |
+
cur.execute("""
|
| 77 |
+
CREATE TABLE IF NOT EXISTS documents (
|
| 78 |
+
doc_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 79 |
+
external_id TEXT, title TEXT, lang TEXT, url TEXT,
|
| 80 |
+
wikidata_id TEXT, date_modified TEXT, full_text TEXT)
|
| 81 |
+
""")
|
| 82 |
+
|
| 83 |
+
cur.execute("""
|
| 84 |
+
CREATE TABLE IF NOT EXISTS chunks (
|
| 85 |
+
chunk_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 86 |
+
doc_id INTEGER, chunk_index INTEGER, text TEXT,
|
| 87 |
+
token_length INTEGER, section TEXT, lang TEXT)
|
| 88 |
+
""")
|
| 89 |
+
|
| 90 |
+
cur.execute("""
|
| 91 |
+
CREATE TABLE IF NOT EXISTS spans (
|
| 92 |
+
span_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 93 |
+
chunk_id INTEGER, span_index INTEGER, text TEXT, char_length INTEGER)
|
| 94 |
+
""")
|
| 95 |
+
|
| 96 |
+
cur.execute("CREATE INDEX IF NOT EXISTS idx_chunks_doc_id ON chunks(doc_id)")
|
| 97 |
+
cur.execute("CREATE INDEX IF NOT EXISTS idx_spans_chunk_id ON spans(chunk_id)")
|
| 98 |
+
cur.execute("CREATE INDEX IF NOT EXISTS idx_chunks_lang ON chunks(lang)")
|
| 99 |
+
self.conn.commit()
|
| 100 |
+
|
| 101 |
+
def _init_meta(self):
|
| 102 |
+
self.meta_path = f"{self.base_dir}/corpus/meta.json"
|
| 103 |
+
cur = self.conn.cursor()
|
| 104 |
+
cur.execute("SELECT MAX(doc_id) FROM documents")
|
| 105 |
+
db_doc = cur.fetchone()[0] or 0
|
| 106 |
+
cur.execute("SELECT MAX(chunk_id) FROM chunks")
|
| 107 |
+
db_chunk = cur.fetchone()[0] or 0
|
| 108 |
+
cur.execute("SELECT MAX(span_id) FROM spans")
|
| 109 |
+
db_span = cur.fetchone()[0] or 0
|
| 110 |
+
|
| 111 |
+
self.meta = {
|
| 112 |
+
"last_doc_id": db_doc + 1,
|
| 113 |
+
"last_chunk_id": db_chunk + 1,
|
| 114 |
+
"last_span_id": db_span + 1
|
| 115 |
+
}
|
| 116 |
+
self._save_meta()
|
| 117 |
+
|
| 118 |
+
def _save_meta(self):
|
| 119 |
+
with open(self.meta_path, "w") as f:
|
| 120 |
+
json.dump(self.meta, f, indent=4)
|
| 121 |
+
|
| 122 |
+
# ---------------------------
|
| 123 |
+
# TEXT PROCESSING & INGESTION
|
| 124 |
+
# ---------------------------
|
| 125 |
+
def split_sentences(self, text):
|
| 126 |
+
text = re.sub(r'[ \t]+', ' ', text)
|
| 127 |
+
pattern = r'(?<=[.!?。!?])(?<![Ar|Dr|Mr|Ms|St]\.)(?<![A-Z]\.)\s+'
|
| 128 |
+
sentences = re.split(pattern, text)
|
| 129 |
+
final_sentences = []
|
| 130 |
+
for s in sentences:
|
| 131 |
+
sub_parts = [p.strip() for p in s.split('\n') if p.strip()]
|
| 132 |
+
final_sentences.extend(sub_parts)
|
| 133 |
+
return [s for s in final_sentences if len(s) > 1]
|
| 134 |
+
|
| 135 |
+
def count_tokens(self, text):
|
| 136 |
+
return len(self.tokenizer.encode(text, add_special_tokens=False))
|
| 137 |
+
|
| 138 |
+
def get_token_counts_batch(self, texts):
|
| 139 |
+
if not texts: return []
|
| 140 |
+
encodings = self.tokenizer(texts, add_special_tokens=False, padding=False, truncation=False)
|
| 141 |
+
return [len(ids) for ids in encodings['input_ids']]
|
| 142 |
+
|
| 143 |
+
def _split_monster_sentence(self, sentence):
|
| 144 |
+
words = sentence.split(' ')
|
| 145 |
+
sub_spans, current_sub, current_toks = [], [], 0
|
| 146 |
+
|
| 147 |
+
for word in words:
|
| 148 |
+
word_toks = self.count_tokens(word)
|
| 149 |
+
if word_toks > self.max_tokens:
|
| 150 |
+
if current_sub:
|
| 151 |
+
sub_spans.append(" ".join(current_sub))
|
| 152 |
+
current_sub, current_toks = [], 0
|
| 153 |
+
half = len(word) // 2
|
| 154 |
+
sub_spans.extend([word[:half], word[half:]])
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
space_tok = 1 if current_sub else 0
|
| 158 |
+
if current_toks + word_toks + space_tok > self.max_tokens and current_sub:
|
| 159 |
+
sub_spans.append(" ".join(current_sub))
|
| 160 |
+
current_sub, current_toks = [word], word_toks
|
| 161 |
+
else:
|
| 162 |
+
current_sub.append(word)
|
| 163 |
+
current_toks += word_toks + space_tok
|
| 164 |
+
|
| 165 |
+
if current_sub: sub_spans.append(" ".join(current_sub))
|
| 166 |
+
return sub_spans
|
| 167 |
+
|
| 168 |
+
def chunk_text(self, text):
|
| 169 |
+
raw_sentences = self.split_sentences(text)
|
| 170 |
+
sentence_lengths = self.get_token_counts_batch(raw_sentences)
|
| 171 |
+
|
| 172 |
+
refined_spans = []
|
| 173 |
+
for s, length in zip(raw_sentences, sentence_lengths):
|
| 174 |
+
if length > self.max_tokens: refined_spans.extend(self._split_monster_sentence(s))
|
| 175 |
+
else: refined_spans.append(s)
|
| 176 |
+
|
| 177 |
+
span_toks_list = self.get_token_counts_batch(refined_spans)
|
| 178 |
+
chunks, current_spans, current_tokens = [], [], 0
|
| 179 |
+
|
| 180 |
+
for span, span_toks in zip(refined_spans, span_toks_list):
|
| 181 |
+
if current_tokens + span_toks > self.max_tokens and current_spans:
|
| 182 |
+
chunk_text = " ".join(current_spans)
|
| 183 |
+
chunks.append((chunk_text, self.count_tokens(chunk_text), list(current_spans)))
|
| 184 |
+
|
| 185 |
+
actual_overlap = min(self.overlap_count, len(current_spans) - 1)
|
| 186 |
+
if actual_overlap > 0:
|
| 187 |
+
current_spans = current_spans[-actual_overlap:]
|
| 188 |
+
current_tokens = self.count_tokens(" ".join(current_spans)) + 1
|
| 189 |
+
else:
|
| 190 |
+
current_spans, current_tokens = [], 0
|
| 191 |
+
|
| 192 |
+
current_spans.append(span)
|
| 193 |
+
current_tokens += span_toks + 1
|
| 194 |
+
|
| 195 |
+
if current_spans:
|
| 196 |
+
chunk_text = " ".join(current_spans)
|
| 197 |
+
chunks.append((chunk_text, self.count_tokens(chunk_text), list(current_spans)))
|
| 198 |
+
return chunks
|
| 199 |
+
|
| 200 |
+
def ingest(self, lang="ko", batch_size=32, limit=None):
|
| 201 |
+
"""
|
| 202 |
+
- The dataset is read in a streaming manner to handle large corpora without memory issues.
|
| 203 |
+
- Each document is processed to create chunks based on token limits, with an overlap strategy to ensure comprehensive coverage of the text.
|
| 204 |
+
- The processed documents, chunks, and spans are stored in SQLite with appropriate indexing for efficient retrieval during search.
|
| 205 |
+
"""
|
| 206 |
+
ds = load_dataset("HuggingFaceFW/finewiki", lang, split="train", streaming=True)
|
| 207 |
+
cur = self.conn.cursor()
|
| 208 |
+
count = 0
|
| 209 |
+
batch_docs, batch_chunks, batch_spans = [], [], []
|
| 210 |
+
|
| 211 |
+
for item in tqdm(ds, desc=f"Ingesting {lang}"):
|
| 212 |
+
if limit and count >= limit: break
|
| 213 |
+
doc_id = self.meta["last_doc_id"]
|
| 214 |
+
batch_docs.append((doc_id, item["id"], item["title"], lang, item["url"], item.get("wikidata_id", ""), item.get("date_modified", ""), item["text"]))
|
| 215 |
+
|
| 216 |
+
for c_idx, (chunk_text, token_len, span_list) in enumerate(self.chunk_text(item["text"])):
|
| 217 |
+
chunk_id = self.meta["last_chunk_id"]
|
| 218 |
+
batch_chunks.append((chunk_id, doc_id, c_idx, chunk_text, token_len, item["title"], lang))
|
| 219 |
+
for s_idx, span_text in enumerate(span_list):
|
| 220 |
+
batch_spans.append((self.meta["last_span_id"], chunk_id, s_idx, span_text, len(span_text)))
|
| 221 |
+
self.meta["last_span_id"] += 1
|
| 222 |
+
self.meta["last_chunk_id"] += 1
|
| 223 |
+
self.meta["last_doc_id"] += 1
|
| 224 |
+
count += 1
|
| 225 |
+
|
| 226 |
+
if len(batch_docs) >= batch_size:
|
| 227 |
+
self._commit_batch(cur, batch_docs, batch_chunks, batch_spans)
|
| 228 |
+
batch_docs, batch_chunks, batch_spans = [], [], []
|
| 229 |
+
if count % (batch_size * 10) == 0: self._save_meta()
|
| 230 |
+
|
| 231 |
+
self._commit_batch(cur, batch_docs, batch_chunks, batch_spans)
|
| 232 |
+
self.conn.commit()
|
| 233 |
+
self.conn.execute("PRAGMA wal_checkpoint(FULL);")
|
| 234 |
+
self._save_meta()
|
| 235 |
+
|
| 236 |
+
def _commit_batch(self, cur, docs, chunks, spans):
|
| 237 |
+
if not docs: return
|
| 238 |
+
cur.executemany("INSERT INTO documents VALUES (?,?,?,?,?,?,?,?)", docs)
|
| 239 |
+
cur.executemany("INSERT INTO chunks VALUES (?,?,?,?,?,?,?)", chunks)
|
| 240 |
+
cur.executemany("INSERT INTO spans VALUES (?,?,?,?,?)", spans)
|
| 241 |
+
|
| 242 |
+
# ---------------------------
|
| 243 |
+
# EMBED TO DISK
|
| 244 |
+
# ---------------------------
|
| 245 |
+
def embed_corpus(self, lang="ko", batch_size=128, save_interval=100000):
|
| 246 |
+
"""
|
| 247 |
+
Text is read in batches from SQLite, embeddings are generated using BGE-M3, and then saved to disk.
|
| 248 |
+
- Embedding generation is performed on the GPU, and data is saved to disk in fixed batches to manage memory.
|
| 249 |
+
- Dense vectors are saved in NumPy's .npz format to ensure fast loading and low disk usage.
|
| 250 |
+
- Sparse vectors are saved in JSONL format to provide flexibility and readability.
|
| 251 |
+
- The saved embeddings are subsequently uploaded to Qdrant for use in searches.
|
| 252 |
+
- This method is designed to reliably generate and save embeddings even on large-scale datasets.
|
| 253 |
+
"""
|
| 254 |
+
cur = self.conn.cursor()
|
| 255 |
+
cur.execute("SELECT chunk_id, text FROM chunks WHERE lang=?", (lang,))
|
| 256 |
+
rows = cur.fetchall()
|
| 257 |
+
|
| 258 |
+
part_id = 0
|
| 259 |
+
id_buffer = []
|
| 260 |
+
dense_buffer = []
|
| 261 |
+
sparse_buffer = []
|
| 262 |
+
|
| 263 |
+
save_dir = f"{self.base_dir}/build_cache/embeddings"
|
| 264 |
+
|
| 265 |
+
for i in tqdm(range(0, len(rows), batch_size), desc=f"1/2 GPU Embedding ({lang})"):
|
| 266 |
+
batch = rows[i:i+batch_size]
|
| 267 |
+
ids = [r[0] for r in batch]
|
| 268 |
+
texts = [r[1] for r in batch]
|
| 269 |
+
|
| 270 |
+
output = self.model.encode(
|
| 271 |
+
texts, batch_size=len(texts), max_length=self.max_tokens,
|
| 272 |
+
return_dense=True, return_sparse=True, return_colbert_vecs=False
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
id_buffer.extend(ids)
|
| 276 |
+
dense_buffer.append(output['dense_vecs'])
|
| 277 |
+
|
| 278 |
+
for sp_dict in output['lexical_weights']:
|
| 279 |
+
sparse_buffer.append({str(k): float(v) for k, v in sp_dict.items()})
|
| 280 |
+
|
| 281 |
+
# Save to disk when a certain number is reached (prevents memory explosion)
|
| 282 |
+
if len(id_buffer) >= save_interval:
|
| 283 |
+
self._save_embedding_part(save_dir, lang, part_id, id_buffer, dense_buffer, sparse_buffer)
|
| 284 |
+
part_id += 1
|
| 285 |
+
id_buffer, dense_buffer, sparse_buffer = [], [], []
|
| 286 |
+
|
| 287 |
+
# Save the last remaining scraps
|
| 288 |
+
self._save_embedding_part(save_dir, lang, part_id, id_buffer, dense_buffer, sparse_buffer)
|
| 289 |
+
print(f"Embedding Generation Complete. Saved to {save_dir}")
|
| 290 |
+
|
| 291 |
+
def _save_embedding_part(self, save_dir, lang, part_id, ids, dense_chunks, sparse_list):
|
| 292 |
+
if not ids: return
|
| 293 |
+
|
| 294 |
+
# Dense & IDs: High-speed storage as NumPy binaries
|
| 295 |
+
np.savez(f"{save_dir}/ebd_{lang}_{part_id}.npz",
|
| 296 |
+
ids=np.array(ids, dtype=np.int64),
|
| 297 |
+
dense=np.vstack(dense_chunks))
|
| 298 |
+
|
| 299 |
+
# Sparse: Save in JSONL format (one line at a time)
|
| 300 |
+
with open(f"{save_dir}/sparse_{lang}_{part_id}.jsonl", 'w', encoding='utf-8') as f:
|
| 301 |
+
for sp in sparse_list:
|
| 302 |
+
f.write(json.dumps(sp) + '\n')
|
| 303 |
+
|
| 304 |
+
# ---------------------------
|
| 305 |
+
# BUILD QDRANT INDEX
|
| 306 |
+
# ---------------------------
|
| 307 |
+
def build_qdrant_index(self, lang="ko", batch_size=2000):
|
| 308 |
+
"""
|
| 309 |
+
The generated embeddings are read from disk and uploaded to Qdrant in batches.
|
| 310 |
+
- This method reads the saved dense and sparse embeddings, constructs the appropriate data structures for Qdrant, and uploads them in batches to manage memory and ensure efficient indexing.
|
| 311 |
+
- After all data is uploaded, it triggers Qdrant's indexing process to optimize search performance.
|
| 312 |
+
- The use of batch uploads and on-disk storage allows this process to scale to large datasets without overwhelming system memory.
|
| 313 |
+
"""
|
| 314 |
+
save_dir = f"{self.base_dir}/build_cache/embeddings"
|
| 315 |
+
files = sorted([f for f in os.listdir(save_dir) if f.startswith(f"ebd_{lang}_") and f.endswith(".npz")])
|
| 316 |
+
|
| 317 |
+
for file_name in files:
|
| 318 |
+
part_id = file_name.split("_")[-1].split(".")[0]
|
| 319 |
+
|
| 320 |
+
# 1. Load file and convert to Qdrant point structure
|
| 321 |
+
npz_path = os.path.join(save_dir, file_name)
|
| 322 |
+
sparse_path = os.path.join(save_dir, f"sparse_{lang}_{part_id}.jsonl")
|
| 323 |
+
|
| 324 |
+
data = np.load(npz_path)
|
| 325 |
+
ids = data['ids']
|
| 326 |
+
dense_vecs = data['dense']
|
| 327 |
+
|
| 328 |
+
with open(sparse_path, 'r', encoding='utf-8') as f:
|
| 329 |
+
sparse_vecs = [json.loads(line) for line in f]
|
| 330 |
+
|
| 331 |
+
points_batch = []
|
| 332 |
+
|
| 333 |
+
# 2. Qdrant Upload Loop
|
| 334 |
+
for i in tqdm(range(len(ids)), desc=f"2/2 Qdrant Uploading (Part {part_id})"):
|
| 335 |
+
chunk_id = int(ids[i])
|
| 336 |
+
sparse_dict = sparse_vecs[i]
|
| 337 |
+
|
| 338 |
+
point = PointStruct(
|
| 339 |
+
id=chunk_id,
|
| 340 |
+
vector={
|
| 341 |
+
"dense": dense_vecs[i].tolist(),
|
| 342 |
+
"sparse": SparseVector(
|
| 343 |
+
indices=[int(k) for k in sparse_dict.keys()],
|
| 344 |
+
values=list(sparse_dict.values())
|
| 345 |
+
)
|
| 346 |
+
},
|
| 347 |
+
payload={"chunk_id": chunk_id, "lang": lang}
|
| 348 |
+
)
|
| 349 |
+
points_batch.append(point)
|
| 350 |
+
|
| 351 |
+
# Upload when stacked to batch size
|
| 352 |
+
if len(points_batch) >= batch_size:
|
| 353 |
+
self.qdrant_client.upload_points(
|
| 354 |
+
collection_name=self.collection_name,
|
| 355 |
+
points=points_batch
|
| 356 |
+
)
|
| 357 |
+
points_batch = []
|
| 358 |
+
|
| 359 |
+
# Uploading leftover scraps
|
| 360 |
+
if points_batch:
|
| 361 |
+
self.qdrant_client.upload_points(
|
| 362 |
+
collection_name=self.collection_name,
|
| 363 |
+
points=points_batch
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
print("Data upload complete. Enabling HNSW Indexing...")
|
| 367 |
+
|
| 368 |
+
# 3. [Key] After all uploads are complete, re-enable indexing (default 20,000) to optimize the graph
|
| 369 |
+
self.qdrant_client.update_collection(
|
| 370 |
+
collection_name=self.collection_name,
|
| 371 |
+
optimizer_config=OptimizersConfigDiff(indexing_threshold=20000)
|
| 372 |
+
)
|
| 373 |
+
print("Qdrant Indexing Complete!")
|
| 374 |
+
|
| 375 |
+
def close(self):
|
| 376 |
+
if hasattr(self, 'conn') and self.conn:
|
| 377 |
+
self.conn.close()
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
if __name__ == "__main__":
|
| 381 |
+
builder = KnowledgeEngineBuilder()
|
| 382 |
+
try:
|
| 383 |
+
builder.ingest(lang="ko", batch_size=32, limit=10000) # Process only 10,000 documents as an example
|
| 384 |
+
builder.embed_corpus(lang="ko", batch_size=128, save_interval=5000)
|
| 385 |
+
builder.build_qdrant_index(lang="ko", batch_size=2000)
|
| 386 |
+
finally:
|
| 387 |
+
builder.close()
|
scripts/setup_db.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/setup_db.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
from huggingface_hub import snapshot_download
|
| 7 |
+
from huggingface_hub.utils import HfHubHTTPError
|
| 8 |
+
|
| 9 |
+
from core.config import settings
|
| 10 |
+
from core.logger import setup_logger
|
| 11 |
+
|
| 12 |
+
logger = setup_logger("setup_db")
|
| 13 |
+
|
| 14 |
+
def download_knowledge_base():
|
| 15 |
+
"""
|
| 16 |
+
Checks if the SQLite DB and Qdrant data already exist locally. If not, it downloads them from the specified Hugging Face repository.
|
| 17 |
+
- It uses snapshot_download with allow_patterns to only download the necessary files, optimizing speed and storage.
|
| 18 |
+
- If the files already exist, it logs a message and skips the download.
|
| 19 |
+
"""
|
| 20 |
+
sqlite_path = settings.SQLITE_PATH
|
| 21 |
+
qdrant_dir = settings.QDRANT_PATH
|
| 22 |
+
|
| 23 |
+
if os.path.exists(sqlite_path) and os.path.isdir(qdrant_dir):
|
| 24 |
+
logger.info(f"⚡ SQLite DB and Qdrant data already exist at {sqlite_path} and {qdrant_dir}. Skipping download.")
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
repo_id = settings.REPO_ID
|
| 28 |
+
local_dir = settings.DATA_DIR
|
| 29 |
+
|
| 30 |
+
logger.info(f"📥 Downloading DBs from HF Repo: {repo_id} to {local_dir}...")
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
download_path = snapshot_download(
|
| 34 |
+
repo_id=repo_id,
|
| 35 |
+
repo_type="dataset",
|
| 36 |
+
local_dir=local_dir,
|
| 37 |
+
allow_patterns=["corpus/*", "qdrant/*"],
|
| 38 |
+
ignore_patterns=["build_cache/*", ".gitattributes"],
|
| 39 |
+
max_workers=4
|
| 40 |
+
)
|
| 41 |
+
logger.info(f"✅ Download complete! Data is ready at: {download_path}")
|
| 42 |
+
|
| 43 |
+
except HfHubHTTPError as e:
|
| 44 |
+
logger.error(f"❌ HTTP Error during download: {e}")
|
| 45 |
+
sys.exit(1)
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.error(f"❌ Unexpected error during download: {e}", exc_info=True)
|
| 48 |
+
sys.exit(1)
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
download_knowledge_base()
|
services/search_service.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# services/search_service.py
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
from typing import Any, Dict, List
|
| 5 |
+
|
| 6 |
+
from api.schemas.search import DocumentMetadata, SearchResultItem
|
| 7 |
+
from core.exceptions import SearchExecutionError
|
| 8 |
+
from core.logger import setup_logger
|
| 9 |
+
from models.embedder import TextEmbedder
|
| 10 |
+
from models.reranker import TextReranker
|
| 11 |
+
from storage.qdrant_client import QdrantStorage
|
| 12 |
+
from storage.sqlite_client import SQLiteStorage
|
| 13 |
+
|
| 14 |
+
logger = setup_logger("search_service")
|
| 15 |
+
|
| 16 |
+
class HybridSearchService:
|
| 17 |
+
"""
|
| 18 |
+
It is a business logic service that derives final search results by integrating
|
| 19 |
+
Qdrant (Vector DB), SQLite (RDBMS), Embedder, and Reranker.
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, qdrant: QdrantStorage, sqlite: SQLiteStorage, embedder: TextEmbedder, reranker: TextReranker):
|
| 22 |
+
self.qdrant = qdrant
|
| 23 |
+
self.sqlite = sqlite
|
| 24 |
+
self.embedder = embedder
|
| 25 |
+
self.reranker = reranker
|
| 26 |
+
|
| 27 |
+
def search(self, query: str, top_k: int = 5, limit: int = 50) -> Dict[str, Any]:
|
| 28 |
+
"""
|
| 29 |
+
Receives user queries and performs hybrid search and reranking.
|
| 30 |
+
|
| 31 |
+
:param query: User search query
|
| 32 |
+
:param top_k: Number of documents to return (after reranking)
|
| 33 |
+
:param limit: Number of candidate documents to fetch from Qdrant (after RRF fusion, before reranking)
|
| 34 |
+
"""
|
| 35 |
+
start_time = time.time()
|
| 36 |
+
logger.info(f"🔍 Starting search pipeline for query: '{query}'")
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
# 1. Query Embedding (Dense, Sparse Extraction)
|
| 40 |
+
encoded_query = self.embedder.encode_query(query)
|
| 41 |
+
|
| 42 |
+
# 2. Qdrant Hybrid Search (Extract limit of candidates using RRF method)
|
| 43 |
+
qdrant_results = self.qdrant.hybrid_search(
|
| 44 |
+
dense_vector=encoded_query.dense_vector,
|
| 45 |
+
sparse_indices=encoded_query.sparse_indices,
|
| 46 |
+
sparse_values=encoded_query.sparse_values,
|
| 47 |
+
limit=limit
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if not qdrant_results:
|
| 51 |
+
logger.warning("No results found in Vector DB.")
|
| 52 |
+
return self._build_empty_response(query, start_time)
|
| 53 |
+
|
| 54 |
+
chunk_ids = [res.id for res in qdrant_results]
|
| 55 |
+
|
| 56 |
+
# 3. Get Dict in SQLite for O(1) Mapping of Source Text and Metadata
|
| 57 |
+
sqlite_data_map = self.sqlite.get_enriched_chunks_dict(chunk_ids)
|
| 58 |
+
|
| 59 |
+
# 4. Data Preparation for Reranking (Merging Qdrant and SQLite Data)
|
| 60 |
+
chunks_for_reranking = []
|
| 61 |
+
for rank, res in enumerate(qdrant_results, start=1):
|
| 62 |
+
# Defense Logic: Skip data inconsistencies (Desync) in Vector DB but not in SQLite
|
| 63 |
+
chunk_info = sqlite_data_map.get(res.id)
|
| 64 |
+
if not chunk_info:
|
| 65 |
+
logger.warning(f"Data Desync: chunk_id {res.id} found in Qdrant but missing in SQLite.")
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
chunks_for_reranking.append({
|
| 69 |
+
"chunk_id": res.id,
|
| 70 |
+
"text": chunk_info["text"],
|
| 71 |
+
"metadata": chunk_info["metadata"],
|
| 72 |
+
"rrf_score": res.score,
|
| 73 |
+
"rrf_rank": rank
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
if not chunks_for_reranking:
|
| 77 |
+
return self._build_empty_response(query, start_time)
|
| 78 |
+
|
| 79 |
+
# 5. Perform Cross-Encoder Reranking
|
| 80 |
+
# Return a list sorted in descending order after recalculating context-based precise scores
|
| 81 |
+
reranked_docs = self.reranker.rerank(
|
| 82 |
+
query=query,
|
| 83 |
+
documents=chunks_for_reranking,
|
| 84 |
+
text_key="text"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# 6. Top-K Truncation and Mapping to Pydantic Schema (SearchResultItem) Specification
|
| 88 |
+
final_results = []
|
| 89 |
+
for doc in reranked_docs[:top_k]:
|
| 90 |
+
final_results.append(SearchResultItem(
|
| 91 |
+
chunk_id=doc["chunk_id"],
|
| 92 |
+
text=doc["text"],
|
| 93 |
+
score=round(doc["rerank_score"], 4), # Neatly rounded to 4 decimal places
|
| 94 |
+
metadata=DocumentMetadata(**doc["metadata"])
|
| 95 |
+
).model_dump()) # Convert to dict for FastAPI compatibility
|
| 96 |
+
|
| 97 |
+
latency_ms = int((time.time() - start_time) * 1000)
|
| 98 |
+
logger.info(f"✅ Search completed in {latency_ms}ms. Found {len(final_results)} final chunks.")
|
| 99 |
+
|
| 100 |
+
return {
|
| 101 |
+
"query": query,
|
| 102 |
+
"results": final_results,
|
| 103 |
+
"latency_ms": latency_ms
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
# Wrap unexpected errors in custom errors and throw them to the router
|
| 108 |
+
logger.error(f"❌ Pipeline failed: {str(e)}", exc_info=True)
|
| 109 |
+
raise SearchExecutionError(f"Search pipeline failed: {str(e)}")
|
| 110 |
+
|
| 111 |
+
def _build_empty_response(self, query: str, start_time: float) -> Dict[str, Any]:
|
| 112 |
+
"""Build a standard response format when no search results are found"""
|
| 113 |
+
return {
|
| 114 |
+
"query": query,
|
| 115 |
+
"results": [],
|
| 116 |
+
"latency_ms": int((time.time() - start_time) * 1000)
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
# ---------------------------------------------------------
|
| 120 |
+
# LLM-Friendly Prompt Formatter
|
| 121 |
+
# (Utility used when injecting into Agents or VLMs)
|
| 122 |
+
# ---------------------------------------------------------
|
| 123 |
+
def format_for_llm(self, search_results: List[Dict[str, Any]]) -> str:
|
| 124 |
+
"""
|
| 125 |
+
Converts the retrieved JSON results into a Markdown/XML mixed format best understood by LLM.
|
| 126 |
+
(This method can be optionally called by API routers or other Agent systems)
|
| 127 |
+
"""
|
| 128 |
+
if not search_results:
|
| 129 |
+
return "No relevant knowledge (documents) available."
|
| 130 |
+
|
| 131 |
+
context_blocks = []
|
| 132 |
+
for i, res in enumerate(search_results, start=1):
|
| 133 |
+
meta = res["metadata"]
|
| 134 |
+
source = meta.get("title", f"Document_{meta.get('doc_id')}")
|
| 135 |
+
|
| 136 |
+
# LLM recognizes text enclosed in XML tags (<doc>) as the clearest 'referencing context'.
|
| 137 |
+
block = (
|
| 138 |
+
f"<doc id=\"{i}\" source=\"{source}\" "
|
| 139 |
+
f"url=\"{meta.get('url', 'N/A')}\" "
|
| 140 |
+
f"relevance_score=\"{res['score']}\">\n"
|
| 141 |
+
f"{res['text']}\n"
|
| 142 |
+
f"</doc>"
|
| 143 |
+
)
|
| 144 |
+
context_blocks.append(block)
|
| 145 |
+
|
| 146 |
+
return "\n\n".join(context_blocks)
|
storage/qdrant_client.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# storage/qdrant_client.py
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from qdrant_client import QdrantClient, models
|
| 6 |
+
|
| 7 |
+
from core.exceptions import DatabaseError
|
| 8 |
+
from core.logger import setup_logger
|
| 9 |
+
|
| 10 |
+
logger = setup_logger("qdrant_client")
|
| 11 |
+
|
| 12 |
+
class QdrantStorage:
|
| 13 |
+
"""
|
| 14 |
+
Qdrant client performing hybrid search based on dense and sparse vectors
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self, path: str, collection_name: str = "knowledge_base"):
|
| 17 |
+
self.path = path
|
| 18 |
+
self.collection_name = collection_name
|
| 19 |
+
try:
|
| 20 |
+
# Local file system-based Qdrant connection (v1.10+)
|
| 21 |
+
self.client = QdrantClient(path=self.path)
|
| 22 |
+
logger.info(f"✅ Connected to local Qdrant at {self.path} (Collection: {self.collection_name})")
|
| 23 |
+
except Exception as e:
|
| 24 |
+
logger.critical(f"❌ Qdrant connection failed: {e}")
|
| 25 |
+
raise e
|
| 26 |
+
|
| 27 |
+
def hybrid_search(
|
| 28 |
+
self,
|
| 29 |
+
dense_vector: List[float],
|
| 30 |
+
sparse_indices: List[int],
|
| 31 |
+
sparse_values: List[float],
|
| 32 |
+
limit: int = 100
|
| 33 |
+
) -> List[models.ScoredPoint]:
|
| 34 |
+
"""
|
| 35 |
+
Qdrant's Native Fusion API to perform hybrid search with dense and sparse vectors.
|
| 36 |
+
Calculates RRF (Reciprocal Rank Fusion) at the database level and returns the results.
|
| 37 |
+
"""
|
| 38 |
+
try:
|
| 39 |
+
# Qdrant v1.10+ Latest Syntax: Fusion processing after multiple searches using Prefetch
|
| 40 |
+
results = self.client.query_points(
|
| 41 |
+
collection_name=self.collection_name,
|
| 42 |
+
prefetch=[
|
| 43 |
+
# 1. Sparse search query
|
| 44 |
+
models.Prefetch(
|
| 45 |
+
query=models.SparseVector(
|
| 46 |
+
indices=sparse_indices,
|
| 47 |
+
values=sparse_values
|
| 48 |
+
),
|
| 49 |
+
using="sparse",
|
| 50 |
+
limit=limit,
|
| 51 |
+
),
|
| 52 |
+
# 2. Dense search query
|
| 53 |
+
models.Prefetch(
|
| 54 |
+
query=dense_vector,
|
| 55 |
+
using="dense",
|
| 56 |
+
limit=limit,
|
| 57 |
+
),
|
| 58 |
+
],
|
| 59 |
+
# 3. Score merging (Fusion) of the two results above using the RRF method
|
| 60 |
+
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
| 61 |
+
limit=limit,
|
| 62 |
+
with_payload=True
|
| 63 |
+
)
|
| 64 |
+
return results.points
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.error(f"❌ Hybrid search failed: {e}", exc_info=True)
|
| 68 |
+
raise DatabaseError(f"Qdrant Hybrid search execution failed: {e}")
|
| 69 |
+
|
| 70 |
+
def close(self):
|
| 71 |
+
"""Qdrant client connection cleanup (if applicable)"""
|
| 72 |
+
if hasattr(self, 'client') and self.client:
|
| 73 |
+
self.client.close()
|
| 74 |
+
logger.info("🛑 Qdrant client connection closed.")
|
storage/sqlite_client.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# storage/sqlite_client.py
|
| 2 |
+
|
| 3 |
+
import sqlite3
|
| 4 |
+
from typing import Any, Dict, List
|
| 5 |
+
|
| 6 |
+
from core.exceptions import DatabaseError
|
| 7 |
+
from core.logger import setup_logger
|
| 8 |
+
|
| 9 |
+
logger = setup_logger("sqlite_client")
|
| 10 |
+
|
| 11 |
+
class SQLiteStorage:
|
| 12 |
+
def __init__(self, db_path: str):
|
| 13 |
+
self.db_path = db_path
|
| 14 |
+
try:
|
| 15 |
+
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
| 16 |
+
self.conn.row_factory = sqlite3.Row
|
| 17 |
+
logger.info(f"✅ Connected to SQLite at {self.db_path}")
|
| 18 |
+
except sqlite3.Error as e:
|
| 19 |
+
logger.critical(f"❌ SQLite connection failed: {e}")
|
| 20 |
+
raise DatabaseError(f"Database connection failed: {e}")
|
| 21 |
+
|
| 22 |
+
def get_enriched_chunks_dict(self, chunk_ids: List[int]) -> Dict[int, Dict[str, Any]]:
|
| 23 |
+
"""
|
| 24 |
+
Given a list of chunk_ids, retrieves the corresponding text and metadata from the SQLite database.
|
| 25 |
+
- This is designed for O(1) access in the search service, where we need to quickly map chunk_ids from Qdrant results to their full text and metadata for reranking and final response construction.
|
| 26 |
+
- The returned dictionary is structured as { chunk_id: { "text": "...", "metadata": {...} } }, allowing for efficient lookups during the search pipeline.
|
| 27 |
+
- The SQL query uses a JOIN to combine data from the chunks and documents tables, ensuring we get all necessary information in a single query for performance optimization.
|
| 28 |
+
- If the list of chunk_ids is empty, it returns an empty dictionary immediately to avoid unnecessary database queries.
|
| 29 |
+
- Error handling is included to catch and log any database issues that arise during query execution.
|
| 30 |
+
"""
|
| 31 |
+
if not chunk_ids:
|
| 32 |
+
return {}
|
| 33 |
+
|
| 34 |
+
placeholders = ",".join("?" * len(chunk_ids))
|
| 35 |
+
|
| 36 |
+
query = f"""
|
| 37 |
+
SELECT
|
| 38 |
+
c.chunk_id, c.text AS chunk_text,
|
| 39 |
+
d.doc_id, d.title, d.lang, d.url, d.date_modified
|
| 40 |
+
FROM chunks c
|
| 41 |
+
JOIN documents d ON c.doc_id = d.doc_id
|
| 42 |
+
WHERE c.chunk_id IN ({placeholders})
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
cur = self.conn.cursor()
|
| 47 |
+
cur.execute(query, chunk_ids)
|
| 48 |
+
rows = cur.fetchall()
|
| 49 |
+
|
| 50 |
+
# Transform the result into a dictionary for O(1) access: { chunk_id: { "text": "...", "metadata": {...} } }
|
| 51 |
+
result_dict = {}
|
| 52 |
+
for row in rows:
|
| 53 |
+
result_dict[row["chunk_id"]] = {
|
| 54 |
+
"text": row["chunk_text"],
|
| 55 |
+
"metadata": {
|
| 56 |
+
"doc_id": row["doc_id"],
|
| 57 |
+
"title": row["title"],
|
| 58 |
+
"lang": row["lang"],
|
| 59 |
+
"url": row["url"],
|
| 60 |
+
"date_modified": row["date_modified"]
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
return result_dict
|
| 64 |
+
|
| 65 |
+
except sqlite3.Error as e:
|
| 66 |
+
logger.error(f"Failed to fetch enriched chunks: {e}")
|
| 67 |
+
raise DatabaseError(f"Query execution failed: {e}")
|
| 68 |
+
|
| 69 |
+
def close(self):
|
| 70 |
+
if hasattr(self, 'conn') and self.conn:
|
| 71 |
+
self.conn.close()
|
| 72 |
+
logger.info("🛑 SQLite connection closed.")
|
templates/index.html
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html>
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8">
|
| 6 |
+
<title>Hybrid Knowledge Engine</title>
|
| 7 |
+
<style>
|
| 8 |
+
body {
|
| 9 |
+
max-width: 800px;
|
| 10 |
+
margin: 0 auto;
|
| 11 |
+
padding: 20px;
|
| 12 |
+
line-height: 1.6;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
.search-box {
|
| 16 |
+
background: #f4f4f4;
|
| 17 |
+
padding: 20px;
|
| 18 |
+
border-radius: 8px;
|
| 19 |
+
margin-bottom: 30px;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
.result-item {
|
| 23 |
+
border-bottom: 1px solid #eee;
|
| 24 |
+
padding: 15px 0;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
.metadata {
|
| 28 |
+
font-size: 0.85em;
|
| 29 |
+
color: #666;
|
| 30 |
+
margin-bottom: 5px;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
.score {
|
| 34 |
+
color: #2c3e50;
|
| 35 |
+
font-weight: bold;
|
| 36 |
+
background: #ecf0f1;
|
| 37 |
+
padding: 2px 6px;
|
| 38 |
+
border-radius: 4px;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
.content {
|
| 42 |
+
margin-top: 10px;
|
| 43 |
+
color: #333;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.latency {
|
| 47 |
+
color: #999;
|
| 48 |
+
font-size: 0.9em;
|
| 49 |
+
text-align: right;
|
| 50 |
+
}
|
| 51 |
+
</style>
|
| 52 |
+
</head>
|
| 53 |
+
|
| 54 |
+
<body>
|
| 55 |
+
<h1>Knowledge Engine</h1>
|
| 56 |
+
|
| 57 |
+
<div class="search-box">
|
| 58 |
+
<form method="post" action="/api/v1/search/demo">
|
| 59 |
+
<input type="text" name="query" value="{{ query }}" placeholder="Enter Query" style="width: 80%; padding: 10px;"
|
| 60 |
+
required>
|
| 61 |
+
<button type="submit" style="padding: 10px 20px; cursor: pointer;">Search</button>
|
| 62 |
+
</form>
|
| 63 |
+
</div>
|
| 64 |
+
|
| 65 |
+
{% if results is not none %}
|
| 66 |
+
<div class="latency">Search time: {{ latency_ms }}ms</div>
|
| 67 |
+
<h2>Results for "{{ query }}"</h2>
|
| 68 |
+
|
| 69 |
+
{% if results|length > 0 %}
|
| 70 |
+
{% for r in results %}
|
| 71 |
+
<div class="result-item">
|
| 72 |
+
<div class="metadata">
|
| 73 |
+
<span class="score">Score: {{ r.score }}</span> |
|
| 74 |
+
<strong>source: {{ r.metadata.title }}</strong>
|
| 75 |
+
{% if r.metadata.url %} | <a href="{{ r.metadata.url }}" target="_blank">Link</a>{% endif %}
|
| 76 |
+
</div>
|
| 77 |
+
<div class="content">
|
| 78 |
+
{{ r.text }}
|
| 79 |
+
</div>
|
| 80 |
+
</div>
|
| 81 |
+
{% endfor %}
|
| 82 |
+
{% else %}
|
| 83 |
+
<p>No search results found.</p>
|
| 84 |
+
{% endif %}
|
| 85 |
+
{% endif %}
|
| 86 |
+
|
| 87 |
+
{% if error_message %}
|
| 88 |
+
<p style="color: red;">{{ error_message }}</p>
|
| 89 |
+
{% endif %}
|
| 90 |
+
</body>
|
| 91 |
+
|
| 92 |
+
</html>
|