File size: 15,527 Bytes
86a1f6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ab37ac
 
86a1f6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ab37ac
 
 
 
 
 
86a1f6e
 
 
4ab37ac
 
 
 
 
 
86a1f6e
4ab37ac
 
 
 
 
 
 
 
 
86a1f6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f86cef4
 
 
 
 
 
 
86a1f6e
 
 
 
 
 
f86cef4
 
 
 
 
86a1f6e
 
 
 
 
 
 
 
 
 
f86cef4
 
 
 
 
86a1f6e
 
 
 
 
 
f86cef4
 
86a1f6e
 
 
 
f86cef4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86a1f6e
 
 
f86cef4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86a1f6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
"""
AegisOps AI β€” FastAPI backend (complete)
Modes: Single Technique, APT Group, Kill Chain, Topology Lab
SSE streaming + CORS + health + PDF/Sigma export

Run:
    uvicorn server:api --reload --port 8000
"""
from __future__ import annotations

import asyncio
import json
import re
import sys
from pathlib import Path
from typing import AsyncIterator

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from dotenv import load_dotenv

load_dotenv()
ROOT = Path(__file__).parent
sys.path.insert(0, str(ROOT))

# Internal AegisOps imports
from agents.llm import live_health, get_model_routing_status
from demo_output import DEMO_INVOKE_RESULT
from graph import app as pipeline
from export import generate_pdf

api = FastAPI(title="AegisOps AI", version="5.0")

api.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

assets_dir = ROOT / "assets"
if assets_dir.exists():
    api.mount("/assets", StaticFiles(directory=str(assets_dir)), name="assets")

@api.get("/", response_class=HTMLResponse)
async def index():
    f = ROOT / "index.html"
    return HTMLResponse(f.read_text(encoding="utf-8") if f.exists() else "<h1>AegisOps AI API Server is Running</h1>")

# ── Health ────────────────────────────────────────────────────────────────────
@api.get("/health")
async def health():
    return JSONResponse(dict(live_health()))

@api.get("/model-routing")
async def model_routing():
    return JSONResponse(get_model_routing_status())

# ── Artifact helpers ──────────────────────────────────────────────────────────
def _extract_fenced(text: str, lang: str) -> str:
    m = re.search(rf"```{lang}\s*(.*?)\s*```", text or "", re.DOTALL | re.IGNORECASE)
    return m.group(1).strip() if m else ""

def _splunk_spl(red: str, tid: str) -> str:
    try:
        payload = _extract_fenced(red, "json")
        obs = json.loads(payload).get("observables", []) if payload else []
        obs = [str(o) for o in obs if o]
    except Exception:
        obs = []
    if not obs:
        return f'index=windows | eval mitre_technique="{tid}" | stats count by host, user'
    clause = " OR ".join(f'"{o}"' for o in obs[:10])
    return f'index=windows ({clause}) | eval mitre_technique="{tid}" | stats count by host'

def _parse_verifier(verifier: str) -> dict:
    try:
        m = re.search(r'```json\s*(.*?)\s*```', verifier, re.DOTALL)
        d = json.loads(m.group(1) if m else verifier)
        return {
            "coverage":          d.get("coverage_score", 0),
            "product_readiness": d.get("product_readiness_score", 0),
            "real_world":        d.get("real_world_applicability_score", 0),
            "safety_verdict":    d.get("safety_verdict", "PASS"),
            "verdict":           d.get("verdict", "PASS"),
            "covered_observables":     d.get("covered_observables", []),
            "missing_observables":     d.get("missing_observables", []),
            "production_gaps":         d.get("production_gaps", []),
            "improvement_suggestions": d.get("improvement_suggestions", []),
        }
    except Exception:
        return {"coverage": 0, "product_readiness": 0, "real_world": 0,
                "safety_verdict": "PENDING", "verdict": "PENDING",
                "covered_observables": [], "missing_observables": [],
                "production_gaps": [], "improvement_suggestions": []}

def _build_response(result: dict, tid: str) -> dict:
    red = result.get("red_output", "")
    blue = result.get("blue_output", "")
    return {
        "status": "success",
        "technique_id": tid,
        "verifier_model": result.get("verifier_model", "Unknown verifier model"),
        "verifier_model_role": result.get("verifier_model_role", "unknown"),
        "outputs": {
            "red": red,
            "blue": blue,
            "response": result.get("response_output", ""),
            "verifier": result.get("verifier_output", ""),
        },
        "artifacts": {
            "sigma":    _extract_fenced(blue, "yaml"),
            "splunk":   _splunk_spl(red, tid),
            "raw_red":  red,
            "raw_blue": blue,
        },
        "scores":  _parse_verifier(result.get("verifier_output", "")),
        "metrics": result.get("metrics", {}),
    }

