[KM-437][DB] Add db pipeline

#2
by rhbt6767 - opened
.gitignore CHANGED
@@ -26,6 +26,8 @@ test/users/user_accounts.csv
26
  .env.prd
27
  .env.example
28
 
 
 
29
  erd/
30
  playground/
31
  playground_retriever.py
 
26
  .env.prd
27
  .env.example
28
 
29
+ CLAUDE.md
30
+
31
  erd/
32
  playground/
33
  playground_retriever.py
src/api/v1/document.py CHANGED
@@ -1,21 +1,20 @@
1
  """Document management API endpoints."""
2
-
3
- from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File, status
4
  from sqlalchemy.ext.asyncio import AsyncSession
5
  from src.db.postgres.connection import get_db
6
  from src.document.document_service import document_service
7
- from src.knowledge.processing_service import knowledge_processor
8
- from src.storage.az_blob.az_blob import blob_storage
9
  from src.middlewares.logging import get_logger, log_execution
10
  from src.middlewares.rate_limit import limiter
 
11
  from pydantic import BaseModel
12
  from typing import List
13
-
14
  logger = get_logger("document_api")
15
-
16
  router = APIRouter(prefix="/api/v1", tags=["Documents"])
17
-
18
-
19
  class DocumentResponse(BaseModel):
20
  id: str
21
  filename: str
@@ -23,8 +22,8 @@ class DocumentResponse(BaseModel):
23
  file_size: int
24
  file_type: str
25
  created_at: str
26
-
27
-
28
  @router.get("/documents/{user_id}", response_model=List[DocumentResponse])
29
  @log_execution(logger)
30
  async def list_documents(
@@ -44,8 +43,8 @@ async def list_documents(
44
  )
45
  for doc in documents
46
  ]
47
-
48
-
49
  @router.post("/document/upload")
50
  @limiter.limit("10/minute")
51
  @log_execution(logger)
@@ -57,57 +56,12 @@ async def upload_document(
57
  ):
58
  """Upload a document."""
59
  if not user_id:
60
- raise HTTPException(
61
- status_code=400,
62
- detail="user_id is required"
63
- )
64
-
65
- try:
66
- # Read file content
67
- content = await file.read()
68
- file_size = len(content)
69
-
70
- # Get file type
71
- filename = file.filename
72
- file_type = filename.split('.')[-1].lower() if '.' in filename else 'txt'
73
-
74
- if file_type not in ['pdf', 'docx', 'txt']:
75
- raise HTTPException(
76
- status_code=400,
77
- detail="Unsupported file type. Supported: pdf, docx, txt"
78
- )
79
-
80
- # Upload to blob storage
81
- blob_name = await blob_storage.upload_file(content, filename, user_id)
82
-
83
- # Create document record
84
- document = await document_service.create_document(
85
- db=db,
86
- user_id=user_id,
87
- filename=filename,
88
- blob_name=blob_name,
89
- file_size=file_size,
90
- file_type=file_type
91
- )
92
-
93
- return {
94
- "status": "success",
95
- "message": "Document uploaded successfully",
96
- "data": {
97
- "id": document.id,
98
- "filename": document.filename,
99
- "status": document.status
100
- }
101
- }
102
-
103
- except Exception as e:
104
- logger.error(f"Upload failed for user {user_id}", error=str(e))
105
- raise HTTPException(
106
- status_code=500,
107
- detail=f"Upload failed: {str(e)}"
108
- )
109
-
110
-
111
  @router.delete("/document/delete")
112
  @log_execution(logger)
