Rifqi Hafizuddin commited on
Commit
7f3bb97
·
1 Parent(s): a4cf97a

[NOTICKET][DB] refactor code to new repo

Browse files
src/pipeline/db_pipeline/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from src.pipeline.db_pipeline.pipeline import run_db_pipeline
2
+
3
+ __all__ = ["run_db_pipeline"]
src/pipeline/db_pipeline/connector.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Connectors for user-provided databases.
2
+
3
+ The pipeline does not own user credentials — an API layer (outside this folder)
4
+ builds an Engine via `connect(...)` and passes it to `run_db_pipeline`. Use
5
+ `engine_scope(...)` for guaranteed disposal of the connection pool.
6
+ """
7
+
8
+ from contextlib import contextmanager
9
+ from typing import Iterator, Literal
10
+
11
+ from sqlalchemy import URL, create_engine
12
+ from sqlalchemy.engine import Engine
13
+
14
+ from src.middlewares.logging import get_logger
15
+
16
+ logger = get_logger("db_connector")
17
+
18
+ DbType = Literal["postgresql", "mysql", "sqlserver"]
19
+
20
+
21
+ def get_postgres_engine(
22
+ host: str, port: int, dbname: str, username: str, password: str
23
+ ) -> Engine:
24
+ """Build a Postgres engine with safe URL escaping (handles special chars in password)."""
25
+ url = URL.create(
26
+ drivername="postgresql+psycopg2",
27
+ username=username,
28
+ password=password,
29
+ host=host,
30
+ port=port,
31
+ database=dbname,
32
+ )
33
+ return create_engine(url)
34
+
35
+
36
+ def connect(
37
+ db_type: DbType,
38
+ host: str,
39
+ port: int,
40
+ dbname: str,
41
+ username: str,
42
+ password: str,
43
+ ) -> Engine:
44
+ """Connect to a user-provided database. Returns a SQLAlchemy engine."""
45
+ logger.info("connecting to user db", db_type=db_type, host=host, port=port, dbname=dbname)
46
+ if db_type == "postgresql":
47
+ return get_postgres_engine(host, port, dbname, username, password)
48
+ elif db_type == "sqlserver":
49
+ raise NotImplementedError("SQL Server support coming soon")
50
+ elif db_type == "mysql":
51
+ raise NotImplementedError("MySQL support coming soon")
52
+ else:
53
+ raise ValueError(f"Unsupported db_type: {db_type}")
54
+
55
+
56
+ @contextmanager
57
+ def engine_scope(
58
+ db_type: DbType,
59
+ host: str,
60
+ port: int,
61
+ dbname: str,
62
+ username: str,
63
+ password: str,
64
+ ) -> Iterator[Engine]:
65
+ """Yield a connected Engine and dispose its pool on exit.
66
+
67
+ API callers should prefer this over raw `connect(...)` so user DB
68
+ connection pools do not leak between pipeline runs.
69
+ """
70
+ engine = connect(db_type, host, port, dbname, username, password)
71
+ try:
72
+ yield engine
73
+ finally:
74
+ engine.dispose()
src/pipeline/db_pipeline/extractor.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Schema introspection and per-column profiling for a user's database.
2
+
3
+ Identifiers (table/column names) are quoted via the engine's dialect preparer,
4
+ which handles reserved words, mixed case, and embedded quotes correctly across
5
+ dialects. Values used in SQL come from SQLAlchemy inspection of the DB itself,
6
+ not user input.
7
+ """
8
+
9
+ from typing import Optional
10
+
11
+ import pandas as pd
12
+ from sqlalchemy import Float, Integer, Numeric, inspect
13
+ from sqlalchemy.engine import Engine
14
+
15
+ from src.middlewares.logging import get_logger
16
+
17
+ logger = get_logger("db_extractor")
18
+
19
+ TOP_VALUES_THRESHOLD = 0.05 # show top values if distinct_ratio <= 5%
20
+
21
+
22
+ def _qi(engine: Engine, name: str) -> str:
23
+ """Dialect-correct identifier quoting (schema.table also handled if dotted)."""
24
+ preparer = engine.dialect.identifier_preparer
25
+ if "." in name:
26
+ schema, _, table = name.partition(".")
27
+ return f"{preparer.quote(schema)}.{preparer.quote(table)}"
28
+ return preparer.quote(name)
29
+
30
+
31
+ def get_schema(
32
+ engine: Engine, exclude_tables: Optional[frozenset[str]] = None
33
+ ) -> dict[str, list[dict]]:
34
+ """Returns {table_name: [{name, type, is_numeric, is_primary_key, foreign_key}, ...]}."""
35
+ exclude = exclude_tables or frozenset()
36
+ inspector = inspect(engine)
37
+ schema = {}
38
+ for table_name in inspector.get_table_names():
39
+ if table_name in exclude:
40
+ continue
41
+
42
+ pk = inspector.get_pk_constraint(table_name)
43
+ pk_cols = set(pk["constrained_columns"]) if pk else set()
44
+
45
+ fk_map = {}
46
+ for fk in inspector.get_foreign_keys(table_name):
47
+ for col, ref_col in zip(fk["constrained_columns"], fk["referred_columns"]):
48
+ fk_map[col] = f"{fk['referred_table']}.{ref_col}"
49
+
50
+ cols = inspector.get_columns(table_name)
51
+ schema[table_name] = [
52
+ {
53
+ "name": c["name"],
54
+ "type": str(c["type"]),
55
+ "is_numeric": isinstance(c["type"], (Integer, Numeric, Float)),
56
+ "is_primary_key": c["name"] in pk_cols,
57
+ "foreign_key": fk_map.get(c["name"]),
58
+ }
59
+ for c in cols
60
+ ]
61
+ logger.info("extracted schema", table_count=len(schema))
62
+ return schema
63
+
64
+
65
+ def get_row_count(engine: Engine, table_name: str) -> int:
66
+ return pd.read_sql(f"SELECT COUNT(*) FROM {_qi(engine, table_name)}", engine).iloc[0, 0]
67
+
68
+
69
+ def profile_column(
70
+ engine: Engine,
71
+ table_name: str,
72
+ col_name: str,
73
+ is_numeric: bool,
74
+ row_count: int,
75
+ ) -> dict:
76
+ """Returns null_count, distinct_count, min/max, top values, and sample values."""
77
+ if row_count == 0:
78
+ return {
79
+ "null_count": 0,
80
+ "distinct_count": 0,
81
+ "distinct_ratio": 0.0,
82
+ "sample_values": [],
83
+ }
84
+
85
+ qt = _qi(engine, table_name)
86
+ qc = _qi(engine, col_name)
87
+
88
+ # Combined stats query: null_count, distinct_count, and min/max (if numeric).
89
+ # One round-trip instead of two.
90
+ select_cols = [
91
+ f"COUNT(*) - COUNT({qc}) AS nulls",
92
+ f"COUNT(DISTINCT {qc}) AS distincts",
93
+ ]
94
+ if is_numeric:
95
+ select_cols.append(f"MIN({qc}) AS min_val")
96
+ select_cols.append(f"MAX({qc}) AS max_val")
97
+ stats = pd.read_sql(f"SELECT {', '.join(select_cols)} FROM {qt}", engine)
98
+
99
+ null_count = int(stats.iloc[0]["nulls"])
100
+ distinct_count = int(stats.iloc[0]["distincts"])
101
+ distinct_ratio = distinct_count / row_count if row_count > 0 else 0
102
+
103
+ profile = {
104
+ "null_count": null_count,
105
+ "distinct_count": distinct_count,
106
+ "distinct_ratio": round(distinct_ratio, 4),
107
+ }
108
+
109
+ if is_numeric:
110
+ profile["min"] = stats.iloc[0]["min_val"]
111
+ profile["max"] = stats.iloc[0]["max_val"]
112
+
113
+ if 0 < distinct_ratio <= TOP_VALUES_THRESHOLD:
114
+ top = pd.read_sql(
115
+ f"SELECT {qc}, COUNT(*) AS cnt FROM {qt} "
116
+ f"GROUP BY {qc} ORDER BY cnt DESC LIMIT 10",
117
+ engine,
118
+ )
119
+ profile["top_values"] = list(zip(top[col_name].tolist(), top["cnt"].tolist()))
120
+
121
+ sample = pd.read_sql(f"SELECT {qc} FROM {qt} LIMIT 5", engine)
122
+ profile["sample_values"] = sample[col_name].tolist()
123
+
124
+ return profile
125
+
126
+
127
+ def profile_table(engine: Engine, table_name: str, columns: list[dict]) -> list[dict]:
128
+ """Profile every column in a table. Returns [{col, profile, text}, ...].
129
+
130
+ Per-column errors are logged and skipped so one bad column doesn't abort
131
+ the whole table.
132
+ """
133
+ row_count = get_row_count(engine, table_name)
134
+ if row_count == 0:
135
+ logger.info("skipping empty table", table=table_name)
136
+ return []
137
+
138
+ results = []
139
+ for col in columns:
140
+ try:
141
+ profile = profile_column(
142
+ engine, table_name, col["name"], col.get("is_numeric", False), row_count
143
+ )
144
+ text = build_text(table_name, row_count, col, profile)
145
+ results.append({"col": col, "profile": profile, "text": text})
146
+ except Exception as e:
147
+ logger.error(
148
+ "column profiling failed",
149
+ table=table_name,
150
+ column=col["name"],
151
+ error=str(e),
152
+ )
153
+ continue
154
+ return results
155
+
156
+
157
+ def build_text(table_name: str, row_count: int, col: dict, profile: dict) -> str:
158
+ col_name = col["name"]
159
+ col_type = col["type"]
160
+
161
+ key_label = ""
162
+ if col.get("is_primary_key"):
163
+ key_label = " [PRIMARY KEY]"
164
+ elif col.get("foreign_key"):
165
+ key_label = f" [FK -> {col['foreign_key']}]"
166
+
167
+ text = f"Table: {table_name} ({row_count} rows)\n"
168
+ text += f"Column: {col_name} ({col_type}){key_label}\n"
169
+ text += f"Null count: {profile['null_count']}\n"
170
+ text += f"Distinct count: {profile['distinct_count']} ({profile['distinct_ratio']:.1%})\n"
171
+ if "min" in profile:
172
+ text += f"Min: {profile['min']}, Max: {profile['max']}\n"
173
+ if "top_values" in profile:
174
+ top_str = ", ".join(f"{v} ({c})" for v, c in profile["top_values"])
175
+ text += f"Top values: {top_str}\n"
176
+ text += f"Sample values: {profile['sample_values']}"
177
+ return text
src/pipeline/db_pipeline/pipeline.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end DB ingestion pipeline: introspect user's DB -> profile columns ->
2
+ build text -> embed + store in the shared PGVector collection.
3
+
4
+ Each column becomes one LangChainDocument with metadata tagging user_id and
5
+ source_type='database', so it is retrievable via the existing retriever.
6
+ """
7
+
8
+ import asyncio
9
+ from typing import Optional
10
+
11
+ from langchain_core.documents import Document as LangChainDocument
12
+ from sqlalchemy.engine import Engine
13
+
14
+ from src.db.postgres.vector_store import get_vector_store
15
+ from src.middlewares.logging import get_logger
16
+ from src.pipeline.db_pipeline.extractor import get_schema, profile_table
17
+
18
+ logger = get_logger("db_pipeline")
19
+
20
+
21
+ def _to_document(user_id: str, table_name: str, entry: dict) -> LangChainDocument:
22
+ col = entry["col"]
23
+ return LangChainDocument(
24
+ page_content=entry["text"],
25
+ metadata={
26
+ "user_id": user_id,
27
+ "source_type": "database",
28
+ "table_name": table_name,
29
+ "column_name": col["name"],
30
+ "column_type": col["type"],
31
+ "is_primary_key": col.get("is_primary_key", False),
32
+ "foreign_key": col.get("foreign_key"),
33
+ },
34
+ )
35
+
36
+
37
+ async def run_db_pipeline(
38
+ user_id: str,
39
+ engine: Engine,
40
+ exclude_tables: Optional[frozenset[str]] = None,
41
+ ) -> int:
42
+ """Introspect the user's DB, profile columns, embed descriptions, store in PGVector.
43
+
44
+ Sync DB work (SQLAlchemy inspect, pandas read_sql) runs in a threadpool;
45
+ async vector writes stay on the event loop.
46
+
47
+ Returns:
48
+ Total number of chunks ingested.
49
+ """
50
+ vector_store = get_vector_store()
51
+ logger.info("db pipeline start", user_id=user_id)
52
+
53
+ schema = await asyncio.to_thread(get_schema, engine, exclude_tables)
54
+
55
+ total = 0
56
+ for table_name, columns in schema.items():
57
+ logger.info("profiling table", table=table_name, columns=len(columns))
58
+ entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
59
+ docs = [_to_document(user_id, table_name, e) for e in entries]
60
+ if docs:
61
+ await vector_store.aadd_documents(docs)
62
+ total += len(docs)
63
+ logger.info("ingested chunks", table=table_name, count=len(docs))
64
+
65
+ logger.info("db pipeline complete", user_id=user_id, total=total)
66
+ return total