OGrohit commited on
Commit
be70060
·
verified ·
1 Parent(s): 7cfebd5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.ensemble import RandomForestClassifier
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.metrics import accuracy_score
7
+ from groq import Groq
8
+ import os
9
+
10
+ # --- PAGE SETUP ---
11
+ st.set_page_config(page_title="AI-NIDS Student Project", layout="wide")
12
+
13
+ st.title("AI-Based Network Intrusion Detection System")
14
+ st.markdown("""
15
+ **Student Project**: This system uses **Random Forest** to detect Network attacks and **Groq AI** to explain the packets.
16
+ """)
17
+
18
+ # --- CONFIGURATION ---
19
+ DATA_FILE = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv"
20
+
21
+ # --- SIDEBAR: SETTINGS ---
22
+ st.sidebar.header("1. Settings")
23
+ groq_api_key = st.sidebar.text_input("Groq API Key (starts with gsk_)", type="password")
24
+ st.sidebar.caption("[Get a free key here](https://console.groq.com/keys)")
25
+
26
+ st.sidebar.header("2. Model Training")
27
+
28
+ @st.cache_data
29
+ def load_data(filepath):
30
+ try:
31
+ df = pd.read_csv(filepath, nrows=15000)
32
+ df.columns = df.columns.str.strip()
33
+ df.replace([np.inf, -np.inf], np.nan, inplace=True)
34
+ df.dropna(inplace=True)
35
+ return df
36
+ except FileNotFoundError:
37
+ return None
38
+
39
+ def train_model(df):
40
+ features = ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets',
41
+ 'Total Length of Fwd Packets', 'Fwd Packet Length Max',
42
+ 'Flow IAT Mean', 'Flow IAT Std', 'Flow Packets/s']
43
+ target = 'Label'
44
+
45
+ missing_cols = [c for c in features if c not in df.columns]
46
+ if missing_cols:
47
+ st.error(f"Missing columns in CSV: {missing_cols}")
48
+ return None, 0, [], None, None
49
+
50
+ X = df[features]
51
+ y = df[target]
52
+
53
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
54
+
55
+ clf = RandomForestClassifier(n_estimators=10, max_depth=10, random_state=42)
56
+ clf.fit(X_train, y_train)
57
+
58
+ score = accuracy_score(y_test, clf.predict(X_test))
59
+ return clf, score, features, X_test, y_test
60
+
61
+ # --- APP LOGIC ---
62
+ df = load_data(DATA_FILE)
63
+
64
+ if df is None:
65
+ st.error(f"Error: File '{DATA_FILE}' not found. Please upload it to the Files tab.")
66
+ st.stop()
67
+
68
+ st.sidebar.success(f"Dataset Loaded: {len(df)} rows")
69
+
70
+ if st.sidebar.button("Train Model Now"):
71
+ with st.spinner("Training model..."):
72
+ clf, accuracy, feature_names, X_test, y_test = train_model(df)
73
+ if clf:
74
+ st.session_state['model'] = clf
75
+ st.session_state['features'] = feature_names
76
+ st.session_state['X_test'] = X_test
77
+ st.session_state['y_test'] = y_test
78
+ st.sidebar.success(f"Training Complete! Accuracy: {accuracy:.2%}")
79
+
80
+ st.header("3. Threat Analysis Dashboard")
81
+
82
+ if 'model' in st.session_state:
83
+ col1, col2 = st.columns(2)
84
+
85
+ with col1:
86
+ st.subheader("Simulation")
87
+ st.info("Pick a random packet from the test data to simulate live traffic.")
88
+
89
+ if st.button("🎲 Capture Random Packet"):
90
+ random_idx = np.random.randint(0, len(st.session_state['X_test']))
91
+ packet_data = st.session_state['X_test'].iloc[random_idx]
92
+ actual_label = st.session_state['y_test'].iloc[random_idx]
93
+
94
+ st.session_state['current_packet'] = packet_data
95
+ st.session_state['actual_label'] = actual_label
96
+
97
+ if 'current_packet' in st.session_state:
98
+ packet = st.session_state['current_packet']
99
+
100
+ with col1:
101
+ st.write("**Packet Header Info:**")
102
+ st.dataframe(packet, use_container_width=True)
103
+
104
+ with col2:
105
+ st.subheader("AI Detection Result")
106
+ prediction = st.session_state['model'].predict([packet])[0]
107
+
108
+ if prediction == "BENIGN":
109
+ st.success(f" STATUS: **SAFE (BENIGN)**")
110
+ else:
111
+ st.error(f"🚨 STATUS: **ATTACK DETECTED ({prediction})**")
112
+
113
+ st.caption(f"Ground Truth Label: {st.session_state['actual_label']}")
114
+
115
+ st.markdown("---")
116
+ st.subheader(" Ask AI Analyst (Groq)")
117
+
118
+ if st.button("Generate Explanation"):
119
+ if not groq_api_key:
120
+ st.warning(" Please enter your Groq API Key in the sidebar first.")
121
+ else:
122
+ try:
123
+ client = Groq(api_key=groq_api_key)
124
+
125
+ prompt = f"""
126
+ You are a cybersecurity analyst.
127
+ A network packet was detected as: {prediction}.
128
+
129
+ Packet Technical Details:
130
+ {packet.to_string()}
131
+
132
+ Please explain:
133
+ 1. Why these specific values (like Flow Duration or Packet Length) might indicate {prediction}.
134
+ 2. If it is BENIGN, explain why it looks normal.
135
+ 3. Keep the answer short and simple for a student.
136
+ """
137
+
138
+ with st.spinner("Groq is analyzing the packet..."):
139
+ completion = client.chat.completions.create(
140
+ model="llama-3.3-70b-versatile", # <--- UPDATED MODEL NAME
141
+ messages=[
142
+ {"role": "user", "content": prompt}
143
+ ],
144
+ temperature=0.6,
145
+ )
146
+ st.info(completion.choices[0].message.content)
147
+
148
+ except Exception as e:
149
+ st.error(f"API Error: {e}")
150
+ else:
151
+ st.info(" Waiting for model training. Click **'Train Model Now'** in the sidebar.")