113
  async def delete_document(
@@ -116,31 +70,10 @@ async def delete_document(
116
  db: AsyncSession = Depends(get_db)
117
  ):
118
  """Delete a document."""
119
- document = await document_service.get_document(db, document_id)
120
-
121
- if not document:
122
- raise HTTPException(
123
- status_code=404,
124
- detail="Document not found"
125
- )
126
-
127
- if document.user_id != user_id:
128
- raise HTTPException(
129
- status_code=403,
130
- detail="Access denied"
131
- )
132
-
133
- success = await document_service.delete_document(db, document_id)
134
-
135
- if success:
136
- return {"status": "success", "message": "Document deleted successfully"}
137
- else:
138
- raise HTTPException(
139
- status_code=500,
140
- detail="Failed to delete document"
141
- )
142
-
143
-
144
  @router.post("/document/process")
145
  @log_execution(logger)
146
  async def process_document(
@@ -149,45 +82,6 @@ async def process_document(
149
  db: AsyncSession = Depends(get_db)
150
  ):
151
  """Process document and ingest to vector index."""
152
- document = await document_service.get_document(db, document_id)
153
-
154
- if not document:
155
- raise HTTPException(
156
- status_code=404,
157
- detail="Document not found"
158
- )
159
-
160
- if document.user_id != user_id:
161
- raise HTTPException(
162
- status_code=403,
163
- detail="Access denied"
164
- )
165
-
166
- try:
167
- # Update status to processing
168
- await document_service.update_document_status(db, document_id, "processing")
169
-
170
- # Process document
171
- chunks_count = await knowledge_processor.process_document(document, db)
172
-
173
- # Update status to completed
174
- await document_service.update_document_status(db, document_id, "completed")
175
-
176
- return {
177
- "status": "success",
178
- "message": "Document processed successfully",
179
- "data": {
180
- "document_id": document_id,
181
- "chunks_processed": chunks_count
182
- }
183
- }
184
-
185
- except Exception as e:
186
- logger.error(f"Processing failed for document {document_id}", error=str(e))
187
- await document_service.update_document_status(
188
- db, document_id, "failed", str(e)
189
- )
190
- raise HTTPException(
191
- status_code=500,
192
- detail=f"Processing failed: {str(e)}"
193
- )
 
1
  """Document management API endpoints."""
2
+
3
+ from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File
4
  from sqlalchemy.ext.asyncio import AsyncSession
5
  from src.db.postgres.connection import get_db
6
  from src.document.document_service import document_service
 
 
7
  from src.middlewares.logging import get_logger, log_execution
8
  from src.middlewares.rate_limit import limiter
9
+ from src.pipeline.document_pipeline.document_pipeline import document_pipeline
10
  from pydantic import BaseModel
11
  from typing import List
12
+
13
  logger = get_logger("document_api")
14
+
15
  router = APIRouter(prefix="/api/v1", tags=["Documents"])
16
+
17
+
18
  class DocumentResponse(BaseModel):
19
  id: str
20
  filename: str
 
22
  file_size: int
23
  file_type: str
24
  created_at: str
25
+
26
+
27
  @router.get("/documents/{user_id}", response_model=List[DocumentResponse])
28
  @log_execution(logger)
29
  async def list_documents(
 
43
  )
44
  for doc in documents
45
  ]
46
+
47
+
48
  @router.post("/document/upload")
49
  @limiter.limit("10/minute")
50
  @log_execution(logger)
 
56
  ):
57
  """Upload a document."""
58
  if not user_id:
59
+ raise HTTPException(status_code=400, detail="user_id is required")
60
+
61
+ data = await document_pipeline.upload(file, user_id, db)
62
+ return {"status": "success", "message": "Document uploaded successfully", "data": data}
63
+
64
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @router.delete("/document/delete")
66
  @log_execution(logger)
67
  async def delete_document(
 
70
  db: AsyncSession = Depends(get_db)
71
  ):
72
  """Delete a document."""
73
+ await document_pipeline.delete(document_id, user_id, db)
74
+ return {"status": "success", "message": "Document deleted successfully"}
75
+
76
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  @router.post("/document/process")
78
  @log_execution(logger)
79
  async def process_document(
 
82
  db: AsyncSession = Depends(get_db)
83
  ):
84
  """Process document and ingest to vector index."""
