| """ |
| CTI Bench Evaluation Script for Cybersecurity Retrieval System |
| |
| This script evaluates the retrieval supervisor system against the CTI Bench dataset, |
| including both CTI-ATE (attack technique extraction) and CTI-MCQ (multiple choice questions). |
| """ |
|
|
| import os |
| import sys |
| import pandas as pd |
| import re |
| import json |
| import csv |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Any, Optional |
| from datetime import datetime |
| from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score |
| import numpy as np |
|
|
| |
| from src.agents.retrieval_supervisor.supervisor import RetrievalSupervisor |
|
|
|
|
| class CTIBenchEvaluator: |
| """Evaluator for CTI Bench dataset using the Retrieval Supervisor.""" |
|
|
| def __init__( |
| self, |
| supervisor: Optional[RetrievalSupervisor], |
| dataset_dir: str = "cti_bench/datasets", |
| output_dir: str = "cti_bench/eval_output", |
| ): |
| """ |
| Initialize the CTI Bench evaluator. |
| |
| Args: |
| supervisor: RetrievalSupervisor instance (can be None for CSV processing) |
| dataset_dir: Directory containing CTI Bench datasets |
| output_dir: Directory to save evaluation results |
| """ |
| self.supervisor = supervisor |
| self.dataset_dir = Path(dataset_dir) |
| self.output_dir = Path(output_dir) |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| self.ate_query_template = """You are a cybersecurity expert specializing in cyber threat intelligence. |
| Extract all MITRE Enterprise attack patterns from the following text and map them to their corresponding MITRE technique IDs. |
| Provide reasoning for each identification. |
| Ensure the final line contains only the IDs for the main techniques, separated by commas, excluding any subtechnique IDs. |
| |
| Example of the final line: T1071, T1560, T1547 |
| |
| Text: |
| {attack_description} |
| """ |
|
|
| def load_datasets(self) -> Tuple[pd.DataFrame, pd.DataFrame]: |
| """Load CTI-ATE and CTI-MCQ datasets.""" |
| try: |
| |
| ate_path = self.dataset_dir / "cti-ate.tsv" |
| ate_df = pd.read_csv(ate_path, sep="\t") |
| print(f"Loaded CTI-ATE dataset: {len(ate_df)} samples") |
|
|
| |
| mcq_path = self.dataset_dir / "cti-mcq.tsv" |
| mcq_df = pd.read_csv(mcq_path, sep="\t") |
| print(f"Loaded CTI-MCQ dataset: {len(mcq_df)} samples") |
|
|
| return ate_df, mcq_df |
|
|
| except Exception as e: |
| print(f"Error loading datasets: {e}") |
| raise |
|
|
| def filter_dataset(self, df: pd.DataFrame, dataset_type: str) -> pd.DataFrame: |
| """Filter dataset according to requirements.""" |
| if dataset_type == "ate": |
| |
| filtered_df = df[df["Platform"] == "Enterprise"].copy() |
| print( |
| f"CTI-ATE filtered to Enterprise platform: {len(filtered_df)} samples" |
| ) |
| elif dataset_type == "mcq": |
| |
| filtered_df = df[df["URL"].str.contains("techniques", na=False)].copy() |
| print(f"CTI-MCQ filtered to technique URLs: {len(filtered_df)} samples") |
| else: |
| raise ValueError(f"Invalid dataset type: {dataset_type}") |
|
|
| return filtered_df |
|
|
| def extract_technique_ids_from_response(self, response: str) -> List[str]: |
| """ |
| Extract MITRE technique IDs from the response text. |
| Simplified version: only checks the final line. |
| |
| Args: |
| response: Response text from the supervisor |
| |
| Returns: |
| List of extracted technique IDs, or empty list if not successful |
| """ |
| |
| lines = response.strip().split("\n") |
| if not lines: |
| return [] |
|
|
| final_line = lines[-1].strip() |
| if not final_line: |
| return [] |
|
|
| |
| technique_pattern = r"\bT\d{4}(?:\.\d{3})?\b" |
|
|
| |
| techniques_in_line = re.findall(technique_pattern, final_line) |
| if not techniques_in_line: |
| return [] |
|
|
| |
| clean_line = re.sub(r"[T\d.,\s]", "", final_line) |
| if len(clean_line) > 0: |
| return [] |
|
|
| |
| return [t for t in techniques_in_line if "." not in t] |
|
|
| def extract_mcq_answer_from_response(self, response: str) -> str: |
| """ |
| Extract the final answer (A, B, C, or D) from MCQ response. |
| |
| Args: |
| response: Response text from the supervisor |
| |
| Returns: |
| Extracted answer letter or empty string if not found |
| """ |
| |
| lines = response.strip().split("\n") |
|
|
| |
| for line in reversed(lines[-3:]): |
| line = line.strip() |
| if line in ["A", "B", "C", "D"]: |
| return line |
|
|
| |
| match = re.search(r"\b([ABCD])\b(?:\s*[.)]?)\s*$", line) |
| if match: |
| return match.group(1) |
|
|
| |
| answer_patterns = [ |
| r"(?:answer|choice|option).*?([ABCD])", |
| r"\b([ABCD])\b(?:\s*[.)]?)\s*$", |
| r"^([ABCD])$", |
| ] |
|
|
| for pattern in answer_patterns: |
| matches = re.findall(pattern, response, re.IGNORECASE | re.MULTILINE) |
| if matches: |
| return matches[-1].upper() |
|
|
| return "" |
|
|
| def evaluate_ate_dataset(self, ate_df: pd.DataFrame) -> List[Dict[str, Any]]: |
| """ |
| Evaluate the CTI-ATE dataset. |
| |
| Args: |
| ate_df: Filtered CTI-ATE dataset |
| |
| Returns: |
| List of evaluation results |
| """ |
| results = [] |
|
|
| print(f"\n{'='*60}") |
| print("EVALUATING CTI-ATE DATASET") |
| print(f"{'='*60}") |
|
|
| for i, (idx, row) in enumerate(ate_df.iterrows()): |
| print(f"Processing ATE sample {i + 1}/{len(ate_df)}: {row['URL']}") |
|
|
| |
| max_retries = 3 |
| success = False |
| result = None |
|
|
| for attempt in range(max_retries): |
| try: |
| print(f" Attempt {attempt + 1}/{max_retries}") |
|
|
| |
| query = self.ate_query_template.format( |
| attack_description=row["Description"] |
| ) |
|
|
| |
| response = self.supervisor.invoke_direct_query(query, trace=False) |
|
|
| |
| if "messages" in response and response["messages"]: |
| |
| last_message = None |
| for msg in reversed(response["messages"]): |
| try: |
| if ( |
| hasattr(msg, "content") |
| and hasattr(msg, "type") |
| and msg.type == "ai" |
| ): |
| last_message = msg |
| break |
| except (AttributeError, TypeError) as e: |
| |
| print(f" Warning: Error accessing message type: {e}") |
| continue |
|
|
| if last_message: |
| response_text = last_message.content |
| else: |
| |
| try: |
| response_text = response["messages"][-1].content |
| except (AttributeError, TypeError) as e: |
| print( |
| f" Warning: Error accessing last message content: {e}" |
| ) |
| response_text = str(response["messages"][-1]) |
| else: |
| response_text = str(response) |
|
|
| |
| predicted_techniques = self.extract_technique_ids_from_response( |
| response_text |
| ) |
|
|
| |
| gt_techniques = [ |
| t.strip() for t in row["GT"].split(",") if t.strip() |
| ] |
|
|
| |
| if len(predicted_techniques) > 0: |
| success = True |
| result = { |
| "url": row["URL"], |
| "description": row["Description"], |
| "ground_truth": gt_techniques, |
| "predicted": predicted_techniques, |
| "response_text": response_text, |
| "success": True, |
| "attempts": attempt + 1, |
| } |
| print(f" GT: {gt_techniques}") |
| print(f" Predicted: {predicted_techniques}") |
| print(f" Success: {result['success']} (attempt {attempt + 1})") |
| break |
| else: |
| print(f" No techniques extracted on attempt {attempt + 1}") |
| if attempt == max_retries - 1: |
| |
| result = { |
| "url": row["URL"], |
| "description": row["Description"], |
| "ground_truth": gt_techniques, |
| "predicted": [], |
| "response_text": response_text, |
| "success": False, |
| "attempts": max_retries, |
| } |
| print(f" GT: {gt_techniques}") |
| print(f" Predicted: {predicted_techniques}") |
| print( |
| f" Success: {result['success']} (all attempts failed)" |
| ) |
| print(f" Response text: {response_text}") |
|
|
| except Exception as e: |
| print(f" Error processing sample (attempt {attempt + 1}): {e}") |
| if attempt == max_retries - 1: |
| |
| result = { |
| "url": row["URL"], |
| "description": row["Description"], |
| "ground_truth": [ |
| t.strip() for t in row["GT"].split(",") if t.strip() |
| ], |
| "predicted": [], |
| "response_text": f"Error: {str(e)}", |
| "success": False, |
| "attempts": max_retries, |
| } |
| print(f" Success: {result['success']} (all attempts failed)") |
| results.append(result) |
|
|
| return results |
|
|
| def evaluate_mcq_dataset(self, mcq_df: pd.DataFrame) -> List[Dict[str, Any]]: |
| """ |
| Evaluate the CTI-MCQ dataset. |
| |
| Args: |
| mcq_df: Filtered CTI-MCQ dataset |
| |
| Returns: |
| List of evaluation results |
| """ |
| results = [] |
|
|
| print(f"\n{'='*60}") |
| print("EVALUATING CTI-MCQ DATASET") |
| print(f"{'='*60}") |
|
|
| for i, (idx, row) in enumerate(mcq_df.iterrows()): |
| print(f"Processing MCQ sample {i + 1}/{len(mcq_df)}: {row['URL']}") |
|
|
| try: |
| |
| query = row["Prompt"] |
|
|
| |
| response = self.supervisor.invoke_direct_query(query, trace=False) |
|
|
| |
| if "messages" in response and response["messages"]: |
| |
| last_message = None |
| for msg in reversed(response["messages"]): |
| try: |
| if ( |
| hasattr(msg, "content") |
| and hasattr(msg, "type") |
| and msg.type == "ai" |
| ): |
| last_message = msg |
| break |
| except (AttributeError, TypeError) as e: |
| |
| print(f" Warning: Error accessing message type: {e}") |
| continue |
|
|
| if last_message: |
| response_text = last_message.content |
| else: |
| |
| try: |
| response_text = response["messages"][-1].content |
| except (AttributeError, TypeError) as e: |
| print( |
| f" Warning: Error accessing last message content: {e}" |
| ) |
| response_text = str(response["messages"][-1]) |
| else: |
| response_text = str(response) |
|
|
| |
| predicted_answer = self.extract_mcq_answer_from_response(response_text) |
|
|
| |
| gt_answer = row["GT"].strip().upper() |
|
|
| |
| result = { |
| "url": row["URL"], |
| "prompt": row["Prompt"], |
| "ground_truth": gt_answer, |
| "predicted": predicted_answer, |
| "response_text": response_text, |
| "correct": predicted_answer == gt_answer, |
| "success": len(predicted_answer) > 0, |
| } |
|
|
| results.append(result) |
|
|
| print(f" GT: {gt_answer}") |
| print(f" Predicted: {predicted_answer}") |
| print(f" Correct: {result['correct']}") |
|
|
| except Exception as e: |
| print(f" Error processing sample: {e}") |
| result = { |
| "url": row["URL"], |
| "prompt": row["Prompt"], |
| "ground_truth": row["GT"].strip().upper(), |
| "predicted": "", |
| "response_text": f"Error: {str(e)}", |
| "correct": False, |
| "success": False, |
| } |
| results.append(result) |
|
|
| return results |
|
|
| def calculate_ate_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, float]: |
| """ |
| Calculate evaluation metrics for ATE dataset using sample-level metrics. |
| |
| Args: |
| results: List of ATE evaluation results |
| |
| Returns: |
| Dictionary of calculated metrics |
| """ |
| if not results: |
| return {} |
|
|
| |
| all_techniques = set() |
| for result in results: |
| all_techniques.update(result["ground_truth"]) |
| all_techniques.update(result["predicted"]) |
|
|
| all_techniques = sorted(list(all_techniques)) |
|
|
| |
| sample_precisions = [] |
| sample_recalls = [] |
| sample_f1s = [] |
|
|
| for result in results: |
| gt_set = set(result["ground_truth"]) |
| pred_set = set(result["predicted"]) |
|
|
| |
| if len(pred_set) == 0: |
| precision = 0.0 |
| else: |
| precision = len(gt_set.intersection(pred_set)) / len(pred_set) |
|
|
| if len(gt_set) == 0: |
| recall = 1.0 if len(pred_set) == 0 else 0.0 |
| else: |
| recall = len(gt_set.intersection(pred_set)) / len(gt_set) |
|
|
| if precision + recall == 0: |
| f1 = 0.0 |
| else: |
| f1 = 2 * (precision * recall) / (precision + recall) |
|
|
| sample_precisions.append(precision) |
| sample_recalls.append(recall) |
| sample_f1s.append(f1) |
|
|
| |
| macro_precision = np.mean(sample_precisions) |
| macro_recall = np.mean(sample_recalls) |
| macro_f1 = np.mean(sample_f1s) |
|
|
| |
| total_tp = 0 |
| total_fp = 0 |
| total_fn = 0 |
|
|
| for result in results: |
| gt_set = set(result["ground_truth"]) |
| pred_set = set(result["predicted"]) |
|
|
| tp = len(gt_set.intersection(pred_set)) |
| fp = len(pred_set - gt_set) |
| fn = len(gt_set - pred_set) |
|
|
| total_tp += tp |
| total_fp += fp |
| total_fn += fn |
|
|
| |
| if total_tp + total_fp == 0: |
| micro_precision = 0.0 |
| else: |
| micro_precision = total_tp / (total_tp + total_fp) |
|
|
| if total_tp + total_fn == 0: |
| micro_recall = 0.0 |
| else: |
| micro_recall = total_tp / (total_tp + total_fn) |
|
|
| if micro_precision + micro_recall == 0: |
| micro_f1 = 0.0 |
| else: |
| micro_f1 = ( |
| 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall) |
| ) |
|
|
| |
| exact_match = sum( |
| 1 for r in results if set(r["ground_truth"]) == set(r["predicted"]) |
| ) / len(results) |
| success_rate = sum(1 for r in results if r["success"]) / len(results) |
|
|
| return { |
| |
| "macro_f1": macro_f1, |
| "macro_precision": macro_precision, |
| "macro_recall": macro_recall, |
| "micro_f1": micro_f1, |
| "micro_precision": micro_precision, |
| "micro_recall": micro_recall, |
| |
| "exact_match_ratio": exact_match, |
| "success_rate": success_rate, |
| "total_samples": len(results), |
| "total_techniques": len(all_techniques), |
| } |
|
|
| def calculate_mcq_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, float]: |
| """ |
| Calculate evaluation metrics for MCQ dataset. |
| |
| Args: |
| results: List of MCQ evaluation results |
| |
| Returns: |
| Dictionary of calculated metrics |
| """ |
| if not results: |
| return {} |
|
|
| |
| y_true = [] |
| y_pred = [] |
|
|
| for result in results: |
| if result["success"]: |
| y_true.append(result["ground_truth"]) |
| y_pred.append(result["predicted"]) |
|
|
| if not y_true: |
| return { |
| "accuracy": 0.0, |
| "f1_macro": 0.0, |
| "f1_micro": 0.0, |
| "precision_macro": 0.0, |
| "recall_macro": 0.0, |
| "success_rate": 0.0, |
| "total_samples": len(results), |
| "answered_samples": 0, |
| } |
|
|
| |
| accuracy = accuracy_score(y_true, y_pred) |
| f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=0) |
| f1_micro = f1_score(y_true, y_pred, average="micro", zero_division=0) |
| precision_macro = precision_score( |
| y_true, y_pred, average="macro", zero_division=0 |
| ) |
| recall_macro = recall_score(y_true, y_pred, average="macro", zero_division=0) |
|
|
| success_rate = sum(1 for r in results if r["success"]) / len(results) |
|
|
| return { |
| "accuracy": accuracy, |
| "f1_macro": f1_macro, |
| "f1_micro": f1_micro, |
| "precision_macro": precision_macro, |
| "recall_macro": recall_macro, |
| "success_rate": success_rate, |
| "total_samples": len(results), |
| "answered_samples": len(y_true), |
| } |
|
|
| def save_results_to_csv( |
| self, results: List[Dict[str, Any]], dataset_type: str, model_name: str = None |
| ): |
| """ |
| Save evaluation results to CSV file. |
| |
| Args: |
| results: Evaluation results |
| dataset_type: Type of dataset ("ate" or "mcq") |
| model_name: Model name (if None, extracted from supervisor) |
| """ |
| if model_name is None: |
| if self.supervisor is not None: |
| model_name = self.supervisor.llm_model.split(":")[-1] |
| else: |
| model_name = "unknown_model" |
|
|
| |
| sanitized_model_name = self._sanitize_filename(model_name) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| if dataset_type == "ate": |
| csv_path = ( |
| self.output_dir / f"cti-ate_{sanitized_model_name}_{timestamp}.csv" |
| ) |
| with open(csv_path, "w", newline="", encoding="utf-8") as f: |
| writer = csv.writer(f) |
| writer.writerow(["Description", "GT", "Predicted"]) |
|
|
| for result in results: |
| description = result["description"] |
| gt = ", ".join(result["ground_truth"]) |
| predicted = ", ".join(result["predicted"]) |
| writer.writerow([description, gt, predicted]) |
|
|
| print(f"ATE results saved to: {csv_path}") |
|
|
| elif dataset_type == "mcq": |
| csv_path = ( |
| self.output_dir / f"cti-mcq_{sanitized_model_name}_{timestamp}.csv" |
| ) |
| with open(csv_path, "w", newline="", encoding="utf-8") as f: |
| writer = csv.writer(f) |
| writer.writerow(["Prompt", "GT", "Predicted"]) |
|
|
| for result in results: |
| prompt = result["prompt"] |
| writer.writerow( |
| [prompt, result["ground_truth"], result["predicted"]] |
| ) |
|
|
| print(f"MCQ results saved to: {csv_path}") |
| else: |
| raise ValueError(f"Invalid dataset type: {dataset_type}") |
|
|
| def save_evaluation_summary( |
| self, metrics: Dict[str, float], dataset_type: str, model_name: str = None |
| ): |
| """ |
| Save evaluation summary to JSON file. |
| |
| Args: |
| metrics: Evaluation metrics |
| dataset_type: Type of dataset ("ate" or "mcq") |
| model_name: Model name (if None, extracted from supervisor) |
| """ |
| if model_name is None: |
| if self.supervisor is not None: |
| model_name = self.supervisor.llm_model.split(":")[-1] |
| else: |
| model_name = "unknown_model" |
|
|
| |
| sanitized_model_name = self._sanitize_filename(model_name) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| summary = { |
| "evaluation_timestamp": datetime.now().isoformat(), |
| "dataset_type": dataset_type, |
| "model_name": model_name, |
| "metrics": metrics, |
| } |
|
|
| summary_path = ( |
| self.output_dir |
| / f"evaluation_summary_{dataset_type}_{sanitized_model_name}_{timestamp}.json" |
| ) |
| with open(summary_path, "w", encoding="utf-8") as f: |
| json.dump(summary, f, indent=2) |
|
|
| print(f"Evaluation summary saved to: {summary_path}") |
|
|
| def _extract_dataset_type_from_filename(self, filename: str) -> str: |
| """ |
| Extract dataset type from CSV filename. |
| |
| Args: |
| filename: The filename (without extension) to extract dataset type from |
| |
| Returns: |
| Dataset type ("ate" or "mcq") |
| """ |
| if "cti-ate" in filename.lower(): |
| return "ate" |
| elif "cti-mcq" in filename.lower(): |
| return "mcq" |
| else: |
| raise ValueError(f"Cannot determine dataset type from filename: {filename}") |
|
|
| def _sanitize_filename(self, filename: str) -> str: |
| """ |
| Sanitize a string to be safe for use in filenames. |
| |
| Args: |
| filename: The string to sanitize |
| |
| Returns: |
| Sanitized filename string |
| """ |
| import re |
|
|
| |
| sanitized = re.sub(r'[/\\:*?"<>|]', "-", filename) |
|
|
| |
| sanitized = re.sub(r"-+", "-", sanitized).strip("-") |
|
|
| return sanitized if sanitized else "unknown" |
|
|
| def read_csv_results( |
| self, csv_path: str, dataset_type: str |
| ) -> List[Dict[str, Any]]: |
| """ |
| Read existing CSV results and convert to evaluation results format. |
| |
| Args: |
| csv_path: Path to the CSV file |
| dataset_type: Type of dataset ("ate" or "mcq") |
| |
| Returns: |
| List of evaluation results in the same format as evaluate_*_dataset methods |
| """ |
| try: |
| df = pd.read_csv(csv_path) |
| results = [] |
|
|
| if dataset_type == "ate": |
| |
| for _, row in df.iterrows(): |
| |
| gt_techniques = [ |
| t.strip() for t in str(row["GT"]).split(",") if t.strip() |
| ] |
| predicted_techniques = [ |
| t.strip() for t in str(row["Predicted"]).split(",") if t.strip() |
| ] |
|
|
| result = { |
| "url": f"csv_row_{len(results)}", |
| "description": str(row["Description"]), |
| "ground_truth": gt_techniques, |
| "predicted": predicted_techniques, |
| "response_text": f"GT: {', '.join(gt_techniques)}, Predicted: {', '.join(predicted_techniques)}", |
| "success": len(predicted_techniques) > 0, |
| "attempts": 1, |
| } |
| results.append(result) |
|
|
| elif dataset_type == "mcq": |
| |
| for _, row in df.iterrows(): |
| gt_answer = str(row["GT"]).strip().upper() |
| predicted_answer = str(row["Predicted"]).strip().upper() |
|
|
| result = { |
| "url": f"csv_row_{len(results)}", |
| "prompt": str(row["Prompt"]), |
| "ground_truth": gt_answer, |
| "predicted": predicted_answer, |
| "response_text": f"GT: {gt_answer}, Predicted: {predicted_answer}", |
| "correct": predicted_answer == gt_answer, |
| "success": len(predicted_answer) > 0, |
| } |
| results.append(result) |
|
|
| else: |
| raise ValueError(f"Invalid dataset type: {dataset_type}") |
|
|
| print(f"Successfully read {len(results)} results from {csv_path}") |
| return results |
|
|
| except Exception as e: |
| print(f"Error reading CSV file {csv_path}: {e}") |
| raise |
|
|
| def calculate_metrics_from_csv( |
| self, csv_path: str, model_name: str = None |
| ) -> Dict[str, Any]: |
| """ |
| Read existing CSV results, calculate metrics, and save summary. |
| |
| Args: |
| csv_path: Path to the CSV file |
| model_name: Model name to use in summary (if None, extracted from filename) |
| |
| Returns: |
| Dictionary containing results and metrics |
| """ |
| |
| filename = Path(csv_path).stem |
| dataset_type = self._extract_dataset_type_from_filename(filename) |
|
|
| if model_name is None: |
| |
| parts = filename.split("_") |
| if len(parts) >= 2: |
| model_name = parts[1] |
| else: |
| model_name = "unknown_model" |
|
|
| print(f"Processing CSV file: {csv_path}") |
| print(f"Dataset type: {dataset_type} (extracted from filename)") |
| print(f"Model name: {model_name}") |
|
|
| |
| results = self.read_csv_results(csv_path, dataset_type) |
|
|
| |
| if dataset_type == "ate": |
| metrics = self.calculate_ate_metrics(results) |
| elif dataset_type == "mcq": |
| metrics = self.calculate_mcq_metrics(results) |
| else: |
| raise ValueError(f"Invalid dataset type: {dataset_type}") |
|
|
| |
| sanitized_model_name = self._sanitize_filename(model_name) |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| summary = { |
| "evaluation_timestamp": datetime.now().isoformat(), |
| "dataset_type": dataset_type, |
| "model_name": model_name, |
| "source_csv": csv_path, |
| "metrics": metrics, |
| } |
|
|
| summary_path = ( |
| self.output_dir |
| / f"evaluation_summary_{dataset_type}_{sanitized_model_name}_{timestamp}.json" |
| ) |
| with open(summary_path, "w", encoding="utf-8") as f: |
| json.dump(summary, f, indent=2) |
|
|
| print(f"Evaluation summary saved to: {summary_path}") |
|
|
| |
| print(f"\n{'='*60}") |
| print(f"METRICS FROM CSV: {dataset_type.upper()}") |
| print(f"{'='*60}") |
|
|
| if dataset_type == "ate": |
| print(f"Macro F1: {metrics.get('macro_f1', 0.0):.3f}") |
| print(f"Macro Precision: {metrics.get('macro_precision', 0.0):.3f}") |
| print(f"Macro Recall: {metrics.get('macro_recall', 0.0):.3f}") |
| print(f"Micro F1: {metrics.get('micro_f1', 0.0):.3f}") |
| print(f"Exact Match: {metrics.get('exact_match_ratio', 0.0):.3f}") |
| print(f"Success Rate: {metrics.get('success_rate', 0.0):.3f}") |
| print(f"Total Samples: {metrics.get('total_samples', 0)}") |
| elif dataset_type == "mcq": |
| print(f"Accuracy: {metrics.get('accuracy', 0.0):.3f}") |
| print(f"F1 Macro: {metrics.get('f1_macro', 0.0):.3f}") |
| print(f"Success Rate: {metrics.get('success_rate', 0.0):.3f}") |
| print(f"Total Samples: {metrics.get('total_samples', 0)}") |
|
|
| print(f"{'='*60}") |
|
|
| return { |
| "results": results, |
| "metrics": metrics, |
| "summary_path": str(summary_path), |
| } |
|
|
| def run_full_evaluation(self) -> Dict[str, Any]: |
| """ |
| Run the complete evaluation pipeline. |
| |
| Returns: |
| Dictionary containing all evaluation results and metrics |
| """ |
| print("Starting CTI Bench evaluation...") |
| print(f"Output directory: {self.output_dir}") |
|
|
| |
| ate_df, mcq_df = self.load_datasets() |
| ate_filtered = self.filter_dataset(ate_df, "ate") |
| mcq_filtered = self.filter_dataset(mcq_df, "mcq") |
|
|
| |
| ate_results = self.evaluate_ate_dataset(ate_filtered) |
| ate_metrics = self.calculate_ate_metrics(ate_results) |
|
|
| |
| mcq_results = self.evaluate_mcq_dataset(mcq_filtered) |
| mcq_metrics = self.calculate_mcq_metrics(mcq_results) |
|
|
| |
| self.save_results_to_csv(ate_results, "ate") |
| self.save_results_to_csv(mcq_results, "mcq") |
| self.save_evaluation_summary(ate_metrics, "ate") |
| self.save_evaluation_summary(mcq_metrics, "mcq") |
|
|
| |
| print(f"\n{'='*60}") |
| print("EVALUATION SUMMARY") |
| print(f"{'='*60}") |
| print(f"CTI-ATE Results:") |
| print(f" Macro F1: {ate_metrics.get('macro_f1', 0.0):.3f}") |
| print(f" Macro Precision: {ate_metrics.get('macro_precision', 0.0):.3f}") |
| print(f" Macro Recall: {ate_metrics.get('macro_recall', 0.0):.3f}") |
| print(f" Micro F1: {ate_metrics.get('micro_f1', 0.0):.3f}") |
| print(f" Exact Match: {ate_metrics.get('exact_match_ratio', 0.0):.3f}") |
| print(f" Success Rate: {ate_metrics.get('success_rate', 0.0):.3f}") |
| print(f" Total Samples: {ate_metrics.get('total_samples', 0)}") |
|
|
| print(f"\nCTI-MCQ Results:") |
| print(f" Accuracy: {mcq_metrics.get('accuracy', 0.0):.3f}") |
| print(f" F1 Macro: {mcq_metrics.get('f1_macro', 0.0):.3f}") |
| print(f" Success Rate: {mcq_metrics.get('success_rate', 0.0):.3f}") |
| print(f" Total Samples: {mcq_metrics.get('total_samples', 0)}") |
| print(f"{'='*60}") |
|
|
| return { |
| "ate_results": ate_results, |
| "mcq_results": mcq_results, |
| "ate_metrics": ate_metrics, |
| "mcq_metrics": mcq_metrics, |
| } |
|
|
| def run_ate_evaluation(self) -> Dict[str, Any]: |
| """ |
| Run evaluation on ATE dataset only. |
| |
| Returns: |
| Dictionary containing ATE evaluation results and metrics |
| """ |
| print("Starting CTI-ATE evaluation...") |
| print(f"Output directory: {self.output_dir}") |
|
|
| |
| ate_df, mcq_df = self.load_datasets() |
| ate_filtered = self.filter_dataset(ate_df, "ate") |
|
|
| |
| ate_results = self.evaluate_ate_dataset(ate_filtered) |
| ate_metrics = self.calculate_ate_metrics(ate_results) |
|
|
| |
| self.save_results_to_csv(ate_results, "ate") |
| self.save_evaluation_summary(ate_metrics, "ate") |
|
|
| |
| print(f"\n{'='*60}") |
| print("CTI-ATE EVALUATION SUMMARY") |
| print(f"{'='*60}") |
| print(f"CTI-ATE Results:") |
| print(f" Macro F1: {ate_metrics.get('macro_f1', 0.0):.3f}") |
| print(f" Macro Precision: {ate_metrics.get('macro_precision', 0.0):.3f}") |
| print(f" Macro Recall: {ate_metrics.get('macro_recall', 0.0):.3f}") |
| print(f" Micro F1: {ate_metrics.get('micro_f1', 0.0):.3f}") |
| print(f" Exact Match: {ate_metrics.get('exact_match_ratio', 0.0):.3f}") |
| print(f" Success Rate: {ate_metrics.get('success_rate', 0.0):.3f}") |
| print(f" Total Samples: {ate_metrics.get('total_samples', 0)}") |
| print(f"{'='*60}") |
|
|
| return { |
| "ate_results": ate_results, |
| "ate_metrics": ate_metrics, |
| } |
|
|
| def run_mcq_evaluation(self) -> Dict[str, Any]: |
| """ |
| Run evaluation on MCQ dataset only. |
| |
| Returns: |
| Dictionary containing MCQ evaluation results and metrics |
| """ |
| print("Starting CTI-MCQ evaluation...") |
| print(f"Output directory: {self.output_dir}") |
|
|
| |
| ate_df, mcq_df = self.load_datasets() |
| mcq_filtered = self.filter_dataset(mcq_df, "mcq") |
|
|
| |
| mcq_results = self.evaluate_mcq_dataset(mcq_filtered) |
| mcq_metrics = self.calculate_mcq_metrics(mcq_results) |
|
|
| |
| self.save_results_to_csv(mcq_results, "mcq") |
| self.save_evaluation_summary(mcq_metrics, "mcq") |
|
|
| |
| print(f"\n{'='*60}") |
| print("CTI-MCQ EVALUATION SUMMARY") |
| print(f"{'='*60}") |
| print(f"CTI-MCQ Results:") |
| print(f" Accuracy: {mcq_metrics.get('accuracy', 0.0):.3f}") |
| print(f" F1 Macro: {mcq_metrics.get('f1_macro', 0.0):.3f}") |
| print(f" Success Rate: {mcq_metrics.get('success_rate', 0.0):.3f}") |
| print(f" Total Samples: {mcq_metrics.get('total_samples', 0)}") |
| print(f"{'='*60}") |
|
|
| return { |
| "mcq_results": mcq_results, |
| "mcq_metrics": mcq_metrics, |
| } |
|
|
|
|
| def main(): |
| """Main function to run the evaluation.""" |
| import argparse |
|
|
| parser = argparse.ArgumentParser( |
| description="Evaluate Retrieval Supervisor on CTI Bench dataset" |
| ) |
| parser.add_argument( |
| "--dataset-dir", |
| default="cti_bench/datasets", |
| help="Directory containing CTI Bench datasets", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| default="cti_bench/eval_output", |
| help="Directory to save evaluation results", |
| ) |
| parser.add_argument( |
| "--kb-path", |
| default="./cyber_knowledge_base", |
| help="Path to cyber knowledge base", |
| ) |
| parser.add_argument( |
| "--llm-model", default="google_genai:gemini-2.0-flash", help="LLM model to use" |
| ) |
| parser.add_argument( |
| "--max-samples", |
| type=int, |
| help="Maximum number of samples to evaluate (for testing)", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| try: |
| |
| print("Initializing Retrieval Supervisor...") |
| supervisor = RetrievalSupervisor( |
| llm_model=args.llm_model, kb_path=args.kb_path, max_iterations=3 |
| ) |
|
|
| |
| evaluator = CTIBenchEvaluator( |
| supervisor=supervisor, |
| dataset_dir=args.dataset_dir, |
| output_dir=args.output_dir, |
| ) |
|
|
| |
| results = evaluator.run_full_evaluation() |
|
|
| print("Evaluation completed successfully!") |
|
|
| except Exception as e: |
| print(f"Evaluation failed: {e}") |
| import traceback |
|
|
| traceback.print_exc() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|