import os from datetime import datetime from pathlib import Path from sqlalchemy import ( JSON, Column, DateTime, Float, ForeignKey, Integer, String, Text, create_engine, inspect, text, ) from sqlalchemy.orm import declarative_base, relationship, sessionmaker Base = declarative_base() _ENGINE_CACHE = {} _SESSION_FACTORY_CACHE = {} SERVER_DIR = Path(__file__).resolve().parents[1] class Repository(Base): __tablename__ = "repositories" id = Column(Integer, primary_key=True) github_url = Column(String(1024), nullable=False, unique=True) source_url = Column(String(1024)) session_key = Column(String(255), index=True) session_expires_at = Column(DateTime) owner = Column(String(255), nullable=False) name = Column(String(255), nullable=False) branch = Column(String(255), nullable=False, default="main") local_path = Column(String(1024)) status = Column(String(64), nullable=False, default="queued") error_message = Column(Text) file_count = Column(Integer, nullable=False, default=0) chunk_count = Column(Integer, nullable=False, default=0) indexed_at = Column(DateTime) created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) chunks = relationship( "CodeChunk", back_populates="repository", cascade="all, delete-orphan" ) chat_turns = relationship( "ChatTurn", back_populates="repository", cascade="all, delete-orphan" ) class CodeChunk(Base): __tablename__ = "code_chunks" id = Column(Integer, primary_key=True) repository_id = Column(Integer, ForeignKey("repositories.id"), nullable=False) file_path = Column(String(1024), nullable=False) language = Column(String(64), nullable=False) symbol_name = Column(String(255)) symbol_type = Column(String(128), nullable=False, default="chunk") line_start = Column(Integer, nullable=False) line_end = Column(Integer, nullable=False) signature = Column(Text) content = Column(Text, nullable=False) searchable_text = Column(Text, nullable=False) metadata_json = Column(JSON, nullable=False, default=dict) embedding_id = Column(Integer) rerank_score = Column(Float) created_at = Column(DateTime, default=datetime.utcnow) repository = relationship("Repository", back_populates="chunks") class ChatTurn(Base): __tablename__ = "chat_turns" id = Column(Integer, primary_key=True) repository_id = Column(Integer, ForeignKey("repositories.id"), nullable=False) role = Column(String(32), nullable=False) content = Column(Text, nullable=False) answer_json = Column(JSON) created_at = Column(DateTime, default=datetime.utcnow) repository = relationship("Repository", back_populates="chat_turns") def init_db(database_url: str = None): if database_url is None: database_url = os.getenv("DATABASE_URL", "sqlite:///./codebase_rag.db") database_url = resolve_database_url(database_url) if database_url in _ENGINE_CACHE: return _ENGINE_CACHE[database_url], _SESSION_FACTORY_CACHE[database_url] connect_args = {"check_same_thread": False} if database_url.startswith("sqlite") else {} engine = create_engine(database_url, echo=False, connect_args=connect_args) Base.metadata.create_all(engine) _ensure_runtime_columns(engine) session_local = sessionmaker(bind=engine) _ENGINE_CACHE[database_url] = engine _SESSION_FACTORY_CACHE[database_url] = session_local return engine, session_local def resolve_database_url(database_url: str) -> str: if not database_url.startswith("sqlite:///"): return database_url sqlite_path = database_url.removeprefix("sqlite:///") if sqlite_path == ":memory:": return database_url path = Path(sqlite_path) if not path.is_absolute(): path = SERVER_DIR / path path.parent.mkdir(parents=True, exist_ok=True) path.touch(exist_ok=True) return f"sqlite:///{path.resolve()}" def _ensure_runtime_columns(engine): inspector = inspect(engine) if "repositories" not in inspector.get_table_names(): return existing = {column["name"] for column in inspector.get_columns("repositories")} alterations = { "source_url": "ALTER TABLE repositories ADD COLUMN source_url VARCHAR(1024)", "session_key": "ALTER TABLE repositories ADD COLUMN session_key VARCHAR(255)", "session_expires_at": "ALTER TABLE repositories ADD COLUMN session_expires_at DATETIME", } with engine.begin() as connection: for column_name, statement in alterations.items(): if column_name not in existing: connection.execute(text(statement)) def get_db_session(database_url: str = None): _, session_local = init_db(database_url) return session_local()