85
+ data = await document_pipeline.process(document_id, user_id, db)
86
+ return {"status": "success", "message": "Document processed successfully", "data": data}
87
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/knowledge/processing_service.py CHANGED
@@ -49,10 +49,14 @@ class KnowledgeProcessingService:
49
  LangChainDocument(
50
  page_content=chunk,
51
  metadata={
52
- "document_id": db_doc.id,
53
  "user_id": db_doc.user_id,
54
- "filename": db_doc.filename,
55
- "chunk_index": i,
 
 
 
 
 
56
  }
57
  )
58
  for i, chunk in enumerate(chunks)
@@ -104,11 +108,15 @@ class KnowledgeProcessingService:
104
  documents.append(LangChainDocument(
105
  page_content=chunk,
106
  metadata={
107
- "document_id": db_doc.id,
108
  "user_id": db_doc.user_id,
109
- "filename": db_doc.filename,
110
- "chunk_index": len(documents),
111
- "page_label": page.page_number,
 
 
 
 
 
112
  }
113
  ))
114
  else:
@@ -122,11 +130,15 @@ class KnowledgeProcessingService:
122
  documents.append(LangChainDocument(
123
  page_content=chunk,
124
  metadata={
125
- "document_id": db_doc.id,
126
  "user_id": db_doc.user_id,
127
- "filename": db_doc.filename,
128
- "chunk_index": len(documents),
129
- "page_label": page_num,
 
 
 
 
 
130
  }
131
  ))
132
 
 
49
  LangChainDocument(
50
  page_content=chunk,
51
  metadata={
 
52
  "user_id": db_doc.user_id,
53
+ "source_type": "document",
54
+ "data": {
55
+ "document_id": db_doc.id,
56
+ "filename": db_doc.filename,
57
+ "file_type": db_doc.file_type,
58
+ "chunk_index": i,
59
+ },
60
  }
61
  )
62
  for i, chunk in enumerate(chunks)
 
108
  documents.append(LangChainDocument(
109
  page_content=chunk,
110
  metadata={
 
111
  "user_id": db_doc.user_id,
112
+ "source_type": "document",
113
+ "data": {
114
+ "document_id": db_doc.id,
115
+ "filename": db_doc.filename,
116
+ "file_type": db_doc.file_type,
117
+ "chunk_index": len(documents),
118
+ "page_label": page.page_number,
119
+ },
120
  }
121
  ))
122
  else:
 
130
  documents.append(LangChainDocument(
131
  page_content=chunk,
132
  metadata={
 
133
  "user_id": db_doc.user_id,
134
+ "source_type": "document",
135
+ "data": {
136
+ "document_id": db_doc.id,
137
+ "filename": db_doc.filename,
138
+ "file_type": db_doc.file_type,
139
+ "chunk_index": len(documents),
140
+ "page_label": page_num,
141
+ },
142
  }
143
  ))
144
 
