OGrohit commited on
Commit
7cfebd5
·
verified ·
1 Parent(s): 9214f55

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -207
app.py DELETED
@@ -1,207 +0,0 @@
1
- # ============================
2
- # AI-Based Network Intrusion Detection System (NIDS)
3
- # VOIS Internship – Final Project
4
- # ============================
5
-
6
- import streamlit as st
7
- import pandas as pd
8
- import numpy as np
9
- from sklearn.ensemble import RandomForestClassifier
10
- from sklearn.model_selection import train_test_split
11
- from sklearn.metrics import accuracy_score, confusion_matrix
12
- import matplotlib.pyplot as plt
13
- import seaborn as sns
14
- from groq import Groq
15
-
16
- # ============================
17
- # PAGE CONFIG
18
- # ============================
19
- st.set_page_config(page_title="AI-Based NIDS", layout="wide")
20
-
21
- st.title("AI-Based Network Intrusion Detection System")
22
- st.markdown("""
23
- This project implements a **Random Forest–based Network Intrusion Detection System (NIDS)**.
24
- It supports:
25
- - Simulated traffic
26
- - Real CIC-style CSV datasets
27
- - Live packet analysis
28
- - AI-based explanation using Groq
29
- """)
30
-
31
- # ============================
32
- # SESSION STATE INIT
33
- # ============================
34
- for key in ["model", "accuracy", "conf_matrix", "features", "X_test", "y_test"]:
35
- if key not in st.session_state:
36
- st.session_state[key] = None
37
-
38
- # ============================
39
- # SIDEBAR – SETTINGS
40
- # ============================
41
- st.sidebar.header("1. Settings")
42
- groq_api_key = st.sidebar.text_input("Groq API Key", type="password")
43
-
44
- st.sidebar.header("2. Data Mode")
45
- data_mode = st.sidebar.radio(
46
- "Select Data Source",
47
- ("Simulation Mode", "CSV Upload Mode")
48
- )
49
-
50
- # ============================
51
- # DATA LOADING FUNCTIONS
52
- # ============================
53
- def load_simulated_data(samples=2000):
54
- np.random.seed(42)
55
- df = pd.DataFrame({
56
- "packet_size": np.random.randint(20, 1500, samples),
57
- "duration": np.random.uniform(0, 60, samples),
58
- "src_bytes": np.random.randint(0, 10000, samples),
59
- "dst_bytes": np.random.randint(0, 10000, samples),
60
- "failed_logins": np.random.randint(0, 5, samples),
61
- })
62
- df["label"] = np.where(
63
- (df["failed_logins"] > 2) | (df["src_bytes"] > 8000),
64
- 1, 0
65
- )
66
- return df
67
-
68
- def preprocess_csv(df):
69
- df = df.replace([np.inf, -np.inf], np.nan).dropna()
70
-
71
- # Normalize CIC-like labels
72
- if "Label" in df.columns:
73
- df["Label"] = df["Label"].apply(lambda x: 0 if x == "BENIGN" else 1)
74
-
75
- df = df.rename(columns={
76
- "Flow Duration": "duration",
77
- "Total Fwd Packets": "src_bytes",
78
- "Total Backward Packets": "dst_bytes",
79
- "Packet Length Mean": "packet_size",
80
- "Label": "label"
81
- })
82
-
83
- required = ["packet_size", "duration", "src_bytes", "dst_bytes", "label"]
84
- return df[required]
85
-
86
- # ============================
87
- # MODEL TRAINING
88
- # ============================
89
- def train_model(df):
90
- X = df.drop("label", axis=1)
91
- y = df["label"]
92
-
93
- X_train, X_test, y_train, y_test = train_test_split(
94
- X, y, test_size=0.3, random_state=42
95
- )
96
-
97
- model = RandomForestClassifier(
98
- n_estimators=100,
99
- max_depth=12,
100
- random_state=42
101
- )
102
- model.fit(X_train, y_train)
103
-
104
- acc = accuracy_score(y_test, model.predict(X_test))
105
- cm = confusion_matrix(y_test, model.predict(X_test))
106
-
107
- return model, acc, cm, X_test, y_test
108
-
109
- def plot_confusion_matrix(cm):
110
- fig, ax = plt.subplots()
111
- sns.heatmap(
112
- cm, annot=True, fmt="d",
113
- xticklabels=["Normal", "Intrusion"],
114
- yticklabels=["Normal", "Intrusion"],
115
- cmap="Blues", ax=ax
116
- )
117
- ax.set_xlabel("Predicted")
118
- ax.set_ylabel("Actual")
119
- ax.set_title("Confusion Matrix")
120
- return fig
121
-
122
- # ============================
123
- # TRAIN MODEL BUTTON
124
- # ============================
125
- st.sidebar.header("3. Model Training")
126
-
127
- uploaded_file = None
128
- if data_mode == "CSV Upload Mode":
129
- uploaded_file = st.sidebar.file_uploader("Upload CSV Dataset", type=["csv"])
130
-
131
- if st.sidebar.button("Train Model"):
132
- with st.spinner("Training model..."):
133
- if data_mode == "Simulation Mode":
134
- df = load_simulated_data()
135
- else:
136
- if uploaded_file is None:
137
- st.sidebar.error("Please upload a CSV file first.")
138
- st.stop()
139
- raw_df = pd.read_csv(uploaded_file)
140
- df = preprocess_csv(raw_df)
141
-
142
- model, acc, cm, X_test, y_test = train_model(df)
143
-
144
- st.session_state.model = model
145
- st.session_state.accuracy = acc
146
- st.session_state.conf_matrix = cm
147
- st.session_state.X_test = X_test
148
- st.session_state.y_test = y_test
149
-
150
- st.sidebar.success(f"Training completed (Accuracy: {acc:.2%})")
151
-
152
- # ============================
153
- # DASHBOARD
154
- # ============================
155
- st.header("Threat Analysis Dashboard")
156
-
157
- if st.session_state.model is not None:
158
- st.metric("Model Accuracy", f"{st.session_state.accuracy:.2%}")
159
- st.pyplot(plot_confusion_matrix(st.session_state.conf_matrix))
160
-
161
- st.markdown("---")
162
- st.subheader("Live Packet Simulation")
163
-
164
- if st.button("Capture Random Packet"):
165
- idx = np.random.randint(0, len(st.session_state.X_test))
166
- st.session_state.packet = st.session_state.X_test.iloc[idx]
167
- st.session_state.actual = st.session_state.y_test.iloc[idx]
168
-
169
- if "packet" in st.session_state:
170
- packet = st.session_state.packet
171
- pred = st.session_state.model.predict([packet])[0]
172
-
173
- st.write("Packet Data")
174
- st.dataframe(packet.to_frame().T)
175
-
176
- if pred == 1:
177
- st.error("Prediction: Intrusion Detected")
178
- else:
179
- st.success("Prediction: Normal Traffic")
180
-
181
- st.caption(f"Ground Truth: {st.session_state.actual}")
182
-
183
- st.markdown("---")
184
- st.subheader("AI Explanation (Groq)")
185
-
186
- if st.button("Generate Explanation"):
187
- if not groq_api_key:
188
- st.warning("Enter Groq API key first.")
189
- else:
190
- client = Groq(api_key=groq_api_key)
191
- prompt = f"""
192
- You are a cybersecurity analyst.
193
- The following packet was classified as {'Intrusion' if pred == 1 else 'Normal'}.
194
-
195
- Packet details:
196
- {packet.to_string()}
197
-
198
- Explain briefly in simple terms.
199
- """
200
- response = client.chat.completions.create(
201
- model="llama-3.3-70b-versatile",
202
- messages=[{"role": "user", "content": prompt}],
203
- temperature=0.6
204
- )
205
- st.info(response.choices[0].message.content)
206
- else:
207
- st.info("Train the model to begin analysis.")