Spaces:
Running
Running
| import os | |
| import re | |
| import json | |
| import shutil | |
| import sqlite3 | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient, hf_hub_download | |
| # --------------------------------- | |
| # Config | |
| # --------------------------------- | |
| DB_FILENAME = "auth_llm-v3.sqlite" | |
| DB_PATH = f"./{DB_FILENAME}" | |
| DATASET_REPO_ID = "ameyjoshi8198/auth-log-db" | |
| HF_TOKEN = os.environ["HF_TOKEN"] | |
| client = InferenceClient(token=HF_TOKEN) | |
| MODEL_NAME = "inclusionAI/Ling-2.6-1T:novita" | |
| # --------------------------------- | |
| # DB setup | |
| # --------------------------------- | |
| def ensure_database(): | |
| if not os.path.exists(DB_PATH) or os.path.getsize(DB_PATH) < 1024: | |
| print("Downloading SQLite database from HF dataset repo...") | |
| downloaded_path = hf_hub_download( | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| filename=DB_FILENAME, | |
| token=HF_TOKEN | |
| ) | |
| if downloaded_path != DB_PATH: | |
| shutil.copy(downloaded_path, DB_PATH) | |
| print(f"Database ready at {DB_PATH}") | |
| print(f"Database size: {os.path.getsize(DB_PATH)} bytes") | |
| def debug_database(): | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = [row[0] for row in cursor.fetchall()] | |
| conn.close() | |
| print("Available tables:", tables) | |
| return tables | |
| # --------------------------------- | |
| # Helpers | |
| # --------------------------------- | |
| def extract_ip(text): | |
| match = re.search(r"\b(?:\d{1,3}\.){3}\d{1,3}\b", text) | |
| return match.group(0) if match else None | |
| def extract_hour(text): | |
| match = re.search(r"\b(\d{1,2})\s*(?:am|pm)?\b", text.lower()) | |
| return int(match.group(1)) if match else None | |
| def extract_date_fragment(text): | |
| months = [ | |
| "jan", "feb", "mar", "apr", "may", "jun", | |
| "jul", "aug", "sep", "oct", "nov", "dec" | |
| ] | |
| t = text.lower() | |
| for m in months: | |
| if m in t: | |
| return m | |
| return None | |
| def detect_intent(question): | |
| q = question.lower() | |
| if extract_ip(q): | |
| return "ip_drilldown" | |
| if "incident" in q: | |
| return "incidents" | |
| if "top" in q or "suspicious" in q or "threat" in q: | |
| return "top_threats" | |
| if "summary" in q or "report" in q: | |
| return "summary" | |
| if "event type" in q or "common event" in q: | |
| return "event_types" | |
| if "what happened" in q or "around" in q or "at" in q: | |
| return "time_slice" | |
| return "general" | |
| # --------------------------------- | |
| # SQL retrieval | |
| # --------------------------------- | |
| def query_db(sql, params=()): | |
| conn = sqlite3.connect(DB_PATH) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute(sql, params) | |
| rows = [dict(r) for r in cursor.fetchall()] | |
| conn.close() | |
| return rows | |
| def retrieve_top_threats(): | |
| return query_db(""" | |
| SELECT src_ip, threat_score, severity, event_count, session_count, | |
| failed_password_hits, invalid_user_hits, top_usernames | |
| FROM ip_profiles | |
| ORDER BY threat_score DESC | |
| LIMIT 10 | |
| """) | |
| def retrieve_incidents(): | |
| return query_db(""" | |
| SELECT incident_id, src_ip, start_time, end_time, event_count, | |
| session_count, failed_password_hits, invalid_user_hits, top_usernames | |
| FROM incidents | |
| ORDER BY start_time DESC | |
| LIMIT 10 | |
| """) | |
| def retrieve_summary(): | |
| return query_db(""" | |
| SELECT * | |
| FROM daily_summary | |
| ORDER BY daybucket DESC | |
| LIMIT 10 | |
| """) | |
| def retrieve_event_types(): | |
| return query_db(""" | |
| SELECT event_type, COUNT(*) AS hits | |
| FROM events | |
| GROUP BY event_type | |
| ORDER BY hits DESC | |
| LIMIT 10 | |
| """) | |
| def retrieve_ip_drilldown(ip): | |
| profile = query_db(""" | |
| SELECT * | |
| FROM ip_profiles | |
| WHERE src_ip = ? | |
| """, (ip,)) | |
| incidents = query_db(""" | |
| SELECT * | |
| FROM incidents | |
| WHERE src_ip = ? | |
| ORDER BY start_time DESC | |
| LIMIT 10 | |
| """, (ip,)) | |
| explanations = query_db(""" | |
| SELECT * | |
| FROM ip_explanations | |
| WHERE src_ip = ? | |
| """, (ip,)) | |
| recent_events = query_db(""" | |
| SELECT * | |
| FROM events | |
| WHERE src_ip = ? | |
| ORDER BY timestamp DESC | |
| LIMIT 25 | |
| """, (ip,)) | |
| return { | |
| "profile": profile, | |
| "incidents": incidents, | |
| "explanations": explanations, | |
| "recent_events": recent_events | |
| } | |
| def retrieve_time_slice(question): | |
| hour = extract_hour(question) | |
| month_fragment = extract_date_fragment(question) | |
| sql = """ | |
| SELECT timestamp, src_ip, username, event_type, auth_phase, severity_hint | |
| FROM events | |
| WHERE 1=1 | |
| """ | |
| params = [] | |
| if hour is not None: | |
| sql += " AND CAST(strftime('%H', timestamp) AS INTEGER) = ?" | |
| params.append(hour) | |
| if month_fragment: | |
| sql += " AND lower(timestamp) LIKE ?" | |
| params.append(f"%{month_fragment}%") | |
| sql += " ORDER BY timestamp DESC LIMIT 50" | |
| rows = query_db(sql, tuple(params)) | |
| return rows | |
| def retrieve_evidence(question): | |
| intent = detect_intent(question) | |
| if intent == "top_threats": | |
| return {"intent": intent, "data": retrieve_top_threats()} | |
| elif intent == "incidents": | |
| return {"intent": intent, "data": retrieve_incidents()} | |
| elif intent == "summary": | |
| return {"intent": intent, "data": retrieve_summary()} | |
| elif intent == "event_types": | |
| return {"intent": intent, "data": retrieve_event_types()} | |
| elif intent == "ip_drilldown": | |
| ip = extract_ip(question) | |
| return {"intent": intent, "data": retrieve_ip_drilldown(ip)} | |
| elif intent == "time_slice": | |
| return {"intent": intent, "data": retrieve_time_slice(question)} | |
| else: | |
| return { | |
| "intent": "general", | |
| "data": { | |
| "top_threats": retrieve_top_threats(), | |
| "recent_incidents": retrieve_incidents(), | |
| "event_types": retrieve_event_types() | |
| } | |
| } | |
| # --------------------------------- | |
| # Answer generation | |
| # --------------------------------- | |
| def answer_question(question): | |
| try: | |
| evidence = retrieve_evidence(question) | |
| if not evidence or not evidence.get("data"): | |
| return "I could not find relevant evidence in the database for that question." | |
| prompt = f""" | |
| You are a security log analyst. | |
| Use ONLY the evidence below. | |
| Do not invent facts. | |
| If the evidence is incomplete, say so clearly. | |
| Prefer concrete observations over speculation. | |
| Question: | |
| {question} | |
| Retrieved evidence: | |
| {json.dumps(evidence, indent=2, default=str)} | |
| """ | |
| response = client.chat_completion( | |
| model=MODEL_NAME, | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=1024 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # --------------------------------- | |
| # Startup | |
| # --------------------------------- | |
| ensure_database() | |
| debug_database() | |
| # --------------------------------- | |
| # Gradio app | |
| # --------------------------------- | |
| demo = gr.Interface( | |
| fn=answer_question, | |
| inputs=gr.Textbox(label="Ask a question about the logs", lines=2, placeholder="e.g. Why is 173.234.31.186 suspicious?"), | |
| outputs=gr.Textbox(label="Answer", lines=16), | |
| title="Log Analyzer", | |
| description="Ask grounded questions about the open source SSH log dataset." | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |