DevelopedBy-Siva commited on
Commit
b378103
·
1 Parent(s): 466d417
.dockerignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .venv/
6
+ venv/
7
+ .env
8
+ .git/
9
+ .gitignore
10
+ *.db
11
+ faiss/
12
+ uploads/
13
+ temp_uploads/
14
+ data/
15
+ rag_system.db
16
+
.gitignore ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for listreamlit==1.31.1braries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific templstreamlit==1.31.1ate is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ uploads/
210
+ temp_uploads/
211
+ data/
212
+ rag_system.db
213
+
214
+ test_demo.py
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ RUN apt-get update && apt-get install -y --no-install-recommends \
4
+ build-essential \
5
+ && rm -rf /var/lib/apt/lists/*
6
+
7
+ WORKDIR /app
8
+
9
+ COPY requirements.txt /app/requirements.txt
10
+ RUN pip install --no-cache-dir -r /app/requirements.txt
11
+
12
+ COPY . /app
13
+
14
+ ENV PYTHONUNBUFFERED=1
15
+ EXPOSE 7860
16
+
17
+ CMD ["uvicorn", "server_app:app", "--host", "0.0.0.0", "--port", "7860"]
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.109.2
2
+ uvicorn[standard]==0.27.1
3
+ python-multipart==0.0.9
4
+
5
+ sentence-transformers==2.3.1
6
+ faiss-cpu==1.9.0.post1
7
+ langchain==0.1.9
8
+ langchain-community==0.0.21
9
+ openai==1.12.0
10
+
11
+ pypdf==4.0.1
12
+ python-docx==1.1.0
13
+ python-magic==0.4.27
14
+
15
+ sqlalchemy==2.0.25
16
+
17
+ fastapi==0.109.2
18
+ uvicorn==0.27.1
19
+ python-multipart==0.0.9
20
+
21
+ python-dotenv==1.0.1
22
+ pydantic==2.6.1
23
+ numpy==1.26.4
24
+ pandas==2.2.0
25
+
26
+ httpx==0.27.2
27
+ openai>=1.3.0
server_app.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ from pydantic import BaseModel
5
+ from typing import Optional
6
+ import shutil
7
+ import os
8
+ from pathlib import Path
9
+ import sys
10
+ from openai import OpenAI
11
+ import json
12
+
13
+ sys.path.insert(0, str(Path(__file__).parent))
14
+
15
+ from src.rag_system import IncrementalRAGSystem
16
+ from src.database import get_db_session, DocumentVersion, DocumentChunk
17
+
18
+
19
+ client = OpenAI(
20
+ api_key=os.getenv("GROQ_API_KEY"), base_url="https://api.groq.com/openai/v1"
21
+ )
22
+
23
+
24
+ app = FastAPI(
25
+ title="Incremental RAG API",
26
+ description="API for document Q&A RAG System",
27
+ version="1.0.0",
28
+ )
29
+
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=[
33
+ "http://localhost:3000",
34
+ "https://document-qa-rag-system.vercel.app/",
35
+ ],
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
+
41
+ rag_system = None
42
+
43
+
44
+ @app.on_event("startup")
45
+ def startup():
46
+ global rag_system
47
+ rag_system = IncrementalRAGSystem()
48
+
49
+
50
+ TEMP_UPLOAD_DIR = "./temp_uploads"
51
+ Path(TEMP_UPLOAD_DIR).mkdir(exist_ok=True)
52
+
53
+
54
+ class QueryRequest(BaseModel):
55
+ question: str
56
+ version_id: Optional[int] = None
57
+ k: int = 5
58
+
59
+
60
+ class ComparisonRequest(BaseModel):
61
+ question: str
62
+ version_id_1: int
63
+ version_id_2: int
64
+ k: int = 3
65
+
66
+
67
+ @app.get("/")
68
+ async def root():
69
+ return {
70
+ "status": "online",
71
+ "message": "Document Q&A RAG API is running",
72
+ }
73
+
74
+
75
+ @app.post("/api/documents/upload")
76
+ async def upload_document(
77
+ file: UploadFile = File(...), doc_name: Optional[str] = Form(None)
78
+ ):
79
+ temp_file_path = None
80
+ try:
81
+ allowed_extensions = {".pdf", ".txt", ".docx"}
82
+ file_ext = Path(file.filename).suffix.lower()
83
+
84
+ if file_ext not in allowed_extensions:
85
+ raise HTTPException(
86
+ status_code=400, detail=f"File type {file_ext} not supported"
87
+ )
88
+
89
+ temp_file_path = Path(TEMP_UPLOAD_DIR) / file.filename
90
+ with open(temp_file_path, "wb") as buffer:
91
+ shutil.copyfileobj(file.file, buffer)
92
+
93
+ if not doc_name:
94
+ doc_name = Path(file.filename).stem
95
+
96
+ result = rag_system.add_document(
97
+ file_path=str(temp_file_path), doc_name=doc_name
98
+ )
99
+
100
+ temp_file_path.unlink()
101
+
102
+ return JSONResponse(
103
+ content={
104
+ "success": True,
105
+ "message": f"Document uploaded as version {result['version_number']}",
106
+ "data": result,
107
+ }
108
+ )
109
+
110
+ except Exception as e:
111
+ if temp_file_path and temp_file_path.exists():
112
+ temp_file_path.unlink()
113
+ raise HTTPException(status_code=500, detail=str(e))
114
+
115
+
116
+ def build_source_context(results):
117
+ parts = []
118
+ for i, r in enumerate(results, start=1):
119
+ excerpt = r["content"]
120
+ if len(excerpt) > 2000:
121
+ excerpt = excerpt[:2000] + "..."
122
+ parts.append(f"[Source {i}]\n{excerpt}")
123
+ return "\n\n".join(parts)
124
+
125
+
126
+ def extract_document_topics(chunks: list, max_topics: int = 5) -> list:
127
+
128
+ sample_text = "\n".join([c["content"] for c in chunks[:3]])
129
+
130
+ try:
131
+ prompt = f"""
132
+ Extract the main topics covered in this document.
133
+
134
+ Document sample:
135
+ {sample_text[:1000]}
136
+
137
+ Return JSON with main topics/sections:
138
+ {{
139
+ "topics": ["Topic 1", "Topic 2", "Topic 3"]
140
+ }}
141
+
142
+ Keep topics concise (2-4 words each). Maximum {max_topics} topics.
143
+ """
144
+
145
+ resp = client.chat.completions.create(
146
+ model="llama-3.3-70b-versatile",
147
+ messages=[{"role": "user", "content": prompt}],
148
+ temperature=0.3,
149
+ max_tokens=200,
150
+ response_format={"type": "json_object"},
151
+ )
152
+
153
+ result = json.loads(resp.choices[0].message.content)
154
+ return result.get("topics", [])[:max_topics]
155
+
156
+ except Exception as e:
157
+ words = sample_text.lower().split()
158
+ fallback_topics = []
159
+ policy_keywords = [
160
+ "policy",
161
+ "work",
162
+ "remote",
163
+ "vacation",
164
+ "benefits",
165
+ "security",
166
+ "equipment",
167
+ "eligibility",
168
+ ]
169
+ for keyword in policy_keywords:
170
+ if keyword in words:
171
+ fallback_topics.append(keyword.title())
172
+
173
+ return (
174
+ fallback_topics[:max_topics] if fallback_topics else ["General Information"]
175
+ )
176
+
177
+
178
+ @app.post("/api/query/generate")
179
+ async def query_with_llm(query_request: QueryRequest):
180
+ question = query_request.question.strip()
181
+
182
+ if len(question) < 3:
183
+ return {
184
+ "question": question,
185
+ "not_found": True,
186
+ "answer": "",
187
+ "message": "Question too short (minimum 3 characters)",
188
+ "sources": [],
189
+ }
190
+
191
+ results = rag_system.query(
192
+ question=question,
193
+ version_id=query_request.version_id,
194
+ k=query_request.k,
195
+ )
196
+
197
+ if not results:
198
+ return {
199
+ "question": question,
200
+ "not_found": True,
201
+ "answer": "",
202
+ "message": "No content found in this document version",
203
+ "suggestion": "Check if you selected the correct version or try searching all versions",
204
+ "sources": [],
205
+ }
206
+
207
+ top_score = results[0]["similarity_score"]
208
+
209
+ if top_score < 0.35:
210
+ topics = extract_document_topics(results)
211
+
212
+ return {
213
+ "question": question,
214
+ "not_found": True,
215
+ "answer": "",
216
+ "message": "No direct match for your question",
217
+ "topics": topics,
218
+ "suggestions": [
219
+ "Try asking about specific topics listed above",
220
+ "Use keywords from the document",
221
+ (
222
+ f"Example: 'What is the {topics[0].lower()}?'"
223
+ if topics
224
+ else "Be more specific"
225
+ ),
226
+ ],
227
+ "top_score": round(top_score, 3),
228
+ "sources": [],
229
+ }
230
+
231
+ force_low_confidence = False
232
+
233
+ if top_score < 0.4:
234
+ filtered = results[:3]
235
+ force_low_confidence = True
236
+ elif top_score > 0.6:
237
+ filtered = [r for r in results if r["similarity_score"] > 0.5][:3]
238
+ elif top_score > 0.45:
239
+ filtered = [r for r in results if r["similarity_score"] > 0.4][:2]
240
+ else:
241
+ filtered = results[:1]
242
+
243
+ context = build_source_context(filtered)
244
+ avg_sim = sum(r["similarity_score"] for r in filtered) / len(filtered)
245
+
246
+ system_msg = """You are a helpful document Q&A assistant.
247
+
248
+ IMPORTANT RULES:
249
+ 1. Answer using ONLY the provided context
250
+ 2. If context is relevant, provide an answer even if partial
251
+ 3. Only return not_found=true if context is COMPLETELY unrelated
252
+ 4. For general questions (like "policy" or "document"), summarize key points
253
+
254
+ You must return valid JSON in this format:
255
+ {
256
+ "not_found": false,
257
+ "answer": "Your answer here",
258
+ "confidence": "high|medium|low"
259
+ }
260
+
261
+ Only use not_found=true if truly nothing relevant exists."""
262
+
263
+ user_prompt = f"""
264
+ Context (avg similarity: {avg_sim:.2f}):
265
+ {context}
266
+
267
+ Question: {question}
268
+
269
+ Provide a helpful answer based on the context. If the question is general, summarize the main points."""
270
+
271
+ try:
272
+ resp = client.chat.completions.create(
273
+ model="llama-3.3-70b-versatile",
274
+ messages=[
275
+ {"role": "system", "content": system_msg},
276
+ {"role": "user", "content": user_prompt},
277
+ ],
278
+ temperature=0.1,
279
+ max_tokens=800,
280
+ response_format={"type": "json_object"},
281
+ )
282
+
283
+ text = resp.choices[0].message.content.strip()
284
+ except Exception as e:
285
+ raise HTTPException(status_code=500, detail=f"LLM API error: {str(e)}")
286
+
287
+ try:
288
+ j = json.loads(text)
289
+ except json.JSONDecodeError:
290
+ start = text.find("{")
291
+ end = text.rfind("}")
292
+ if start != -1 and end != -1:
293
+ try:
294
+ j = json.loads(text[start : end + 1])
295
+ except json.JSONDecodeError:
296
+ j = {
297
+ "not_found": False,
298
+ "answer": text,
299
+ "confidence": "low",
300
+ "note": "Response format was non-standard",
301
+ }
302
+ else:
303
+ raise HTTPException(
304
+ status_code=500, detail="Failed to parse LLM response as JSON"
305
+ )
306
+
307
+ j["sources"] = filtered
308
+ j["question"] = question
309
+ j["avg_similarity"] = round(avg_sim, 3)
310
+
311
+ if "confidence" not in j:
312
+ if avg_sim > 0.6:
313
+ j["confidence"] = "high"
314
+ elif avg_sim > 0.45:
315
+ j["confidence"] = "medium"
316
+ else:
317
+ j["confidence"] = "low"
318
+
319
+ if force_low_confidence:
320
+ j["confidence"] = "low"
321
+ j["warning"] = "Answer based on limited context relevance"
322
+
323
+ return j
324
+
325
+
326
+ @app.get("/api/documents")
327
+ async def list_documents():
328
+ try:
329
+ documents = rag_system.get_all_documents()
330
+ return documents
331
+ except Exception as e:
332
+ raise HTTPException(status_code=500, detail=str(e))
333
+
334
+
335
+ @app.get("/api/documents/{doc_name}/versions")
336
+ async def get_document_versions(doc_name: str):
337
+ try:
338
+ versions = rag_system.get_document_versions(doc_name)
339
+ if not versions:
340
+ raise HTTPException(
341
+ status_code=404, detail=f"Document '{doc_name}' not found"
342
+ )
343
+ return versions
344
+ except HTTPException:
345
+ raise
346
+ except Exception as e:
347
+ raise HTTPException(status_code=500, detail=str(e))
348
+
349
+
350
+ @app.get("/api/documents/{doc_name}/versions/{version_id}/diff")
351
+ async def get_version_diff(doc_name: str, version_id: int):
352
+ try:
353
+ session = get_db_session()
354
+
355
+ try:
356
+ current_version = (
357
+ session.query(DocumentVersion).filter_by(id=version_id).first()
358
+ )
359
+
360
+ if not current_version:
361
+ raise HTTPException(status_code=404, detail="Version not found")
362
+
363
+ prev_version = (
364
+ session.query(DocumentVersion)
365
+ .filter_by(
366
+ document_id=current_version.document_id,
367
+ version_number=current_version.version_number - 1,
368
+ )
369
+ .first()
370
+ )
371
+
372
+ if not prev_version:
373
+ return {
374
+ "success": True,
375
+ "message": "This is the first version",
376
+ "is_first_version": True,
377
+ "current_version": current_version.version_number,
378
+ }
379
+
380
+ current_chunks = [chunk.content for chunk in current_version.chunks]
381
+ prev_chunks = [chunk.content for chunk in prev_version.chunks]
382
+
383
+ current_text = "\n\n".join(current_chunks)
384
+ prev_text = "\n\n".join(prev_chunks)
385
+
386
+ stats = {
387
+ "chunks_added": len(current_chunks) - len(prev_chunks),
388
+ "current_chunks": len(current_chunks),
389
+ "previous_chunks": len(prev_chunks),
390
+ "current_version": current_version.version_number,
391
+ "previous_version": prev_version.version_number,
392
+ }
393
+
394
+ system_msg = """You are analyzing document changes.
395
+ Identify what changed between two versions.
396
+ Be specific and concise.
397
+ You must respond with valid JSON only."""
398
+
399
+ user_prompt = f"""
400
+ Previous Version:
401
+ {prev_text[:3000]}...
402
+
403
+ Current Version:
404
+ {current_text[:3000]}...
405
+
406
+ Analyze the changes and return valid JSON in this format:
407
+ {{
408
+ "summary": "Brief overview of changes",
409
+ "key_changes": [
410
+ {{"type": "added|modified|removed", "description": "what changed"}},
411
+ ],
412
+ "impact": "low|medium|high"
413
+ }}
414
+ """
415
+ try:
416
+ resp = client.chat.completions.create(
417
+ model="llama-3.3-70b-versatile",
418
+ messages=[
419
+ {"role": "system", "content": system_msg},
420
+ {"role": "user", "content": user_prompt},
421
+ ],
422
+ temperature=0.1,
423
+ max_tokens=500,
424
+ response_format={"type": "json_object"}, # Now this works
425
+ )
426
+
427
+ llm_response = resp.choices[0].message.content.strip()
428
+
429
+ try:
430
+ diff_analysis = json.loads(llm_response)
431
+ except json.JSONDecodeError as e:
432
+ print(f"Failed to parse LLM response: {llm_response}")
433
+ diff_analysis = {
434
+ "summary": f"Version {current_version.version_number} has {len(current_chunks) - len(prev_chunks)} more chunks than version {prev_version.version_number}",
435
+ "key_changes": [
436
+ {
437
+ "type": "modified",
438
+ "description": f"Content updated with {abs(len(current_chunks) - len(prev_chunks))} chunk difference",
439
+ }
440
+ ],
441
+ "impact": "medium",
442
+ }
443
+
444
+ except Exception as llm_error:
445
+ print(f"LLM API error: {llm_error}")
446
+ diff_analysis = {
447
+ "summary": "Unable to generate detailed analysis",
448
+ "key_changes": [
449
+ {
450
+ "type": "modified",
451
+ "description": f"{len(current_chunks)} chunks in current version vs {len(prev_chunks)} in previous",
452
+ }
453
+ ],
454
+ "impact": "unknown",
455
+ }
456
+
457
+ return {
458
+ "success": True,
459
+ "is_first_version": False,
460
+ "stats": stats,
461
+ "analysis": diff_analysis,
462
+ "version_info": {
463
+ "current": {
464
+ "id": current_version.id,
465
+ "number": current_version.version_number,
466
+ "date": current_version.upload_date.isoformat(),
467
+ },
468
+ "previous": {
469
+ "id": prev_version.id,
470
+ "number": prev_version.version_number,
471
+ "date": prev_version.upload_date.isoformat(),
472
+ },
473
+ },
474
+ }
475
+
476
+ finally:
477
+ session.close()
478
+
479
+ except HTTPException:
480
+ raise
481
+ except json.JSONDecodeError as e:
482
+ raise HTTPException(
483
+ status_code=500, detail=f"Failed to parse LLM response: {str(e)}"
484
+ )
485
+ except Exception as e:
486
+ raise HTTPException(status_code=500, detail=str(e))
487
+
488
+
489
+ @app.post("/api/compare/detailed")
490
+ async def compare_versions_detailed(comparison: ComparisonRequest):
491
+
492
+ try:
493
+ session = get_db_session()
494
+
495
+ try:
496
+ v1 = (
497
+ session.query(DocumentVersion)
498
+ .filter_by(id=comparison.version_id_1)
499
+ .first()
500
+ )
501
+ v2 = (
502
+ session.query(DocumentVersion)
503
+ .filter_by(id=comparison.version_id_2)
504
+ .first()
505
+ )
506
+
507
+ if not v1 or not v2:
508
+ raise HTTPException(status_code=404, detail="Version not found")
509
+
510
+ v1_chunks = [chunk.content for chunk in v1.chunks]
511
+ v2_chunks = [chunk.content for chunk in v2.chunks]
512
+
513
+ v1_text = "\n\n".join(v1_chunks)
514
+ v2_text = "\n\n".join(v2_chunks)
515
+
516
+ if comparison.question:
517
+ results_v1 = rag_system.query(
518
+ question=comparison.question,
519
+ version_id=comparison.version_id_1,
520
+ k=comparison.k,
521
+ )
522
+
523
+ results_v2 = rag_system.query(
524
+ question=comparison.question,
525
+ version_id=comparison.version_id_2,
526
+ k=comparison.k,
527
+ )
528
+
529
+ context_v1 = "\n".join([r["content"] for r in results_v1[:2]])
530
+ context_v2 = "\n".join([r["content"] for r in results_v2[:2]])
531
+
532
+ system_msg = """Compare how two document versions answer the same question.
533
+ Identify specific differences."""
534
+
535
+ user_prompt = f"""
536
+ Question: {comparison.question}
537
+
538
+ Version {v1.version_number} says:
539
+ {context_v1}
540
+
541
+ Version {v2.version_number} says:
542
+ {context_v2}
543
+
544
+ Return JSON:
545
+ {{
546
+ "answer_v1": "Answer from version 1",
547
+ "answer_v2": "Answer from version 2",
548
+ "changed": true/false,
549
+ "differences": [
550
+ {{"aspect": "what changed", "v1": "old value", "v2": "new value"}}
551
+ ],
552
+ "summary": "Overall comparison"
553
+ }}
554
+ """
555
+ else:
556
+ system_msg = """Compare two document versions.
557
+ Identify all significant changes."""
558
+
559
+ user_prompt = f"""
560
+ Version {v1.version_number}:
561
+ {v1_text[:4000]}...
562
+
563
+ Version {v2.version_number}:
564
+ {v2_text[:4000]}...
565
+
566
+ Return JSON:
567
+ {{
568
+ "overall_change": "high|medium|low",
569
+ "summary": "What changed overall",
570
+ "sections_changed": ["section 1", "section 2"],
571
+ "key_differences": [
572
+ {{"category": "category", "description": "what changed", "type": "added|modified|removed"}}
573
+ ],
574
+ "recommendations": "Who should review these changes"
575
+ }}
576
+ """
577
+
578
+ resp = client.chat.completions.create(
579
+ model="llama-3.3-70b-versatile",
580
+ messages=[
581
+ {"role": "system", "content": system_msg},
582
+ {"role": "user", "content": user_prompt},
583
+ ],
584
+ temperature=0.1,
585
+ max_tokens=1000,
586
+ )
587
+
588
+ analysis = json.loads(resp.choices[0].message.content)
589
+
590
+ return {
591
+ "success": True,
592
+ "question": comparison.question if comparison.question else None,
593
+ "version_info": {
594
+ "version_1": {
595
+ "id": v1.id,
596
+ "number": v1.version_number,
597
+ "date": v1.upload_date.isoformat(),
598
+ "chunks": len(v1_chunks),
599
+ },
600
+ "version_2": {
601
+ "id": v2.id,
602
+ "number": v2.version_number,
603
+ "date": v2.upload_date.isoformat(),
604
+ "chunks": len(v2_chunks),
605
+ },
606
+ },
607
+ "analysis": analysis,
608
+ "stats": {
609
+ "chunks_difference": len(v2_chunks) - len(v1_chunks),
610
+ "text_length_v1": len(v1_text),
611
+ "text_length_v2": len(v2_text),
612
+ },
613
+ }
614
+
615
+ finally:
616
+ session.close()
617
+
618
+ except HTTPException:
619
+ raise
620
+ except json.JSONDecodeError:
621
+ raise HTTPException(status_code=500, detail="Failed to parse LLM response")
622
+ except Exception as e:
623
+ raise HTTPException(status_code=500, detail=str(e))
624
+
625
+
626
+ if __name__ == "__main__":
627
+ import uvicorn
628
+
629
+ uvicorn.run("server_app:app", host="0.0.0.0", port=8000, reload=True)
src/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Incremental RAG System - A production-ready RAG with document versioning
3
+ """
4
+
5
+ from .rag_system import IncrementalRAGSystem
6
+ from .embeddings import EmbeddingGenerator
7
+ from .vector_store import FAISSVectorStore
8
+ from .document_processor import DocumentProcessor
9
+
10
+ __version__ = "1.0.0"
11
+ __all__ = [
12
+ "IncrementalRAGSystem",
13
+ "EmbeddingGenerator",
14
+ "FAISSVectorStore",
15
+ "DocumentProcessor",
16
+ ]
src/database.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import (
2
+ create_engine,
3
+ Column,
4
+ Integer,
5
+ String,
6
+ DateTime,
7
+ Text,
8
+ ForeignKey,
9
+ )
10
+ from sqlalchemy.ext.declarative import declarative_base
11
+ from sqlalchemy.orm import sessionmaker, relationship
12
+ from datetime import datetime
13
+ import os
14
+
15
+ Base = declarative_base()
16
+
17
+
18
+ class Document(Base):
19
+
20
+ __tablename__ = "documents"
21
+
22
+ id = Column(Integer, primary_key=True)
23
+ doc_name = Column(String(255), nullable=False)
24
+ created_at = Column(DateTime, default=datetime.utcnow)
25
+
26
+ versions = relationship(
27
+ "DocumentVersion", back_populates="document", cascade="all, delete-orphan"
28
+ )
29
+
30
+ def __repr__(self):
31
+ return f"<Document(id={self.id}, name='{self.doc_name}')>"
32
+
33
+
34
+ class DocumentVersion(Base):
35
+
36
+ __tablename__ = "document_versions"
37
+
38
+ id = Column(Integer, primary_key=True)
39
+ document_id = Column(Integer, ForeignKey("documents.id"), nullable=False)
40
+ version_number = Column(Integer, nullable=False)
41
+ file_path = Column(String(512), nullable=False)
42
+ upload_date = Column(DateTime, default=datetime.utcnow)
43
+ file_hash = Column(String(64))
44
+ doc_metadata = Column(Text)
45
+ document = relationship("Document", back_populates="versions")
46
+ chunks = relationship(
47
+ "DocumentChunk", back_populates="version", cascade="all, delete-orphan"
48
+ )
49
+
50
+ def __repr__(self):
51
+ return f"<DocumentVersion(doc_id={self.document_id}, v{self.version_number})>"
52
+
53
+
54
+ class DocumentChunk(Base):
55
+
56
+ __tablename__ = "document_chunks"
57
+
58
+ id = Column(Integer, primary_key=True)
59
+ version_id = Column(Integer, ForeignKey("document_versions.id"), nullable=False)
60
+ chunk_index = Column(Integer, nullable=False)
61
+ content = Column(Text, nullable=False)
62
+ faiss_index = Column(Integer)
63
+
64
+ version = relationship("DocumentVersion", back_populates="chunks")
65
+
66
+ def __repr__(self):
67
+ return f"<DocumentChunk(id={self.id}, chunk_index={self.chunk_index})>"
68
+
69
+
70
+ def init_db(database_url: str = None):
71
+ if database_url is None:
72
+ database_url = os.getenv("DATABASE_URL", "sqlite:///./rag_system.db")
73
+
74
+ engine = create_engine(database_url, echo=False)
75
+ Base.metadata.create_all(engine)
76
+
77
+ SessionLocal = sessionmaker(bind=engine)
78
+ return engine, SessionLocal
79
+
80
+
81
+ def get_db_session(database_url: str = None):
82
+ _, SessionLocal = init_db(database_url)
83
+ return SessionLocal()
src/document_processor.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from typing import List, Tuple
3
+ from pathlib import Path
4
+ import pypdf
5
+
6
+
7
+ class DocumentProcessor:
8
+
9
+ def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
10
+ self.chunk_size = chunk_size
11
+ self.chunk_overlap = chunk_overlap
12
+
13
+ def extract_text_from_pdf(self, file_path: str) -> str:
14
+ text = ""
15
+ try:
16
+ with open(file_path, "rb") as file:
17
+ pdf_reader = pypdf.PdfReader(file)
18
+ for page in pdf_reader.pages:
19
+ text += page.extract_text() + "\n"
20
+ except Exception as e:
21
+ raise ValueError(f"Error reading PDF: {str(e)}")
22
+
23
+ return text.strip()
24
+
25
+ def chunk_text(self, text: str) -> List[str]:
26
+ if not text:
27
+ return []
28
+
29
+ chunks = []
30
+ start = 0
31
+ text_length = len(text)
32
+
33
+ while start < text_length:
34
+ end = start + self.chunk_size
35
+ chunk = text[start:end]
36
+
37
+ if end < text_length:
38
+ last_period = chunk.rfind(".")
39
+ last_newline = chunk.rfind("\n")
40
+ break_point = max(last_period, last_newline)
41
+
42
+ if break_point > self.chunk_size * 0.5:
43
+ chunk = chunk[: break_point + 1]
44
+ end = start + break_point + 1
45
+
46
+ chunks.append(chunk.strip())
47
+
48
+ start = end - self.chunk_overlap
49
+
50
+ return [c for c in chunks if c]
51
+
52
+ def process_document(self, file_path: str) -> Tuple[str, List[str]]:
53
+
54
+ file_ext = Path(file_path).suffix.lower()
55
+
56
+ if file_ext == ".pdf":
57
+ text = self.extract_text_from_pdf(file_path)
58
+ elif file_ext == ".txt":
59
+ with open(file_path, "r", encoding="utf-8") as f:
60
+ text = f.read()
61
+ else:
62
+ raise ValueError(f"Unsupported file type: {file_ext}")
63
+
64
+ chunks = self.chunk_text(text)
65
+
66
+ return text, chunks
67
+
68
+ @staticmethod
69
+ def compute_file_hash(file_path: str) -> str:
70
+ hash_md5 = hashlib.md5()
71
+ with open(file_path, "rb") as f:
72
+ for chunk in iter(lambda: f.read(4096), b""):
73
+ hash_md5.update(chunk)
74
+ return hash_md5.hexdigest()
src/embeddings.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+ from sentence_transformers import SentenceTransformer
4
+ import os
5
+
6
+
7
+ class EmbeddingGenerator:
8
+
9
+ def __init__(self, model_name: str = None):
10
+ self.model_name = model_name or os.getenv("EMBEDDING_MODEL", "all-MiniLM-L6-v2")
11
+ print(f"Loading embedding model: {self.model_name}")
12
+ self.model = SentenceTransformer(self.model_name)
13
+ self.embedding_dim = self.model.get_sentence_embedding_dimension()
14
+ print(f"Model loaded. Embedding dimension: {self.embedding_dim}")
15
+
16
+ def embed_text(self, text: str) -> np.ndarray:
17
+ return self.model.encode(text, convert_to_numpy=True)
18
+
19
+ def embed_batch(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
20
+ if not texts:
21
+ return np.array([])
22
+
23
+ embeddings = self.model.encode(
24
+ texts,
25
+ batch_size=batch_size,
26
+ convert_to_numpy=True,
27
+ show_progress_bar=len(texts) > 10,
28
+ )
29
+
30
+ return embeddings
31
+
32
+ def get_embedding_dim(self) -> int:
33
+ return self.embedding_dim
src/rag_system.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+ from typing import List, Tuple, Optional
5
+ from datetime import datetime
6
+
7
+ from src.database import (
8
+ init_db,
9
+ get_db_session,
10
+ Document,
11
+ DocumentVersion,
12
+ DocumentChunk,
13
+ )
14
+ from src.document_processor import DocumentProcessor
15
+ from src.embeddings import EmbeddingGenerator
16
+ from src.vector_store import FAISSVectorStore
17
+
18
+
19
+ class IncrementalRAGSystem:
20
+
21
+ def __init__(
22
+ self,
23
+ database_url: str = None,
24
+ embedding_model: str = None,
25
+ index_path: str = None,
26
+ upload_dir: str = None,
27
+ ):
28
+
29
+ print("Initializing Incremental RAG System...")
30
+
31
+ self.database_url = database_url or os.getenv(
32
+ "DATABASE_URL", "sqlite:///./rag_system.db"
33
+ )
34
+ init_db(self.database_url)
35
+
36
+ self.processor = DocumentProcessor(chunk_size=512, chunk_overlap=50)
37
+ self.embedder = EmbeddingGenerator(model_name=embedding_model)
38
+ self.vector_store = FAISSVectorStore(
39
+ embedding_dim=self.embedder.get_embedding_dim(),
40
+ index_path=index_path or "./data/faiss_index",
41
+ )
42
+
43
+ self.upload_dir = upload_dir or "./uploads"
44
+ Path(self.upload_dir).mkdir(parents=True, exist_ok=True)
45
+
46
+ print("RAG System initialized successfully!")
47
+
48
+ def add_document(self, file_path: str, doc_name: str = None) -> dict:
49
+
50
+ if not Path(file_path).exists():
51
+ raise FileNotFoundError(f"File not found: {file_path}")
52
+
53
+ if doc_name is None:
54
+ doc_name = Path(file_path).stem
55
+
56
+ print(f"\nProcessing document: {doc_name}")
57
+
58
+ full_text, chunks = self.processor.process_document(file_path)
59
+ file_hash = self.processor.compute_file_hash(file_path)
60
+
61
+ print(f" - Extracted {len(chunks)} chunks")
62
+
63
+ session = get_db_session(self.database_url)
64
+
65
+ try:
66
+ document = session.query(Document).filter_by(doc_name=doc_name).first()
67
+
68
+ if document is None:
69
+ document = Document(doc_name=doc_name)
70
+ session.add(document)
71
+ session.flush()
72
+ version_number = 1
73
+ print(f" - Created new document (ID: {document.id})")
74
+ else:
75
+ max_version = (
76
+ session.query(DocumentVersion)
77
+ .filter_by(document_id=document.id)
78
+ .count()
79
+ )
80
+ version_number = max_version + 1
81
+ print(f" - Adding version {version_number} to existing document")
82
+
83
+ dest_path = (
84
+ Path(self.upload_dir)
85
+ / f"{doc_name}_v{version_number}{Path(file_path).suffix}"
86
+ )
87
+ shutil.copy2(file_path, dest_path)
88
+
89
+ version = DocumentVersion(
90
+ document_id=document.id,
91
+ version_number=version_number,
92
+ file_path=str(dest_path),
93
+ file_hash=file_hash,
94
+ )
95
+ session.add(version)
96
+ session.flush()
97
+
98
+ print(f" - Generating embeddings...")
99
+ embeddings = self.embedder.embed_batch(chunks)
100
+
101
+ metadata_list = [
102
+ {
103
+ "document_id": document.id,
104
+ "version_id": version.id,
105
+ "chunk_index": i,
106
+ "doc_name": doc_name,
107
+ "version_number": version_number,
108
+ "content": chunk,
109
+ }
110
+ for i, chunk in enumerate(chunks)
111
+ ]
112
+
113
+ faiss_ids = self.vector_store.add_embeddings(embeddings, metadata_list)
114
+
115
+ for i, (chunk, faiss_id) in enumerate(zip(chunks, faiss_ids)):
116
+ db_chunk = DocumentChunk(
117
+ version_id=version.id,
118
+ chunk_index=i,
119
+ content=chunk,
120
+ faiss_index=faiss_id,
121
+ )
122
+ session.add(db_chunk)
123
+
124
+ session.commit()
125
+
126
+ self.vector_store.save()
127
+
128
+ print(f"Successfully added {doc_name} v{version_number}")
129
+
130
+ return {
131
+ "document_id": document.id,
132
+ "document_name": doc_name,
133
+ "version_id": version.id,
134
+ "version_number": version_number,
135
+ "num_chunks": len(chunks),
136
+ "file_path": str(dest_path),
137
+ }
138
+
139
+ except Exception as e:
140
+ session.rollback()
141
+ raise e
142
+ finally:
143
+ session.close()
144
+
145
+ def query(
146
+ self, question: str, version_id: Optional[int] = None, k: int = 5
147
+ ) -> List[dict]:
148
+ print(f"\nQuerying: '{question}'")
149
+
150
+ query_embedding = self.embedder.embed_text(question)
151
+
152
+ results = self.vector_store.search(
153
+ query_embedding, k=k, version_filter=version_id
154
+ )
155
+
156
+ print(f" - Found {len(results)} relevant chunks")
157
+
158
+ formatted_results = []
159
+ for distance, metadata in results:
160
+ formatted_results.append(
161
+ {
162
+ "content": metadata.get("content", ""),
163
+ "document_name": metadata.get("doc_name", ""),
164
+ "version": metadata.get("version_number", ""),
165
+ "chunk_index": metadata.get("chunk_index", ""),
166
+ "similarity_score": 1 / (1 + distance),
167
+ }
168
+ )
169
+
170
+ return formatted_results
171
+
172
+ def get_document_versions(self, doc_name: str) -> List[dict]:
173
+ session = get_db_session(self.database_url)
174
+
175
+ try:
176
+ document = session.query(Document).filter_by(doc_name=doc_name).first()
177
+
178
+ if not document:
179
+ return []
180
+
181
+ versions = (
182
+ session.query(DocumentVersion)
183
+ .filter_by(document_id=document.id)
184
+ .order_by(DocumentVersion.version_number)
185
+ .all()
186
+ )
187
+
188
+ return [
189
+ {
190
+ "version_id": v.id,
191
+ "version_number": v.version_number,
192
+ "upload_date": v.upload_date.isoformat(),
193
+ "file_path": v.file_path,
194
+ "num_chunks": len(v.chunks),
195
+ }
196
+ for v in versions
197
+ ]
198
+ finally:
199
+ session.close()
200
+
201
+ def get_all_documents(self) -> List[dict]:
202
+ session = get_db_session(self.database_url)
203
+
204
+ try:
205
+ documents = session.query(Document).all()
206
+
207
+ result = []
208
+ for doc in documents:
209
+ result.append(
210
+ {
211
+ "document_id": doc.id,
212
+ "document_name": doc.doc_name,
213
+ "created_at": doc.created_at.isoformat(),
214
+ "num_versions": len(doc.versions),
215
+ }
216
+ )
217
+
218
+ return result
219
+ finally:
220
+ session.close()
221
+
222
+ def get_stats(self) -> dict:
223
+ session = get_db_session(self.database_url)
224
+
225
+ try:
226
+ num_documents = session.query(Document).count()
227
+ num_versions = session.query(DocumentVersion).count()
228
+ num_chunks = session.query(DocumentChunk).count()
229
+
230
+ vector_stats = self.vector_store.get_stats()
231
+
232
+ return {
233
+ "num_documents": num_documents,
234
+ "num_versions": num_versions,
235
+ "num_chunks": num_chunks,
236
+ "vector_store": vector_stats,
237
+ }
238
+ finally:
239
+ session.close()
src/vector_store.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+ import pickle
4
+ from pathlib import Path
5
+ from typing import List, Tuple, Optional
6
+
7
+
8
+ class FAISSVectorStore:
9
+
10
+ def __init__(self, embedding_dim: int, index_path: str = None):
11
+
12
+ self.embedding_dim = embedding_dim
13
+ self.index_path = index_path or "./data/faiss_index"
14
+ self.index = None
15
+ self.id_to_metadata = {} # Map FAISS ID to metadata
16
+ self.current_id = 0
17
+
18
+ Path(self.index_path).parent.mkdir(parents=True, exist_ok=True)
19
+
20
+ if Path(f"{self.index_path}.faiss").exists():
21
+ self.load()
22
+ else:
23
+ self._create_new_index()
24
+
25
+ def _create_new_index(self):
26
+ self.index = faiss.IndexFlatL2(self.embedding_dim)
27
+ self.id_to_metadata = {}
28
+ self.current_id = 0
29
+ print(f"Created new FAISS index with dimension {self.embedding_dim}")
30
+
31
+ def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[int]:
32
+
33
+ if embeddings.shape[1] != self.embedding_dim:
34
+ raise ValueError(
35
+ f"Embedding dimension mismatch: expected {self.embedding_dim}, "
36
+ f"got {embeddings.shape[1]}"
37
+ )
38
+
39
+ embeddings = embeddings.astype("float32")
40
+
41
+ num_vectors = embeddings.shape[0]
42
+ ids = list(range(self.current_id, self.current_id + num_vectors))
43
+
44
+ self.index.add(embeddings)
45
+
46
+ for i, meta in zip(ids, metadata):
47
+ self.id_to_metadata[i] = meta
48
+
49
+ self.current_id += num_vectors
50
+
51
+ print(f"Added {num_vectors} vectors. Total: {self.index.ntotal}")
52
+ return ids
53
+
54
+ def search(
55
+ self,
56
+ query_embedding: np.ndarray,
57
+ k: int = 5,
58
+ version_filter: Optional[int] = None,
59
+ ) -> List[Tuple[float, dict]]:
60
+
61
+ if self.index.ntotal == 0:
62
+ return []
63
+
64
+ if query_embedding.ndim == 1:
65
+ query_embedding = query_embedding.reshape(1, -1)
66
+ query_embedding = query_embedding.astype("float32")
67
+
68
+ search_k = k * 10 if version_filter else k
69
+ distances, indices = self.index.search(
70
+ query_embedding, min(search_k, self.index.ntotal)
71
+ )
72
+
73
+ results = []
74
+ for dist, idx in zip(distances[0], indices[0]):
75
+ if idx == -1:
76
+ continue
77
+
78
+ metadata = self.id_to_metadata.get(int(idx), {})
79
+
80
+ if version_filter is not None:
81
+ if metadata.get("version_id") != version_filter:
82
+ continue
83
+
84
+ results.append((float(dist), metadata))
85
+
86
+ if len(results) >= k:
87
+ break
88
+
89
+ return results
90
+
91
+ def save(self):
92
+ faiss.write_index(self.index, f"{self.index_path}.faiss")
93
+
94
+ with open(f"{self.index_path}.meta", "wb") as f:
95
+ pickle.dump(
96
+ {
97
+ "id_to_metadata": self.id_to_metadata,
98
+ "current_id": self.current_id,
99
+ "embedding_dim": self.embedding_dim,
100
+ },
101
+ f,
102
+ )
103
+
104
+ print(f"Saved index to {self.index_path}")
105
+
106
+ def load(self):
107
+ try:
108
+ self.index = faiss.read_index(f"{self.index_path}.faiss")
109
+
110
+ with open(f"{self.index_path}.meta", "rb") as f:
111
+ data = pickle.load(f)
112
+ self.id_to_metadata = data["id_to_metadata"]
113
+ self.current_id = data["current_id"]
114
+ self.embedding_dim = data["embedding_dim"]
115
+
116
+ print(f"Loaded index from {self.index_path} ({self.index.ntotal} vectors)")
117
+ except Exception as e:
118
+ print(f"Error loading index: {e}")
119
+ self._create_new_index()
120
+
121
+ def get_stats(self) -> dict:
122
+ return {
123
+ "total_vectors": self.index.ntotal if self.index else 0,
124
+ "embedding_dim": self.embedding_dim,
125
+ "index_path": self.index_path,
126
+ }