OGrohit commited on
Commit
a95d8aa
·
verified ·
1 Parent(s): edd309d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -0
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================
2
+ # AI-Based Network Intrusion Detection System (NIDS)
3
+ # VOIS Internship – Final Project
4
+ # ============================
5
+
6
+ import streamlit as st
7
+ import pandas as pd
8
+ import numpy as np
9
+ from sklearn.ensemble import RandomForestClassifier
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.metrics import accuracy_score, confusion_matrix
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ from groq import Groq
15
+
16
+ # ============================
17
+ # PAGE CONFIG
18
+ # ============================
19
+ st.set_page_config(page_title="AI-Based NIDS", layout="wide")
20
+
21
+ st.title("AI-Based Network Intrusion Detection System")
22
+ st.markdown("""
23
+ This project implements a **Random Forest–based Network Intrusion Detection System (NIDS)**.
24
+ It supports:
25
+ - Simulated traffic
26
+ - Real CIC-style CSV datasets
27
+ - Live packet analysis
28
+ - AI-based explanation using Groq
29
+ """)
30
+
31
+ # ============================
32
+ # SESSION STATE INIT
33
+ # ============================
34
+ for key in ["model", "accuracy", "conf_matrix", "features", "X_test", "y_test"]:
35
+ if key not in st.session_state:
36
+ st.session_state[key] = None
37
+
38
+ # ============================
39
+ # SIDEBAR – SETTINGS
40
+ # ============================
41
+ st.sidebar.header("1. Settings")
42
+ groq_api_key = st.sidebar.text_input("Groq API Key", type="password")
43
+
44
+ st.sidebar.header("2. Data Mode")
45
+ data_mode = st.sidebar.radio(
46
+ "Select Data Source",
47
+ ("Simulation Mode", "CSV Upload Mode")
48
+ )
49
+
50
+ # ============================
51
+ # DATA LOADING FUNCTIONS
52
+ # ============================
53
+ def load_simulated_data(samples=2000):
54
+ np.random.seed(42)
55
+ df = pd.DataFrame({
56
+ "packet_size": np.random.randint(20, 1500, samples),
57
+ "duration": np.random.uniform(0, 60, samples),
58
+ "src_bytes": np.random.randint(0, 10000, samples),
59
+ "dst_bytes": np.random.randint(0, 10000, samples),
60
+ "failed_logins": np.random.randint(0, 5, samples),
61
+ })
62
+ df["label"] = np.where(
63
+ (df["failed_logins"] > 2) | (df["src_bytes"] > 8000),
64
+ 1, 0
65
+ )
66
+ return df
67
+
68
+ def preprocess_csv(df):
69
+ df = df.replace([np.inf, -np.inf], np.nan).dropna()
70
+
71
+ # Normalize CIC-like labels
72
+ if "Label" in df.columns:
73
+ df["Label"] = df["Label"].apply(lambda x: 0 if x == "BENIGN" else 1)
74
+
75
+ df = df.rename(columns={
76
+ "Flow Duration": "duration",
77
+ "Total Fwd Packets": "src_bytes",
78
+ "Total Backward Packets": "dst_bytes",
79
+ "Packet Length Mean": "packet_size",
80
+ "Label": "label"
81
+ })
82
+
83
+ required = ["packet_size", "duration", "src_bytes", "dst_bytes", "label"]
84
+ return df[required]
85
+
86
+ # ============================
87
+ # MODEL TRAINING
88
+ # ============================
89
+ def train_model(df):
90
+ X = df.drop("label", axis=1)
91
+ y = df["label"]
92
+
93
+ X_train, X_test, y_train, y_test = train_test_split(
94
+ X, y, test_size=0.3, random_state=42
95
+ )
96
+
97
+ model = RandomForestClassifier(
98
+ n_estimators=100,
99
+ max_depth=12,
100
+ random_state=42
101
+ )
102
+ model.fit(X_train, y_train)
103
+
104
+ acc = accuracy_score(y_test, model.predict(X_test))
105
+ cm = confusion_matrix(y_test, model.predict(X_test))
106
+
107
+ return model, acc, cm, X_test, y_test
108
+
109
+ def plot_confusion_matrix(cm):
110
+ fig, ax = plt.subplots()
111
+ sns.heatmap(
112
+ cm, annot=True, fmt="d",
113
+ xticklabels=["Normal", "Intrusion"],
114
+ yticklabels=["Normal", "Intrusion"],
115
+ cmap="Blues", ax=ax
116
+ )
117
+ ax.set_xlabel("Predicted")
118
+ ax.set_ylabel("Actual")
119
+ ax.set_title("Confusion Matrix")
120
+ return fig
121
+
122
+ # ============================
123
+ # TRAIN MODEL BUTTON
124
+ # ============================
125
+ st.sidebar.header("3. Model Training")
126
+
127
+ uploaded_file = None
128
+ if data_mode == "CSV Upload Mode":
129
+ uploaded_file = st.sidebar.file_uploader("Upload CSV Dataset", type=["csv"])
130
+
131
+ if st.sidebar.button("Train Model"):
132
+ with st.spinner("Training model..."):
133
+ if data_mode == "Simulation Mode":
134
+ df = load_simulated_data()
135
+ else:
136
+ if uploaded_file is None:
137
+ st.sidebar.error("Please upload a CSV file first.")
138
+ st.stop()
139
+ raw_df = pd.read_csv(uploaded_file)
140
+ df = preprocess_csv(raw_df)
141
+
142
+ model, acc, cm, X_test, y_test = train_model(df)
143
+
144
+ st.session_state.model = model
145
+ st.session_state.accuracy = acc
146
+ st.session_state.conf_matrix = cm
147
+ st.session_state.X_test = X_test
148
+ st.session_state.y_test = y_test
149
+
150
+ st.sidebar.success(f"Training completed (Accuracy: {acc:.2%})")
151
+
152
+ # ============================
153
+ # DASHBOARD
154
+ # ============================
155
+ st.header("Threat Analysis Dashboard")
156
+
157
+ if st.session_state.model is not None:
158
+ st.metric("Model Accuracy", f"{st.session_state.accuracy:.2%}")
159
+ st.pyplot(plot_confusion_matrix(st.session_state.conf_matrix))
160
+
161
+ st.markdown("---")
162
+ st.subheader("Live Packet Simulation")
163
+
164
+ if st.button("Capture Random Packet"):
165
+ idx = np.random.randint(0, len(st.session_state.X_test))
166
+ st.session_state.packet = st.session_state.X_test.iloc[idx]
167
+ st.session_state.actual = st.session_state.y_test.iloc[idx]
168
+
169
+ if "packet" in st.session_state:
170
+ packet = st.session_state.packet
171
+ pred = st.session_state.model.predict([packet])[0]
172
+
173
+ st.write("Packet Data")
174
+ st.dataframe(packet.to_frame().T)
175
+
176
+ if pred == 1:
177
+ st.error("Prediction: Intrusion Detected")
178
+ else:
179
+ st.success("Prediction: Normal Traffic")
180
+
181
+ st.caption(f"Ground Truth: {st.session_state.actual}")
182
+
183
+ st.markdown("---")
184
+ st.subheader("AI Explanation (Groq)")
185
+
186
+ if st.button("Generate Explanation"):
187
+ if not groq_api_key:
188
+ st.warning("Enter Groq API key first.")
189
+ else:
190
+ client = Groq(api_key=groq_api_key)
191
+ prompt = f"""
192
+ You are a cybersecurity analyst.
193
+ The following packet was classified as {'Intrusion' if pred == 1 else 'Normal'}.
194
+
195
+ Packet details:
196
+ {packet.to_string()}
197
+
198
+ Explain briefly in simple terms.
199
+ """
200
+ response = client.chat.completions.create(
201
+ model="llama-3.3-70b-versatile",
202
+ messages=[{"role": "user", "content": prompt}],
203
+ temperature=0.6
204
+ )
205
+ st.info(response.choices[0].message.content)
206
+ else:
207
+ st.info("Train the model to begin analysis.")