import os import re import json import shutil import sqlite3 import gradio as gr from huggingface_hub import InferenceClient, hf_hub_download # --------------------------------- # Config # --------------------------------- DB_FILENAME = "auth_llm-v3.sqlite" DB_PATH = f"./{DB_FILENAME}" DATASET_REPO_ID = "ameyjoshi8198/auth-log-db" HF_TOKEN = os.environ["HF_TOKEN"] client = InferenceClient(token=HF_TOKEN) MODEL_NAME = "inclusionAI/Ling-2.6-1T:novita" # --------------------------------- # DB setup # --------------------------------- def ensure_database(): if not os.path.exists(DB_PATH) or os.path.getsize(DB_PATH) < 1024: print("Downloading SQLite database from HF dataset repo...") downloaded_path = hf_hub_download( repo_id=DATASET_REPO_ID, repo_type="dataset", filename=DB_FILENAME, token=HF_TOKEN ) if downloaded_path != DB_PATH: shutil.copy(downloaded_path, DB_PATH) print(f"Database ready at {DB_PATH}") print(f"Database size: {os.path.getsize(DB_PATH)} bytes") def debug_database(): conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = [row[0] for row in cursor.fetchall()] conn.close() print("Available tables:", tables) return tables # --------------------------------- # Helpers # --------------------------------- def extract_ip(text): match = re.search(r"\b(?:\d{1,3}\.){3}\d{1,3}\b", text) return match.group(0) if match else None def extract_hour(text): match = re.search(r"\b(\d{1,2})\s*(?:am|pm)?\b", text.lower()) return int(match.group(1)) if match else None def extract_date_fragment(text): months = [ "jan", "feb", "mar", "apr", "may", "jun", "jul", "aug", "sep", "oct", "nov", "dec" ] t = text.lower() for m in months: if m in t: return m return None def detect_intent(question): q = question.lower() if extract_ip(q): return "ip_drilldown" if "incident" in q: return "incidents" if "top" in q or "suspicious" in q or "threat" in q: return "top_threats" if "summary" in q or "report" in q: return "summary" if "event type" in q or "common event" in q: return "event_types" if "what happened" in q or "around" in q or "at" in q: return "time_slice" return "general" # --------------------------------- # SQL retrieval # --------------------------------- def query_db(sql, params=()): conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row cursor = conn.cursor() cursor.execute(sql, params) rows = [dict(r) for r in cursor.fetchall()] conn.close() return rows def retrieve_top_threats(): return query_db(""" SELECT src_ip, threat_score, severity, event_count, session_count, failed_password_hits, invalid_user_hits, top_usernames FROM ip_profiles ORDER BY threat_score DESC LIMIT 10 """) def retrieve_incidents(): return query_db(""" SELECT incident_id, src_ip, start_time, end_time, event_count, session_count, failed_password_hits, invalid_user_hits, top_usernames FROM incidents ORDER BY start_time DESC LIMIT 10 """) def retrieve_summary(): return query_db(""" SELECT * FROM daily_summary ORDER BY daybucket DESC LIMIT 10 """) def retrieve_event_types(): return query_db(""" SELECT event_type, COUNT(*) AS hits FROM events GROUP BY event_type ORDER BY hits DESC LIMIT 10 """) def retrieve_ip_drilldown(ip): profile = query_db(""" SELECT * FROM ip_profiles WHERE src_ip = ? """, (ip,)) incidents = query_db(""" SELECT * FROM incidents WHERE src_ip = ? ORDER BY start_time DESC LIMIT 10 """, (ip,)) explanations = query_db(""" SELECT * FROM ip_explanations WHERE src_ip = ? """, (ip,)) recent_events = query_db(""" SELECT * FROM events WHERE src_ip = ? ORDER BY timestamp DESC LIMIT 25 """, (ip,)) return { "profile": profile, "incidents": incidents, "explanations": explanations, "recent_events": recent_events } def retrieve_time_slice(question): hour = extract_hour(question) month_fragment = extract_date_fragment(question) sql = """ SELECT timestamp, src_ip, username, event_type, auth_phase, severity_hint FROM events WHERE 1=1 """ params = [] if hour is not None: sql += " AND CAST(strftime('%H', timestamp) AS INTEGER) = ?" params.append(hour) if month_fragment: sql += " AND lower(timestamp) LIKE ?" params.append(f"%{month_fragment}%") sql += " ORDER BY timestamp DESC LIMIT 50" rows = query_db(sql, tuple(params)) return rows def retrieve_evidence(question): intent = detect_intent(question) if intent == "top_threats": return {"intent": intent, "data": retrieve_top_threats()} elif intent == "incidents": return {"intent": intent, "data": retrieve_incidents()} elif intent == "summary": return {"intent": intent, "data": retrieve_summary()} elif intent == "event_types": return {"intent": intent, "data": retrieve_event_types()} elif intent == "ip_drilldown": ip = extract_ip(question) return {"intent": intent, "data": retrieve_ip_drilldown(ip)} elif intent == "time_slice": return {"intent": intent, "data": retrieve_time_slice(question)} else: return { "intent": "general", "data": { "top_threats": retrieve_top_threats(), "recent_incidents": retrieve_incidents(), "event_types": retrieve_event_types() } } # --------------------------------- # Answer generation # --------------------------------- def answer_question(question): try: evidence = retrieve_evidence(question) if not evidence or not evidence.get("data"): return "I could not find relevant evidence in the database for that question." prompt = f""" You are a security log analyst. Use ONLY the evidence below. Do not invent facts. If the evidence is incomplete, say so clearly. Prefer concrete observations over speculation. Question: {question} Retrieved evidence: {json.dumps(evidence, indent=2, default=str)} """ response = client.chat_completion( model=MODEL_NAME, messages=[{"role": "user", "content": prompt}], max_tokens=1024 ) return response.choices[0].message.content except Exception as e: return f"Error: {str(e)}" # --------------------------------- # Startup # --------------------------------- ensure_database() debug_database() # --------------------------------- # Gradio app # --------------------------------- demo = gr.Interface( fn=answer_question, inputs=gr.Textbox(label="Ask a question about the logs", lines=2, placeholder="e.g. Why is 173.234.31.186 suspicious?"), outputs=gr.Textbox(label="Answer", lines=16), title="Log Analyzer", description="Ask grounded questions about the open source SSH log dataset." ) demo.launch(server_name="0.0.0.0", server_port=7860)