# ── Mode Resolution ───────────────────────────────────────────────────────────
def _resolve_techniques(mode: str, technique_id: str) -> list[str]:
    """Return list of technique IDs to run based on the selected mode."""
    tid = technique_id.split("Β·")[0].strip().upper()
    if mode == "single":
        return [tid]
    if mode == "apt":
        try:
            from apt import get_apt_techniques
            techniques = get_apt_techniques(technique_id) 
            return [t["technique_id"] for t in techniques] or [tid]
        except Exception:
            return [tid]
    if mode == "chain":
        try:
            from chain import get_next_techniques
            chain = [tid] + [t["technique_id"] for t in get_next_techniques(tid)]
            return chain[:3]  # limit to 3 for demo purposes
        except Exception:
            return [tid]
    if mode == "topology":
        try:
            from topology import generate_attack_paths
            paths = generate_attack_paths(tid)
            if paths:
                return paths[0]["seed_techniques"][:3]
        except Exception:
            pass
        return [tid]
    return [tid]

# ── Streaming (SSE) ───────────────────────────────────────────────────────────
def _sse(event: str, data: dict) -> str:
    return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"

async def _stream_demo(technique_id: str) -> AsyncIterator[str]:
    result = DEMO_INVOKE_RESULT
    stages = [
        ("red",      "red_output",      "Threat Agent",     3.8),
        ("blue",     "blue_output",     "Detection Agent",  3.2),
        ("response", "response_output", "Response Agent",   2.4),
        ("verifier", "verifier_output", "Validation Agent", 1.9),
    ]
    yield _sse("start", {
        "demo": True,
        "technique_id": technique_id,
        "pipeline_version": "aegisops-production-hybrid-v1",
    })

    for key, field, label, delay in stages:
        yield _sse("agent_start", {"agent": key, "label": label})
        await asyncio.sleep(delay)
        yield _sse("agent_done", {
            "agent": key,
            "label": label,
            "output": result.get(field, ""),
        })

    full = _build_response(result, technique_id)
    yield _sse("done", {
        "demo": True,
        "metrics": full["metrics"],
        "artifacts": full["artifacts"],
        "scores": full["scores"],
        "verifier_model": full.get("verifier_model"),
        "verifier_model_role": full.get("verifier_model_role"),
    })


def _run_node(node_name: str, state: dict) -> dict:
    from agents.red_agent import run_red_agent
    from agents.blue_agent import run_blue_agent
    from agents.response_agent import run_response_agent
    from agents.verifier_agent import run_verifier_agent
    return {
        "red_agent":      run_red_agent,
        "blue_agent":     run_blue_agent,
        "response_agent": run_response_agent,
        "verifier_agent": run_verifier_agent,
    }[node_name](state)