src/pipeline/db_pipeline/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from src.pipeline.db_pipeline.pipeline import run_db_pipeline
2
+
3
+ __all__ = ["run_db_pipeline"]
src/pipeline/db_pipeline/connector.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Connectors for user-provided databases.
2
+
3
+ The pipeline does not own user credentials β€” an API layer (outside this folder)
4
+ builds an Engine via `connect(...)` and passes it to `run_db_pipeline`. Use
5
+ `engine_scope(...)` for guaranteed disposal of the connection pool.
6
+ """
7
+
8
+ from contextlib import contextmanager
9
+ from typing import Iterator, Literal
10
+
11
+ from sqlalchemy import URL, create_engine
12
+ from sqlalchemy.engine import Engine
13
+
14
+ from src.middlewares.logging import get_logger
15
+
16
+ logger = get_logger("db_connector")
17
+
18
+ DbType = Literal["postgresql", "mysql", "sqlserver"]
19
+
20
+
21
+ def get_postgres_engine(
22
+ host: str, port: int, dbname: str, username: str, password: str
23
+ ) -> Engine:
24
+ """Build a Postgres engine with safe URL escaping (handles special chars in password)."""
25
+ url = URL.create(
26
+ drivername="postgresql+psycopg2",
27
+ username=username,
28
+ password=password,
29
+ host=host,
30
+ port=port,
31
+ database=dbname,
32
+ )
33
+ return create_engine(url)
34
+
35
+
36
+ def connect(
37
+ db_type: DbType,
38
+ host: str,
39
+ port: int,
40
+ dbname: str,
41
+ username: str,
42
+ password: str,
43
+ ) -> Engine:
44
+ """Connect to a user-provided database. Returns a SQLAlchemy engine."""
45
+ logger.info("connecting to user db", db_type=db_type, host=host, port=port, dbname=dbname)
46
+ if db_type == "postgresql":
47
+ return get_postgres_engine(host, port, dbname, username, password)
48
+ elif db_type == "sqlserver":
49
+ raise NotImplementedError("SQL Server support coming soon")
50
+ elif db_type == "mysql":
51
+ raise NotImplementedError("MySQL support coming soon")
52
+ else:
53
+ raise ValueError(f"Unsupported db_type: {db_type}")
54
+
55
+
56
+ @contextmanager
57
+ def engine_scope(
58
+ db_type: DbType,
59
+ host: str,
60
+ port: int,
61
+ dbname: str,
62
+ username: str,
63
+ password: str,
64
+ ) -> Iterator[Engine]:
65
+ """Yield a connected Engine and dispose its pool on exit.
66
+
67
+ API callers should prefer this over raw `connect(...)` so user DB
68
+ connection pools do not leak between pipeline runs.
69
+ """
70
+ engine = connect(db_type, host, port, dbname, username, password)
71
+ try:
72
+ yield engine
73
+ finally:
74
+ engine.dispose()
src/pipeline/db_pipeline/extractor.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Schema introspection and per-column profiling for a user's database.
2
+
3
+ Identifiers (table/column names) are quoted via the engine's dialect preparer,
4
+ which handles reserved words, mixed case, and embedded quotes correctly across
5
+ dialects. Values used in SQL come from SQLAlchemy inspection of the DB itself,
6
+ not user input.
7
+ """
8
+
9
+ from typing import Optional
10
+
11
+ import pandas as pd
12
+ from sqlalchemy import Float, Integer, Numeric, inspect
13
+ from sqlalchemy.engine import Engine
14
+
15
+ from src.middlewares.logging import get_logger
16
+
17
+ logger = get_logger("db_extractor")
18
+
19
+ TOP_VALUES_THRESHOLD = 0.05 # show top values if distinct_ratio <= 5%
20
+
21
+
22
+ def _qi(engine: Engine, name: str) -> str:
23
+ """Dialect-correct identifier quoting (schema.table also handled if dotted)."""
24
+ preparer = engine.dialect.identifier_preparer
25
+ if "." in name:
26
+ schema, _, table = name.partition(".")
27
+ return f"{preparer.quote(schema)}.{preparer.quote(table)}"
28
+ return preparer.quote(name)
29
+
30
+
31
+ def get_schema(
32
+ engine: Engine, exclude_tables: Optional[frozenset[str]] = None
33
+ ) -> dict[str, list[dict]]:
34
+ """Returns {table_name: [{name, type, is_numeric, is_primary_key, foreign_key}, ...]}."""
35
+ exclude = exclude_tables or frozenset()
36
+ inspector = inspect(engine)
37
+ schema = {}
38
+ for table_name in inspector.get_table_names():
39
+ if table_name in exclude:
40
+ continue
41
+
42
+ pk = inspector.get_pk_constraint(table_name)
43
+ pk_cols = set(pk["constrained_columns"]) if pk else set()
44
+
45
+ fk_map = {}
46
+ for fk in inspector.get_foreign_keys(table_name):
47
+ for col, ref_col in zip(fk["constrained_columns"], fk["referred_columns"]):
48
+ fk_map[col] = f"{fk['referred_table']}.{ref_col}"
49
+
50
+ cols = inspector.get_columns(table_name)
51
+ schema[table_name] = [
52
+ {
53
+ "name": c["name"],
54
+ "type": str(c["type"]),
55
+ "is_numeric": isinstance(c["type"], (Integer, Numeric, Float)),
56
+ "is_primary_key": c["name"] in pk_cols,
57
+ "foreign_key": fk_map.get(c["name"]),
58
+ }
59
+ for c in cols
60
+ ]
61
+ logger.info("extracted schema", table_count=len(schema))
62
+ return schema
63
+
64
+
65
+ def get_row_count(engine: Engine, table_name: str) -> int:
66
+ return pd.read_sql(f"SELECT COUNT(*) FROM {_qi(engine, table_name)}", engine).iloc[0, 0]
67
+
68
+
69
+ def profile_column(
70
+ engine: Engine,
71
+ table_name: str,
72
+ col_name: str,
73
+ is_numeric: bool,
74
+ row_count: int,
75
+ ) -> dict:
76
+ """Returns null_count, distinct_count, min/max, top values, and sample values."""
77
+ if row_count == 0:
78
+ return {
79
+ "null_count": 0,
80
+ "distinct_count": 0,
81
+ "distinct_ratio": 0.0,
82
+ "sample_values": [],
83
+ }
84
+
85
+ qt = _qi(engine, table_name)
86
+ qc = _qi(engine, col_name)
87
+
88
+ # Combined stats query: null_count, distinct_count, and min/max (if numeric).
89
+ # One round-trip instead of two.
90
+ select_cols = [
91
+ f"COUNT(*) - COUNT({qc}) AS nulls",
92
+ f"COUNT(DISTINCT {qc}) AS distincts",
93
+ ]
94
+ if is_numeric:
95
+ select_cols.append(f"MIN({qc}) AS min_val")
96
+ select_cols.append(f"MAX({qc}) AS max_val")
97
+ select_cols.append(f"AVG({qc}) AS mean_val")
98
+ # PERCENTILE_CONT is supported by Postgres and SQL Server; MySQL would need
99
+ # a dialect-specific fallback when that connector is added.
100
+ select_cols.append(
101
+ f"PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY {qc}) AS median_val"
102
+ )
103
+ stats = pd.read_sql(f"SELECT {', '.join(select_cols)} FROM {qt}", engine)
104
+
105
+ null_count = int(stats.iloc[0]["nulls"])
106
+ distinct_count = int(stats.iloc[0]["distincts"])
107
+ distinct_ratio = distinct_count / row_count if row_count > 0 else 0
108
+
109
+ profile = {
110
+ "null_count": null_count,
111
+ "distinct_count": distinct_count,
112
+ "distinct_ratio": round(distinct_ratio, 4),
113
+ }
114
+
115
+ if is_numeric:
116
+ profile["min"] = stats.iloc[0]["min_val"]
117
+ profile["max"] = stats.iloc[0]["max_val"]
118
+ profile["mean"] = stats.iloc[0]["mean_val"]
119
+ profile["median"] = stats.iloc[0]["median_val"]
120
+
121
+ if 0 < distinct_ratio <= TOP_VALUES_THRESHOLD:
122
+ top = pd.read_sql(
123
+ f"SELECT {qc}, COUNT(*) AS cnt FROM {qt} "
124
+ f"GROUP BY {qc} ORDER BY cnt DESC LIMIT 10",
125
+ engine,
126
+ )
127
+ profile["top_values"] = list(zip(top[col_name].tolist(), top["cnt"].tolist()))
128
+
129
+ sample = pd.read_sql(f"SELECT {qc} FROM {qt} LIMIT 5", engine)
130
+ profile["sample_values"] = sample[col_name].tolist()
131
+
132
+ return profile
133
+
134
+
135
+ def profile_table(engine: Engine, table_name: str, columns: list[dict]) -> list[dict]:
136
+ """Profile every column in a table. Returns [{col, profile, text}, ...].
137
+
138
+ Per-column errors are logged and skipped so one bad column doesn't abort
139
+ the whole table.
140
+ """
141
+ row_count = get_row_count(engine, table_name)
142
+ if row_count == 0:
143
+ logger.info("skipping empty table", table=table_name)
144
+ return []
145
+
146
+ results = []
147
+ for col in columns:
148
+ try:
149
+ profile = profile_column(
150
+ engine, table_name, col["name"], col.get("is_numeric", False), row_count
151
+ )
152
+ text = build_text(table_name, row_count, col, profile)
153
+ results.append({"col": col, "profile": profile, "text": text})
154
+ except Exception as e:
155
+ logger.error(
156
+ "column profiling failed",
157
+ table=table_name,
158
+ column=col["name"],
159
+ error=str(e),
160
+ )
161
+ continue
162
+ return results
163
+
164
+
165
+ def build_text(table_name: str, row_count: int, col: dict, profile: dict) -> str:
166
+ col_name = col["name"]
167
+ col_type = col["type"]
168
+
169
+ key_label = ""
170
+ if col.get("is_primary_key"):
171
+ key_label = " [PRIMARY KEY]"
172
+ elif col.get("foreign_key"):
173
+ key_label = f" [FK -> {col['foreign_key']}]"
174
+
175
+ text = f"Table: {table_name} ({row_count} rows)\n"
176
+ text += f"Column: {col_name} ({col_type}){key_label}\n"
177
+ text += f"Null count: {profile['null_count']}\n"
178
+ text += f"Distinct count: {profile['distinct_count']} ({profile['distinct_ratio']:.1%})\n"
179
+ if "min" in profile:
180
+ text += f"Min: {profile['min']}, Max: {profile['max']}\n"
181
+ text += f"Mean: {profile['mean']}, Median: {profile['median']}\n"
182
+ if "top_values" in profile:
183
+ top_str = ", ".join(f"{v} ({c})" for v, c in profile["top_values"])
184
+ text += f"Top values: {top_str}\n"
185
+ text += f"Sample values: {profile['sample_values']}"
186
+ return text
src/pipeline/db_pipeline/pipeline.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end DB ingestion pipeline: introspect user's DB -> profile columns ->
2
+ build text -> embed + store in the shared PGVector collection.
3
+
4
+ Each column becomes one LangChainDocument with metadata tagging user_id and
5
+ source_type='database', so it is retrievable via the existing retriever.
6
+ """
7
+
8
+ import asyncio
9
+ from typing import Optional
10
+
11
+ from langchain_core.documents import Document as LangChainDocument
12
+ from sqlalchemy.engine import Engine
13
+
14
+ from src.db.postgres.vector_store import get_vector_store
15
+ from src.middlewares.logging import get_logger
16
+ from src.pipeline.db_pipeline.extractor import get_schema, profile_table
17
+
18
+ logger = get_logger("db_pipeline")
19
+
20
+
21
+ def _to_document(user_id: str, table_name: str, entry: dict) -> LangChainDocument:
22
+ col = entry["col"]
23
+ return LangChainDocument(
24
+ page_content=entry["text"],
25
+ metadata={
26
+ "user_id": user_id,
27
+ "source_type": "database",
28
+ "data": {
29
+ "table_name": table_name,
30
+ "column_name": col["name"],
31
+ "column_type": col["type"],
32
+ "is_primary_key": col.get("is_primary_key", False),
33
+ "foreign_key": col.get("foreign_key"),
34
+ },
35
+ },
36
+ )
37
+
38
+
39
+ async def run_db_pipeline(
40
+ user_id: str,
41
+ engine: Engine,
42
+ exclude_tables: Optional[frozenset[str]] = None,
43
+ ) -> int:
44
+ """Introspect the user's DB, profile columns, embed descriptions, store in PGVector.
45
+
46
+ Sync DB work (SQLAlchemy inspect, pandas read_sql) runs in a threadpool;
47
+ async vector writes stay on the event loop.
48
+
49
+ Returns:
50
+ Total number of chunks ingested.
51
+ """
52
+ vector_store = get_vector_store()
53
+ logger.info("db pipeline start", user_id=user_id)
54
+
55
+ schema = await asyncio.to_thread(get_schema, engine, exclude_tables)
56
+
57
+ total = 0
58
+ for table_name, columns in schema.items():
59
+ logger.info("profiling table", table=table_name, columns=len(columns))
60
+ entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
61
+ docs = [_to_document(user_id, table_name, e) for e in entries]
62
+ if docs:
63
+ await vector_store.aadd_documents(docs)
64
+ total += len(docs)
65
+ logger.info("ingested chunks", table=table_name, count=len(docs))
66
+
67
+ logger.info("db pipeline complete", user_id=user_id, total=total)
68
+ return total
src/pipeline/document_pipeline/__init__.py ADDED
File without changes
src/pipeline/document_pipeline/document_pipeline.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document upload and processing pipeline."""
2
+
3
+ from fastapi import HTTPException, UploadFile
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+
6
+ from src.document.document_service import document_service
7
+ from src.knowledge.processing_service import knowledge_processor
8
+ from src.middlewares.logging import get_logger
9
+ from src.storage.az_blob.az_blob import blob_storage
10
+
11
+ logger = get_logger("document_pipeline")
12
+
13
+ SUPPORTED_FILE_TYPES = ["pdf", "docx", "txt"]
14
+
15
+
16
+ class DocumentPipeline:
17
+ """Orchestrates the full document upload, process, and delete flows."""
18
+
19
+ async def upload(self, file: UploadFile, user_id: str, db: AsyncSession) -> dict:
20
+ """Validate β†’ upload to blob β†’ save to DB."""
21
+ content = await file.read()
22
+ file_type = file.filename.split(".")[-1].lower() if "." in file.filename else "txt"
23
+
24
+ if file_type not in SUPPORTED_FILE_TYPES:
25
+ raise HTTPException(
26
+ status_code=400,
27
+ detail=f"Unsupported file type. Supported: {SUPPORTED_FILE_TYPES}",
28
+ )
29
+
30
+ blob_name = await blob_storage.upload_file(content, file.filename, user_id)
31
+ document = await document_service.create_document(
32
+ db=db,
33
+ user_id=user_id,
34
+ filename=file.filename,
35
+ blob_name=blob_name,
36
+ file_size=len(content),
37
+ file_type=file_type,
38
+ )
39
+
40
+ logger.info(f"Uploaded document {document.id} for user {user_id}")
41
+ return {"id": document.id, "filename": document.filename, "status": document.status}
42
+
43
+ async def process(self, document_id: str, user_id: str, db: AsyncSession) -> dict:
44
+ """Validate ownership β†’ extract text β†’ chunk β†’ ingest to vector store."""
45
+ document = await document_service.get_document(db, document_id)
46
+
47
+ if not document:
48
+ raise HTTPException(status_code=404, detail="Document not found")
49
+ if document.user_id != user_id:
50
+ raise HTTPException(status_code=403, detail="Access denied")
51
+
52
+ try:
53
+ await document_service.update_document_status(db, document_id, "processing")
54
+ chunks_count = await knowledge_processor.process_document(document, db)
55
+ await document_service.update_document_status(db, document_id, "completed")
56
+
57
+ logger.info(f"Processed document {document_id}: {chunks_count} chunks")
58
+ return {"document_id": document_id, "chunks_processed": chunks_count}
59
+
60
+ except Exception as e:
61
+ logger.error(f"Processing failed for document {document_id}", error=str(e))
62
+ await document_service.update_document_status(db, document_id, "failed", str(e))
63
+ raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
64
+
65
+ async def delete(self, document_id: str, user_id: str, db: AsyncSession) -> dict:
66
+ """Validate ownership β†’ delete from blob and DB."""
67
+ document = await document_service.get_document(db, document_id)
68
+
69
+ if not document:
70
+ raise HTTPException(status_code=404, detail="Document not found")
71
+ if document.user_id != user_id:
72
+ raise HTTPException(status_code=403, detail="Access denied")
73
+
74
+ await document_service.delete_document(db, document_id)
75
+
76
+ logger.info(f"Deleted document {document_id} for user {user_id}")
77
+ return {"document_id": document_id}
78
+
79
+
80
+ document_pipeline = DocumentPipeline()