ameyjoshi8198 commited on
Commit
2fe112f
·
verified ·
1 Parent(s): d5ab31f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -85
app.py CHANGED
@@ -1,23 +1,26 @@
1
  import os
 
 
 
2
  import sqlite3
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient, hf_hub_download
5
 
6
- # ----------------------------
7
  # Config
8
- # ----------------------------
9
  DB_FILENAME = "auth_llm-v3.sqlite"
10
  DB_PATH = f"./{DB_FILENAME}"
11
-
12
- # Replace this with your actual dataset repo id
13
- DATASET_REPO_ID = "ameyjoshi8198/auth-log-db"
14
 
15
  HF_TOKEN = os.environ["HF_TOKEN"]
16
  client = InferenceClient(token=HF_TOKEN)
17
 
18
- # ----------------------------
19
- # Download DB from HF dataset
20
- # ----------------------------
 
 
21
  def ensure_database():
22
  if not os.path.exists(DB_PATH) or os.path.getsize(DB_PATH) < 1024:
23
  print("Downloading SQLite database from HF dataset repo...")
@@ -27,20 +30,12 @@ def ensure_database():
27
  filename=DB_FILENAME,
28
  token=HF_TOKEN
29
  )
30
-
31
  if downloaded_path != DB_PATH:
32
- import shutil
33
  shutil.copy(downloaded_path, DB_PATH)
34
 
35
- print(f"Database ready at {DB_PATH}")
36
- print(f"Database size: {os.path.getsize(DB_PATH)} bytes")
37
- else:
38
- print(f"Database already exists at {DB_PATH}")
39
- print(f"Database size: {os.path.getsize(DB_PATH)} bytes")
40
 
41
- # ----------------------------
42
- # Debug schema
43
- # ----------------------------
44
  def debug_database():
45
  conn = sqlite3.connect(DB_PATH)
46
  cursor = conn.cursor()
@@ -50,86 +45,207 @@ def debug_database():
50
  print("Available tables:", tables)
51
  return tables
52
 
53
- # ----------------------------
54
- # Retrieval
55
- # ----------------------------
56
- def retrieve_evidence(question):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  conn = sqlite3.connect(DB_PATH)
58
  conn.row_factory = sqlite3.Row
59
  cursor = conn.cursor()
 
 
 
 
60
 
61
- q = question.lower()
 
 
 
 
 
 
 
62
 
63
- if "top" in q or "suspicious" in q:
64
- cursor.execute("""
65
- SELECT src_ip, threat_score, severity, event_count, session_count,
66
- failed_password_hits, invalid_user_hits, top_usernames
67
- FROM ip_profiles
68
- ORDER BY threat_score DESC
69
- LIMIT 10
70
- """)
71
- elif "incident" in q:
72
- cursor.execute("""
73
- SELECT incident_id, src_ip, start_time, end_time, event_count,
74
- session_count, failed_password_hits, invalid_user_hits, top_usernames
75
- FROM incidents
76
- ORDER BY start_time DESC
77
- LIMIT 10
78
- """)
79
- elif "report" in q or "summary" in q:
80
- cursor.execute("""
81
- SELECT *
82
- FROM daily_summary
83
- ORDER BY daybucket DESC
84
- LIMIT 10
85
- """)
86
- elif "event type" in q or "common event" in q:
87
- cursor.execute("""
88
- SELECT event_type, COUNT(*) AS hits
89
- FROM events
90
- GROUP BY event_type
91
- ORDER BY hits DESC
92
- LIMIT 10
93
- """)
94
- else:
95
- cursor.execute("""
96
- SELECT src_ip, threat_score, severity, event_count, session_count,
97
- failed_password_hits, invalid_user_hits, top_usernames
98
- FROM ip_profiles
99
- ORDER BY threat_score DESC
100
- LIMIT 5
101
- """)
102
-
103
- rows = [dict(row) for row in cursor.fetchall()]
104
- conn.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  return rows
106
 
107
- # ----------------------------
108
- # Answering
109
- # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def answer_question(question):
111
  try:
112
  evidence = retrieve_evidence(question)
113
 
114
- if not evidence:
115
  return "I could not find relevant evidence in the database for that question."
116
 
117
- prompt = f"""You are a security log analyst.
 
118
 
119
- Use ONLY the evidence below. Do not invent facts.
120
- If the evidence is insufficient, say so clearly.
121
-
122
- Evidence:
123
- {evidence}
124
 
125
  Question:
126
  {question}
 
 
 
127
  """
128
 
129
  response = client.chat_completion(
130
- model="meta-llama/Llama-4-Scout-17B-16E-Instruct:groq",
131
  messages=[{"role": "user", "content": prompt}],
132
- max_tokens=512
133
  )
134
 
135
  return response.choices[0].message.content
@@ -137,21 +253,21 @@ Question:
137
  except Exception as e:
138
  return f"Error: {str(e)}"
139
 
140
- # ----------------------------
141
  # Startup
142
- # ----------------------------
143
  ensure_database()
144
  debug_database()
145
 
146
- # ----------------------------
147
- # Gradio UI
148
- # ----------------------------
149
  demo = gr.Interface(
150
  fn=answer_question,
151
- inputs=gr.Textbox(label="Ask a question about the logs", lines=2),
152
- outputs=gr.Textbox(label="Answer", lines=14),
153
  title="Security Log Analyzer",
154
- description="Ask questions about the open source log dataset."
155
  )
156
 
157
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
2
+ import re
3
+ import json
4
+ import shutil
5
  import sqlite3
6
  import gradio as gr
7
  from huggingface_hub import InferenceClient, hf_hub_download
8
 
9
+ # ---------------------------------
10
  # Config
11
+ # ---------------------------------
12
  DB_FILENAME = "auth_llm-v3.sqlite"
13
  DB_PATH = f"./{DB_FILENAME}"
14
+ DATASET_REPO_ID = "YOUR_USERNAME/YOUR_DATASET_NAME"
 
 
15
 
16
  HF_TOKEN = os.environ["HF_TOKEN"]
17
  client = InferenceClient(token=HF_TOKEN)
18
 
19
+ MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct:groq"
20
+
21
+ # ---------------------------------
22
+ # DB setup
23
+ # ---------------------------------
24
  def ensure_database():
25
  if not os.path.exists(DB_PATH) or os.path.getsize(DB_PATH) < 1024:
26
  print("Downloading SQLite database from HF dataset repo...")
 
30
  filename=DB_FILENAME,
31
  token=HF_TOKEN
32
  )
 
33
  if downloaded_path != DB_PATH:
 
34
  shutil.copy(downloaded_path, DB_PATH)
35
 
36
+ print(f"Database ready at {DB_PATH}")
37
+ print(f"Database size: {os.path.getsize(DB_PATH)} bytes")
 
 
 
38
 
 
 
 
39
  def debug_database():
40
  conn = sqlite3.connect(DB_PATH)
41
  cursor = conn.cursor()
 
45
  print("Available tables:", tables)
46
  return tables
47
 
48
+ # ---------------------------------
49
+ # Helpers
50
+ # ---------------------------------
51
+ def extract_ip(text):
52
+ match = re.search(r"\b(?:\d{1,3}\.){3}\d{1,3}\b", text)
53
+ return match.group(0) if match else None
54
+
55
+ def extract_hour(text):
56
+ match = re.search(r"\b(\d{1,2})\s*(?:am|pm)?\b", text.lower())
57
+ return int(match.group(1)) if match else None
58
+
59
+ def extract_date_fragment(text):
60
+ months = [
61
+ "jan", "feb", "mar", "apr", "may", "jun",
62
+ "jul", "aug", "sep", "oct", "nov", "dec"
63
+ ]
64
+ t = text.lower()
65
+ for m in months:
66
+ if m in t:
67
+ return m
68
+ return None
69
+
70
+ def detect_intent(question):
71
+ q = question.lower()
72
+
73
+ if extract_ip(q):
74
+ return "ip_drilldown"
75
+ if "incident" in q:
76
+ return "incidents"
77
+ if "top" in q or "suspicious" in q or "threat" in q:
78
+ return "top_threats"
79
+ if "summary" in q or "report" in q:
80
+ return "summary"
81
+ if "event type" in q or "common event" in q:
82
+ return "event_types"
83
+ if "what happened" in q or "around" in q or "at" in q:
84
+ return "time_slice"
85
+ return "general"
86
+
87
+ # ---------------------------------
88
+ # SQL retrieval
89
+ # ---------------------------------
90
+ def query_db(sql, params=()):
91
  conn = sqlite3.connect(DB_PATH)
92
  conn.row_factory = sqlite3.Row
93
  cursor = conn.cursor()
94
+ cursor.execute(sql, params)
95
+ rows = [dict(r) for r in cursor.fetchall()]
96
+ conn.close()
97
+ return rows
98
 
99
+ def retrieve_top_threats():
100
+ return query_db("""
101
+ SELECT src_ip, threat_score, severity, event_count, session_count,
102
+ failed_password_hits, invalid_user_hits, top_usernames
103
+ FROM ip_profiles
104
+ ORDER BY threat_score DESC
105
+ LIMIT 10
106
+ """)
107
 
