Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score | |
| from groq import Groq | |
| import os | |
| # --- PAGE SETUP --- | |
| st.set_page_config(page_title="AI-NIDS Student Project", layout="wide") | |
| st.title("AI-Based Network Intrusion Detection System") | |
| st.markdown(""" | |
| **Student Project**: This system uses **Random Forest** to detect Network attacks and **Groq AI** to explain the packets. | |
| """) | |
| # --- CONFIGURATION --- | |
| DATA_FILE = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv" | |
| # --- SIDEBAR: SETTINGS --- | |
| st.sidebar.header("1. Settings") | |
| groq_api_key = st.sidebar.text_input("Groq API Key (starts with gsk_)", type="password") | |
| st.sidebar.caption("[Get a free key here](https://console.groq.com/keys)") | |
| st.sidebar.header("2. Model Training") | |
| def load_data(filepath): | |
| try: | |
| df = pd.read_csv(filepath, nrows=15000) | |
| df.columns = df.columns.str.strip() | |
| df.replace([np.inf, -np.inf], np.nan, inplace=True) | |
| df.dropna(inplace=True) | |
| return df | |
| except FileNotFoundError: | |
| return None | |
| def train_model(df): | |
| features = ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets', | |
| 'Total Length of Fwd Packets', 'Fwd Packet Length Max', | |
| 'Flow IAT Mean', 'Flow IAT Std', 'Flow Packets/s'] | |
| target = 'Label' | |
| missing_cols = [c for c in features if c not in df.columns] | |
| if missing_cols: | |
| st.error(f"Missing columns in CSV: {missing_cols}") | |
| return None, 0, [], None, None | |
| X = df[features] | |
| y = df[target] | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) | |
| clf = RandomForestClassifier(n_estimators=10, max_depth=10, random_state=42) | |
| clf.fit(X_train, y_train) | |
| score = accuracy_score(y_test, clf.predict(X_test)) | |
| return clf, score, features, X_test, y_test | |
| # --- APP LOGIC --- | |
| df = load_data(DATA_FILE) | |
| if df is None: | |
| st.error(f"Error: File '{DATA_FILE}' not found. Please upload it to the Files tab.") | |
| st.stop() | |
| st.sidebar.success(f"Dataset Loaded: {len(df)} rows") | |
| if st.sidebar.button("Train Model Now"): | |
| with st.spinner("Training model..."): | |
| clf, accuracy, feature_names, X_test, y_test = train_model(df) | |
| if clf: | |
| st.session_state['model'] = clf | |
| st.session_state['features'] = feature_names | |
| st.session_state['X_test'] = X_test | |
| st.session_state['y_test'] = y_test | |
| st.sidebar.success(f"Training Complete! Accuracy: {accuracy:.2%}") | |
| st.header("3. Threat Analysis Dashboard") | |
| if 'model' in st.session_state: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Simulation") | |
| st.info("Pick a random packet from the test data to simulate live traffic.") | |
| if st.button("🎲 Capture Random Packet"): | |
| random_idx = np.random.randint(0, len(st.session_state['X_test'])) | |
| packet_data = st.session_state['X_test'].iloc[random_idx] | |
| actual_label = st.session_state['y_test'].iloc[random_idx] | |
| st.session_state['current_packet'] = packet_data | |
| st.session_state['actual_label'] = actual_label | |
| if 'current_packet' in st.session_state: | |
| packet = st.session_state['current_packet'] | |
| with col1: | |
| st.write("**Packet Header Info:**") | |
| st.dataframe(packet, use_container_width=True) | |
| with col2: | |
| st.subheader("AI Detection Result") | |
| prediction = st.session_state['model'].predict([packet])[0] | |
| if prediction == "BENIGN": | |
| st.success(f" STATUS: **SAFE (BENIGN)**") | |
| else: | |
| st.error(f"🚨 STATUS: **ATTACK DETECTED ({prediction})**") | |
| st.caption(f"Ground Truth Label: {st.session_state['actual_label']}") | |
| st.markdown("---") | |
| st.subheader(" Ask AI Analyst (Groq)") | |
| if st.button("Generate Explanation"): | |
| if not groq_api_key: | |
| st.warning(" Please enter your Groq API Key in the sidebar first.") | |
| else: | |
| try: | |
| client = Groq(api_key=groq_api_key) | |
| prompt = f""" | |
| You are a cybersecurity analyst. | |
| A network packet was detected as: {prediction}. | |
| Packet Technical Details: | |
| {packet.to_string()} | |
| Please explain: | |
| 1. Why these specific values (like Flow Duration or Packet Length) might indicate {prediction}. | |
| 2. If it is BENIGN, explain why it looks normal. | |
| 3. Keep the answer short and simple for a student. | |
| """ | |
| with st.spinner("Groq is analyzing the packet..."): | |
| completion = client.chat.completions.create( | |
| model="llama-3.3-70b-versatile", # <--- UPDATED MODEL NAME | |
| messages=[ | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.6, | |
| ) | |
| st.info(completion.choices[0].message.content) | |
| except Exception as e: | |
| st.error(f"API Error: {e}") | |
| else: | |
| st.info(" Waiting for model training. Click **'Train Model Now'** in the sidebar.") |