|
|
| """
|
| Simple Integrated Pipeline - Direct connection between Log Analysis Agent and Retrieval Supervisor
|
|
|
| This file replaces the complex full_pipeline structure with a straightforward LangGraph
|
| that passes log analysis results directly to the retrieval supervisor.
|
| """
|
|
|
| import os
|
| import sys
|
| import time
|
| from pathlib import Path
|
| from typing import Dict, Any, TypedDict
|
| from langchain.chat_models import init_chat_model
|
| from dotenv import load_dotenv
|
|
|
|
|
| from langgraph.graph import StateGraph, END, START
|
| from langchain_core.messages import HumanMessage
|
|
|
|
|
|
|
| project_root = Path(__file__).parent.parent.parent
|
| sys.path.insert(0, str(project_root))
|
|
|
| from src.agents.log_analysis_agent.agent import LogAnalysisAgent
|
| from src.agents.retrieval_supervisor.supervisor import RetrievalSupervisor
|
| from src.agents.response_agent.response_agent import ResponseAgent
|
|
|
|
|
|
|
| class PipelineState(TypedDict):
|
| log_file: str
|
| log_analysis_result: Dict[str, Any]
|
| retrieval_result: Dict[str, Any]
|
| response_analysis: Dict[str, Any]
|
| query: str
|
| tactic: str
|
| markdown_report: str
|
|
|
|
|
| def create_simple_pipeline(
|
| model_name: str = "google_genai:gemini-2.0-flash",
|
| temperature: float = 0.1,
|
| max_log_analysis_iterations: int = 2,
|
| max_retrieval_iterations: int = 2,
|
| log_agent_output_dir: str = "analysis",
|
| response_agent_output_dir: str = "final_response",
|
| progress_callback=None,
|
| ):
|
|
|
| print("\n" + "=" * 60)
|
| print("INITIALIZING LLM CLIENT")
|
| print("=" * 60)
|
| print(f"Model: {model_name}")
|
| print(f"Temperature: {temperature}")
|
| print("=" * 60 + "\n")
|
|
|
| if "gpt-oss" in model_name and "groq" in model_name:
|
| reasoning_effort = "medium"
|
| reasoning_format = "hidden"
|
| llm_client = init_chat_model(
|
| model_name,
|
| temperature=temperature,
|
| reasoning_effort=reasoning_effort,
|
| reasoning_format=reasoning_format,
|
| )
|
| print(
|
| f"[INFO] Using GPT-OSS model: {model_name} with reasoning effort: {reasoning_effort}"
|
| )
|
| elif "gpt-5" in model_name and "openai" in model_name:
|
| reasoning_effort = "minimal"
|
| llm_client = init_chat_model(
|
| model_name,
|
| reasoning_effort=reasoning_effort,
|
| )
|
| print(
|
| f"[INFO] Using GPT-5 family model: {model_name} with reasoning effort: {reasoning_effort}"
|
| )
|
| else:
|
| llm_client = init_chat_model(model_name, temperature=temperature)
|
| print(f"[INFO] Initialized with {model_name}")
|
|
|
|
|
| log_agent = LogAnalysisAgent(
|
| model_name=model_name,
|
| output_dir=log_agent_output_dir,
|
| max_iterations=max_log_analysis_iterations,
|
| llm_client=llm_client,
|
| )
|
|
|
| retrieval_supervisor = RetrievalSupervisor(
|
| kb_path="./cyber_knowledge_base",
|
| max_iterations=max_retrieval_iterations,
|
| llm_client=llm_client,
|
| )
|
|
|
| response_agent = ResponseAgent(
|
| model_name=model_name,
|
| output_dir=response_agent_output_dir,
|
| llm_client=llm_client,
|
| )
|
|
|
| def run_log_analysis(state: PipelineState) -> PipelineState:
|
| """Run log analysis and capture results."""
|
| print("\n" + "=" * 60)
|
| print("PHASE 1: LOG ANALYSIS")
|
| print("=" * 60)
|
|
|
| log_file = state["log_file"]
|
| print(f"Analyzing log file: {log_file}")
|
|
|
| if progress_callback:
|
| progress_callback(20, "Running log analysis...")
|
|
|
|
|
| analysis_result = log_agent.analyze(log_file)
|
|
|
|
|
| state["log_analysis_result"] = analysis_result
|
|
|
| if progress_callback:
|
| progress_callback(40, "Log analysis completed")
|
|
|
| print(
|
| f"\nLog Analysis Assessment: {analysis_result.get('overall_assessment', 'UNKNOWN')}"
|
| )
|
| print(f"Abnormal Events: {len(analysis_result.get('abnormal_events', []))}")
|
|
|
| return state
|
|
|
| def run_retrieval_with_context(state: PipelineState) -> PipelineState:
|
| """Transform log analysis results and run retrieval supervisor."""
|
| print("\n" + "=" * 60)
|
| print("PHASE 2: THREAT INTELLIGENCE RETRIEVAL")
|
| print("=" * 60)
|
|
|
|
|
| log_analysis_result = state["log_analysis_result"]
|
| assessment = log_analysis_result.get("overall_assessment", "UNKNOWN")
|
|
|
|
|
| query = create_retrieval_query(log_analysis_result, state.get("query"))
|
|
|
| print(f"Generated retrieval query based on {assessment} assessment")
|
| print("\nStarting retrieval supervisor with log analysis context...\n")
|
|
|
| if progress_callback:
|
| progress_callback(50, "Running threat intelligence retrieval...")
|
|
|
|
|
| retrieval_result = retrieval_supervisor.invoke(
|
| query=query,
|
| log_analysis_report=log_analysis_result,
|
| context=state.get("query"),
|
| trace=False,
|
| )
|
|
|
| if progress_callback:
|
| progress_callback(70, "Threat intelligence retrieval completed")
|
|
|
|
|
| state["retrieval_result"] = retrieval_result
|
|
|
| return state
|
|
|
| def run_response_analysis(state: PipelineState) -> PipelineState:
|
| """Run response agent to create Event ID → MITRE technique mappings."""
|
| print("\n" + "=" * 60)
|
| print("PHASE 3: RESPONSE CORRELATION ANALYSIS")
|
| print("=" * 60)
|
| print("Creating Event ID → MITRE technique mappings...")
|
|
|
| if progress_callback:
|
| progress_callback(80, "Running response correlation analysis...")
|
|
|
|
|
| response_analysis, markdown_report = response_agent.analyze_and_map(
|
| log_analysis_result=state["log_analysis_result"],
|
| retrieval_result=state["retrieval_result"],
|
| log_file=state["log_file"],
|
| tactic=state.get("tactic"),
|
| )
|
|
|
| if progress_callback:
|
| progress_callback(90, "Response analysis completed")
|
|
|
|
|
| state["response_analysis"] = response_analysis
|
|
|
|
|
| state["markdown_report"] = markdown_report
|
|
|
|
|
| print(f"Analysis complete! Results saved to final_response folder.")
|
|
|
| print(f"\n" + "=" * 60)
|
| print("PIPELINE COMPLETED")
|
| print("=" * 60)
|
|
|
| return state
|
|
|
|
|
| workflow = StateGraph(PipelineState)
|
|
|
|
|
| workflow.add_node("log_analysis", run_log_analysis)
|
| workflow.add_node("retrieval", run_retrieval_with_context)
|
| workflow.add_node("response", run_response_analysis)
|
|
|
|
|
| workflow.set_entry_point("log_analysis")
|
| workflow.add_edge("log_analysis", "retrieval")
|
| workflow.add_edge("retrieval", "response")
|
| workflow.add_edge("response", END)
|
|
|
| return workflow.compile(name="simple_integrated_pipeline")
|
|
|
|
|
| def create_retrieval_query(
|
| log_analysis_result: Dict[str, Any], user_query: str = None
|
| ) -> str:
|
| """Transform log analysis results into a retrieval query."""
|
| assessment = log_analysis_result.get("overall_assessment", "UNKNOWN")
|
| analysis_summary = log_analysis_result.get("analysis_summary", "")
|
| abnormal_events = log_analysis_result.get("abnormal_events", [])
|
|
|
| if assessment == "NORMAL" and not user_query:
|
| return "Analyze this normal log activity and provide baseline threat intelligence for monitoring purposes."
|
|
|
| query_parts = [
|
| "Analyze the detected security anomalies and provide comprehensive threat intelligence.",
|
| "",
|
| f"Log Analysis Assessment: {assessment}",
|
| f"Summary: {analysis_summary}",
|
| "",
|
| ]
|
|
|
| if abnormal_events:
|
| query_parts.append("Detected Anomalies:")
|
| for i, event in enumerate(abnormal_events[:5], 1):
|
| event_desc = event.get("event_description", "Unknown event")
|
| severity = event.get("severity", "Unknown")
|
| event_id = event.get("event_id", "N/A")
|
|
|
| query_parts.append(f"{i}. Event {event_id} [{severity}]: {event_desc}")
|
|
|
| query_parts.append("")
|
|
|
|
|
| query_parts.extend(
|
| [
|
| "Intelligence Requirements:",
|
| "1. Map findings to relevant MITRE ATT&CK techniques and tactics",
|
| "2. Provide threat actor attribution and campaign intelligence",
|
| "3. Generate actionable IOCs and detection recommendations",
|
| "4. Assess threat severity and recommend response actions",
|
| ]
|
| )
|
|
|
| if user_query:
|
| query_parts.extend(["", f"Additional Context: {user_query}"])
|
|
|
| return "\n".join(query_parts)
|
|
|
|
|
| def analyze_log_file(
|
| log_file: str,
|
| query: str = None,
|
| tactic: str = None,
|
| model_name: str = "google_genai:gemini-2.0-flash",
|
| temperature: float = 0.1,
|
| max_log_analysis_iterations: int = 2,
|
| max_retrieval_iterations: int = 2,
|
| log_agent_output_dir: str = "analysis",
|
| response_agent_output_dir: str = "final_response",
|
| progress_callback=None,
|
| ):
|
| """
|
| Analyze a single log file through the integrated pipeline.
|
|
|
| Args:
|
| log_file: Path to the log file to analyze
|
| query: Optional user query for additional context
|
| tactic: Optional tactic name for organizing output
|
| model_name: Name of the model to use (e.g., "google_genai:gemini-2.0-flash", "groq:gpt-oss-120b", "groq:llama-3.1-8b-instant")
|
| temperature: Temperature for model generation
|
| max_log_analysis_iterations: Maximum number of iterations for the log analysis agent
|
| max_retrieval_iterations: Maximum number of iterations for the retrieval supervisor
|
| log_agent_output_dir: Directory to save log agent output
|
| response_agent_output_dir: Directory to save response agent output
|
| """
|
| if not os.path.exists(log_file):
|
| print(f"Error: Log file not found: {log_file}")
|
| return
|
|
|
| print(f"Starting integrated pipeline analysis...")
|
| print(f"Log file: {log_file}")
|
| print(f"Model: {model_name}")
|
| if tactic:
|
| print(f"Tactic: {tactic}")
|
| print(f"User query: {query or 'None'}")
|
|
|
|
|
| pipeline = create_simple_pipeline(
|
| model_name=model_name,
|
| temperature=temperature,
|
| max_log_analysis_iterations=max_log_analysis_iterations,
|
| max_retrieval_iterations=max_retrieval_iterations,
|
| log_agent_output_dir=log_agent_output_dir,
|
| response_agent_output_dir=response_agent_output_dir,
|
| progress_callback=progress_callback,
|
| )
|
|
|
|
|
| initial_state = {
|
| "log_file": log_file,
|
| "log_analysis_result": {},
|
| "retrieval_result": {},
|
| "response_analysis": {},
|
| "query": query or "",
|
| "tactic": tactic or "",
|
| "markdown_report": "",
|
| }
|
|
|
|
|
| start_time = time.time()
|
|
|
| if progress_callback:
|
| progress_callback(10, "Initializing pipeline...")
|
|
|
| final_state = pipeline.invoke(initial_state)
|
| end_time = time.time()
|
|
|
| if progress_callback:
|
| progress_callback(100, "Analysis complete!")
|
|
|
| print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")
|
| print("Analysis complete!")
|
| return final_state
|
|
|
|
|
| def main():
|
| """Main entry point."""
|
| if len(sys.argv) < 2:
|
| print(
|
| "Usage: python simple_pipeline.py <log_file> [query] [--model MODEL_NAME]"
|
| )
|
| print("\nExamples:")
|
| print(" python simple_pipeline.py sample_log.json")
|
| print(
|
| " python simple_pipeline.py sample_log.json 'Focus on credential access attacks'"
|
| )
|
| print(" python simple_pipeline.py sample_log.json --model groq:gpt-oss-120b")
|
| print("\nAvailable models:")
|
| print(" - google_genai:gemini-2.0-flash")
|
| print(" - google_genai:gemini-1.5-flash")
|
| print(" - groq:gpt-oss-120b")
|
| print(" - groq:gpt-oss-20b")
|
| print(" - groq:llama-3.1-8b-instant")
|
| print(" - groq:llama-3.3-70b-versatile")
|
| sys.exit(1)
|
|
|
| log_file = sys.argv[1]
|
| query = None
|
|
|
| model_name = "google_genai:gemini-2.0-flash"
|
| temperature = 0.1
|
| max_log_analysis_iterations = 2
|
| max_retrieval_iterations = 2
|
| log_agent_output_dir = "analysis"
|
| response_agent_output_dir = "final_response"
|
|
|
|
|
| i = 2
|
| while i < len(sys.argv):
|
| if sys.argv[i] == "--model" and i + 1 < len(sys.argv):
|
| model_name = sys.argv[i + 1]
|
| i += 2
|
| else:
|
| query = sys.argv[i]
|
| i += 1
|
|
|
|
|
| load_dotenv()
|
|
|
|
|
| try:
|
| final_state = analyze_log_file(
|
| log_file,
|
| query,
|
| tactic=None,
|
| model_name=model_name,
|
| temperature=temperature,
|
| max_log_analysis_iterations=max_log_analysis_iterations,
|
| max_retrieval_iterations=max_retrieval_iterations,
|
| log_agent_output_dir=log_agent_output_dir,
|
| response_agent_output_dir=response_agent_output_dir,
|
| )
|
| print(final_state["markdown_report"])
|
| except Exception as e:
|
| print(f"Error: {e}")
|
| sys.exit(1)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|