async def _stream_live(technique_id: str, mode: str) -> AsyncIterator[str]:
    techniques = _resolve_techniques(mode, technique_id)
    yield _sse("start", {
        "demo": False,
        "technique_id": technique_id,
        "mode": mode,
        "techniques": techniques,
        "pipeline_version": "aegisops-production-hybrid-v1",
    })

    all_results = []
    loop = asyncio.get_event_loop()

    for i, tid in enumerate(techniques):
        if len(techniques) > 1:
            yield _sse("technique_start", {
                "technique_id": tid,
                "index": i,
                "total": len(techniques),
            })

        agent_order = [
            ("red_agent",      "red",      "red_output",      "Threat Agent"),
            ("blue_agent",     "blue",     "blue_output",     "Detection Agent"),
            ("response_agent", "response", "response_output", "Response Agent"),
            ("verifier_agent", "verifier", "verifier_output", "Validation Agent"),
        ]
        state: dict = {"technique_id": tid}

        for node_name, key, field, label in agent_order:
            yield _sse("agent_start", {
                "agent": key,
                "label": label,
                "technique_id": tid,
            })
            try:
                result = await loop.run_in_executor(
                    None, lambda s=state, n=node_name: _run_node(n, s)
                )
                state.update(result)
                yield _sse("agent_done", {
                    "agent": key,
                    "label": label,
                    "output": state.get(field, ""),
                    "technique_id": tid,
                })
            except Exception as exc:
                yield _sse("agent_error", {
                    "agent": key,
                    "label": label,
                    "error": str(exc),
                    "technique_id": tid,
                })
                # Yield error done so frontend doesn't hang
                yield _sse("done", {
                    "demo": False,
                    "error": str(exc),
                    "metrics": {},
                    "artifacts": {"sigma": "", "splunk": "", "raw_red": "", "raw_blue": ""},
                    "scores": {
                        "coverage": 0, "product_readiness": 0, "real_world": 0,
                        "safety_verdict": "FAIL", "verdict": "FAIL",
                        "covered_observables": [], "missing_observables": [],
                        "production_gaps": [], "improvement_suggestions": [],
                    },
                    "verifier_model": None,
                    "verifier_model_role": None,
                })
                return

        all_results.append(state)

        # For multi-technique modes, yield a progress event (NOT done)
        # so the frontend knows this technique finished without stopping
        if len(techniques) > 1:
            sub = _build_response(state, tid)
            yield _sse("technique_done", {
                "technique_id": tid,
                "index": i,
                "total": len(techniques),
                "scores": sub["scores"],
                "artifacts": sub["artifacts"],
                "metrics": sub["metrics"],
                "verifier_model": sub.get("verifier_model"),
                "verifier_model_role": sub.get("verifier_model_role"),
            })

    # Single final done event β€” always fires exactly once
    if not all_results:
        return

    # For single technique, use its result directly
    # For multi-technique, merge outputs for the final payload
    if len(all_results) == 1:
        final = _build_response(all_results[0], techniques[0])
    else:
        # Merge: concatenate red/blue outputs, use last verifier scores
        merged_red = "\n\n---\n\n".join(
            r.get("red_output", "") for r in all_results
        )
        merged_blue = "\n\n---\n\n".join(
            r.get("blue_output", "") for r in all_results
        )
        # Use the last result's verifier for overall scores
        last = all_results[-1]
        last["red_output"] = merged_red
        last["blue_output"] = merged_blue
        final = _build_response(last, technique_id)

    yield _sse("done", {
        "demo": False,
        "metrics": final["metrics"],
        "artifacts": final["artifacts"],
        "scores": final["scores"],
        "verifier_model": final.get("verifier_model"),
        "verifier_model_role": final.get("verifier_model_role"),
        "techniques_completed": [r.get("technique_id", "") for r in all_results],
    })

# ── Endpoints ─────────────────────────────────────────────────────────────────
@api.post("/run")
async def run_streaming(request: Request):
    body = await request.json()
    demo      = body.get("demo", True)
    technique = body.get("technique_id", "T1059.001").strip()
    mode      = body.get("mode", "single").lower().replace(" ", "_")
    
    mode_map = {"single_technique": "single", "apt_group": "apt",
                "kill_chain": "chain", "topology_lab": "topology"}
    mode = mode_map.get(mode, mode)

    async def generate():
        if demo:
            async for chunk in _stream_demo(technique):
                yield chunk
        else:
            async for chunk in _stream_live(technique, mode):
                yield chunk

    return StreamingResponse(generate(), media_type="text/event-stream",
                             headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})

class DrillRequest(BaseModel):
    technique_id: str = "T1059.001"
    demo_mode: bool = True

@api.post("/api/run-drill")
async def run_drill(req: DrillRequest):
    """Legacy non-streaming endpoint for single runs."""
    if req.demo_mode:
        result = DEMO_INVOKE_RESULT
    else:
        loop = asyncio.get_event_loop()
        result = await loop.run_in_executor(
            None, lambda: pipeline.invoke({"technique_id": req.technique_id})
        )
    return _build_response(result, req.technique_id)

@api.post("/export/pdf")
async def export_pdf(request: Request):
    body = await request.json()
    tid = body.get("technique_id", "T1059.001")
    pdf = generate_pdf(tid, body.get("red_output", ""), body.get("blue_output", ""))
    return StreamingResponse(iter([pdf]), media_type="application/pdf",
                             headers={"Content-Disposition": f"attachment; filename=aegisops_{tid}.pdf"})

@api.post("/export/sigma")
async def export_sigma(request: Request):
    body = await request.json()
    sigma = _extract_fenced(body.get("blue_output", ""), "yaml")
    tid = body.get("technique_id", "rule")
    return StreamingResponse(iter([sigma.encode()]), media_type="text/plain",
                             headers={"Content-Disposition": f"attachment; filename=sigma_{tid}.yml"})

@api.get("/topology")
async def get_topology(seed: str = "T1566.001"):
    from topology import generate_topology, generate_attack_paths
    topo = generate_topology(seed)
    paths = generate_attack_paths(seed)
    return JSONResponse({"topology": topo, "paths": paths})

@api.get("/intel/group")
async def get_intel_group(name: str = "APT28"):
    from apt import get_group_info, get_apt_techniques
    info = get_group_info(name)
    techniques = get_apt_techniques(name)
    return JSONResponse({"group": info, "techniques": techniques})