log-chatbot / app.py
ameyjoshi8198's picture
Update app.py
6b76364 verified
import os
import re
import json
import shutil
import sqlite3
import gradio as gr
from huggingface_hub import InferenceClient, hf_hub_download
# ---------------------------------
# Config
# ---------------------------------
DB_FILENAME = "auth_llm-v3.sqlite"
DB_PATH = f"./{DB_FILENAME}"
DATASET_REPO_ID = "ameyjoshi8198/auth-log-db"
HF_TOKEN = os.environ["HF_TOKEN"]
client = InferenceClient(token=HF_TOKEN)
MODEL_NAME = "inclusionAI/Ling-2.6-1T:novita"
# ---------------------------------
# DB setup
# ---------------------------------
def ensure_database():
if not os.path.exists(DB_PATH) or os.path.getsize(DB_PATH) < 1024:
print("Downloading SQLite database from HF dataset repo...")
downloaded_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
filename=DB_FILENAME,
token=HF_TOKEN
)
if downloaded_path != DB_PATH:
shutil.copy(downloaded_path, DB_PATH)
print(f"Database ready at {DB_PATH}")
print(f"Database size: {os.path.getsize(DB_PATH)} bytes")
def debug_database():
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [row[0] for row in cursor.fetchall()]
conn.close()
print("Available tables:", tables)
return tables
# ---------------------------------
# Helpers
# ---------------------------------
def extract_ip(text):
match = re.search(r"\b(?:\d{1,3}\.){3}\d{1,3}\b", text)
return match.group(0) if match else None
def extract_hour(text):
match = re.search(r"\b(\d{1,2})\s*(?:am|pm)?\b", text.lower())
return int(match.group(1)) if match else None
def extract_date_fragment(text):
months = [
"jan", "feb", "mar", "apr", "may", "jun",
"jul", "aug", "sep", "oct", "nov", "dec"
]
t = text.lower()
for m in months:
if m in t:
return m
return None
def detect_intent(question):
q = question.lower()
if extract_ip(q):
return "ip_drilldown"
if "incident" in q:
return "incidents"
if "top" in q or "suspicious" in q or "threat" in q:
return "top_threats"
if "summary" in q or "report" in q:
return "summary"
if "event type" in q or "common event" in q:
return "event_types"
if "what happened" in q or "around" in q or "at" in q:
return "time_slice"
return "general"
# ---------------------------------
# SQL retrieval
# ---------------------------------
def query_db(sql, params=()):
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute(sql, params)
rows = [dict(r) for r in cursor.fetchall()]
conn.close()
return rows
def retrieve_top_threats():
return query_db("""
SELECT src_ip, threat_score, severity, event_count, session_count,
failed_password_hits, invalid_user_hits, top_usernames
FROM ip_profiles
ORDER BY threat_score DESC
LIMIT 10
""")
def retrieve_incidents():
return query_db("""
SELECT incident_id, src_ip, start_time, end_time, event_count,
session_count, failed_password_hits, invalid_user_hits, top_usernames
FROM incidents
ORDER BY start_time DESC
LIMIT 10
""")
def retrieve_summary():
return query_db("""
SELECT *
FROM daily_summary
ORDER BY daybucket DESC
LIMIT 10
""")
def retrieve_event_types():
return query_db("""
SELECT event_type, COUNT(*) AS hits
FROM events
GROUP BY event_type
ORDER BY hits DESC
LIMIT 10
""")
def retrieve_ip_drilldown(ip):
profile = query_db("""
SELECT *
FROM ip_profiles
WHERE src_ip = ?
""", (ip,))
incidents = query_db("""
SELECT *
FROM incidents
WHERE src_ip = ?
ORDER BY start_time DESC
LIMIT 10
""", (ip,))
explanations = query_db("""
SELECT *
FROM ip_explanations
WHERE src_ip = ?
""", (ip,))
recent_events = query_db("""
SELECT *
FROM events
WHERE src_ip = ?
ORDER BY timestamp DESC
LIMIT 25
""", (ip,))
return {
"profile": profile,
"incidents": incidents,
"explanations": explanations,
"recent_events": recent_events
}
def retrieve_time_slice(question):
hour = extract_hour(question)
month_fragment = extract_date_fragment(question)
sql = """
SELECT timestamp, src_ip, username, event_type, auth_phase, severity_hint
FROM events
WHERE 1=1
"""
params = []
if hour is not None:
sql += " AND CAST(strftime('%H', timestamp) AS INTEGER) = ?"
params.append(hour)
if month_fragment:
sql += " AND lower(timestamp) LIKE ?"
params.append(f"%{month_fragment}%")
sql += " ORDER BY timestamp DESC LIMIT 50"
rows = query_db(sql, tuple(params))
return rows
def retrieve_evidence(question):
intent = detect_intent(question)
if intent == "top_threats":
return {"intent": intent, "data": retrieve_top_threats()}
elif intent == "incidents":
return {"intent": intent, "data": retrieve_incidents()}
elif intent == "summary":
return {"intent": intent, "data": retrieve_summary()}
elif intent == "event_types":
return {"intent": intent, "data": retrieve_event_types()}
elif intent == "ip_drilldown":
ip = extract_ip(question)
return {"intent": intent, "data": retrieve_ip_drilldown(ip)}
elif intent == "time_slice":
return {"intent": intent, "data": retrieve_time_slice(question)}
else:
return {
"intent": "general",
"data": {
"top_threats": retrieve_top_threats(),
"recent_incidents": retrieve_incidents(),
"event_types": retrieve_event_types()
}
}
# ---------------------------------
# Answer generation
# ---------------------------------
def answer_question(question):
try:
evidence = retrieve_evidence(question)
if not evidence or not evidence.get("data"):
return "I could not find relevant evidence in the database for that question."
prompt = f"""
You are a security log analyst.
Use ONLY the evidence below.
Do not invent facts.
If the evidence is incomplete, say so clearly.
Prefer concrete observations over speculation.
Question:
{question}
Retrieved evidence:
{json.dumps(evidence, indent=2, default=str)}
"""
response = client.chat_completion(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
max_tokens=1024
)
return response.choices[0].message.content
except Exception as e:
return f"Error: {str(e)}"
# ---------------------------------
# Startup
# ---------------------------------
ensure_database()
debug_database()
# ---------------------------------
# Gradio app
# ---------------------------------
demo = gr.Interface(
fn=answer_question,
inputs=gr.Textbox(label="Ask a question about the logs", lines=2, placeholder="e.g. Why is 173.234.31.186 suspicious?"),
outputs=gr.Textbox(label="Answer", lines=16),
title="Log Analyzer",
description="Ask grounded questions about the open source SSH log dataset."
)
demo.launch(server_name="0.0.0.0", server_port=7860)