108
+ def retrieve_incidents():
109
+ return query_db("""
110
+ SELECT incident_id, src_ip, start_time, end_time, event_count,
111
+ session_count, failed_password_hits, invalid_user_hits, top_usernames
112
+ FROM incidents
113
+ ORDER BY start_time DESC
114
+ LIMIT 10
115
+ """)
116
+
117
+ def retrieve_summary():
118
+ return query_db("""
119
+ SELECT *
120
+ FROM daily_summary
121
+ ORDER BY daybucket DESC
122
+ LIMIT 10
123
+ """)
124
+
125
+ def retrieve_event_types():
126
+ return query_db("""
127
+ SELECT event_type, COUNT(*) AS hits
128
+ FROM events
129
+ GROUP BY event_type
130
+ ORDER BY hits DESC
131
+ LIMIT 10
132
+ """)
133
+
134
+ def retrieve_ip_drilldown(ip):
135
+ profile = query_db("""
136
+ SELECT *
137
+ FROM ip_profiles
138
+ WHERE src_ip = ?
139
+ """, (ip,))
140
+
141
+ incidents = query_db("""
142
+ SELECT *
143
+ FROM incidents
144
+ WHERE src_ip = ?
145
+ ORDER BY start_time DESC
146
+ LIMIT 10
147
+ """, (ip,))
148
+
149
+ explanations = query_db("""
150
+ SELECT *
151
+ FROM ip_explanations
152
+ WHERE src_ip = ?
153
+ """, (ip,))
154
+
155
+ recent_events = query_db("""
156
+ SELECT *
157
+ FROM events
158
+ WHERE src_ip = ?
159
+ ORDER BY timestamp DESC
160
+ LIMIT 25
161
+ """, (ip,))
162
+
163
+ return {
164
+ "profile": profile,
165
+ "incidents": incidents,
166
+ "explanations": explanations,
167
+ "recent_events": recent_events
168
+ }
169
+
170
+ def retrieve_time_slice(question):
171
+ hour = extract_hour(question)
172
+ month_fragment = extract_date_fragment(question)
173
+
174
+ sql = """
175
+ SELECT timestamp, src_ip, username, event_type, auth_phase, severity_hint
176
+ FROM events
177
+ WHERE 1=1
178
+ """
179
+ params = []
180
+
181
+ if hour is not None:
182
+ sql += " AND CAST(strftime('%H', timestamp) AS INTEGER) = ?"
183
+ params.append(hour)
184
+
185
+ if month_fragment:
186
+ sql += " AND lower(timestamp) LIKE ?"
187
+ params.append(f"%{month_fragment}%")
188
+
189
+ sql += " ORDER BY timestamp DESC LIMIT 50"
190
+
191
+ rows = query_db(sql, tuple(params))
192
  return rows
193
 
194
+ def retrieve_evidence(question):
195
+ intent = detect_intent(question)
196
+
197
+ if intent == "top_threats":
198
+ return {"intent": intent, "data": retrieve_top_threats()}
199
+ elif intent == "incidents":
200
+ return {"intent": intent, "data": retrieve_incidents()}
201
+ elif intent == "summary":
202
+ return {"intent": intent, "data": retrieve_summary()}
203
+ elif intent == "event_types":
204
+ return {"intent": intent, "data": retrieve_event_types()}
205
+ elif intent == "ip_drilldown":
206
+ ip = extract_ip(question)
207
+ return {"intent": intent, "data": retrieve_ip_drilldown(ip)}
208
+ elif intent == "time_slice":
209
+ return {"intent": intent, "data": retrieve_time_slice(question)}
210
+ else:
211
+ return {
212
+ "intent": "general",
213
+ "data": {
214
+ "top_threats": retrieve_top_threats(),
215
+ "recent_incidents": retrieve_incidents(),
216
+ "event_types": retrieve_event_types()
217
+ }
218
+ }
219
+
220
+ # ---------------------------------
221
+ # Answer generation
222
+ # ---------------------------------
223
  def answer_question(question):
224
  try:
225
  evidence = retrieve_evidence(question)
226
 
227
+ if not evidence or not evidence.get("data"):
228
  return "I could not find relevant evidence in the database for that question."
229
 
230
+ prompt = f"""
231
+ You are a security log analyst.
232
 
233
+ Use ONLY the evidence below.
234
+ Do not invent facts.
235
+ If the evidence is incomplete, say so clearly.
236
+ Prefer concrete observations over speculation.
 
237
 
238
  Question:
239
  {question}
240
+
241
+ Retrieved evidence:
242
+ {json.dumps(evidence, indent=2, default=str)}
243
  """
244
 
245
  response = client.chat_completion(
246
+ model=MODEL_NAME,
247
  messages=[{"role": "user", "content": prompt}],
248
+ max_tokens=700
249
  )
250
 
251
  return response.choices[0].message.content
 
253
  except Exception as e:
254
  return f"Error: {str(e)}"
255
 
256
+ # ---------------------------------
257
  # Startup
258
+ # ---------------------------------
259
  ensure_database()
260
  debug_database()
261
 
262
+ # ---------------------------------
263
+ # Gradio app
264
+ # ---------------------------------
265
  demo = gr.Interface(
266
  fn=answer_question,
267
+ inputs=gr.Textbox(label="Ask a question about the logs", lines=2, placeholder="e.g. Why is 173.234.31.186 suspicious?"),
268
+ outputs=gr.Textbox(label="Answer", lines=16),
269
  title="Security Log Analyzer",
270
+ description="Ask grounded questions about the open source SSH log dataset."
271
  )
272
 
273
  demo.launch(server_name="0.0.0.0", server_port=7860)